Skip to content

attention_head

dreem.models.attention_head

Module containing different components of multi-head attention heads.

ATTWeightHead

Bases: Module

Single attention head.

Source code in dreem/models/attention_head.py
class ATTWeightHead(torch.nn.Module):
    """Single attention head."""

    def __init__(
        self,
        feature_dim: int,
        num_layers: int,
        dropout: float,
    ):
        """Initialize an instance of ATTWeightHead.

        Args:
            feature_dim: The dimensionality of input features.
            num_layers: The number of hidden layers in the MLP.
            dropout: Dropout probability.
        """
        super().__init__()

        self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
        self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> torch.Tensor:
        """Compute the attention weights of a query tensor using the key tensor.

        Args:
            query: Input tensor of shape (batch_size, num_frame_instances, feature_dim).
            key: Input tensor of shape (batch_size, num_window_instances, feature_dim).

        Returns:
            Output tensor of shape (batch_size, num_frame_instances, num_window_instances).
        """
        k = self.k_proj(key)
        q = self.q_proj(query)
        attn_weights = torch.bmm(q, k.transpose(1, 2))

        return attn_weights  # (B, N_t, N)

__init__(feature_dim, num_layers, dropout)

Initialize an instance of ATTWeightHead.

Parameters:

Name Type Description Default
feature_dim int

The dimensionality of input features.

required
num_layers int

The number of hidden layers in the MLP.

required
dropout float

Dropout probability.

required
Source code in dreem/models/attention_head.py
def __init__(
    self,
    feature_dim: int,
    num_layers: int,
    dropout: float,
):
    """Initialize an instance of ATTWeightHead.

    Args:
        feature_dim: The dimensionality of input features.
        num_layers: The number of hidden layers in the MLP.
        dropout: Dropout probability.
    """
    super().__init__()

    self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
    self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)

forward(query, key)

Compute the attention weights of a query tensor using the key tensor.

Parameters:

Name Type Description Default
query Tensor

Input tensor of shape (batch_size, num_frame_instances, feature_dim).

required
key Tensor

Input tensor of shape (batch_size, num_window_instances, feature_dim).

required

Returns:

Type Description
Tensor

Output tensor of shape (batch_size, num_frame_instances, num_window_instances).

Source code in dreem/models/attention_head.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
) -> torch.Tensor:
    """Compute the attention weights of a query tensor using the key tensor.

    Args:
        query: Input tensor of shape (batch_size, num_frame_instances, feature_dim).
        key: Input tensor of shape (batch_size, num_window_instances, feature_dim).

    Returns:
        Output tensor of shape (batch_size, num_frame_instances, num_window_instances).
    """
    k = self.k_proj(key)
    q = self.q_proj(query)
    attn_weights = torch.bmm(q, k.transpose(1, 2))

    return attn_weights  # (B, N_t, N)