attention_head
dreem.models.attention_head
¶
Module containing different components of multi-head attention heads.
ATTWeightHead
¶
Bases: Module
Single attention head.
Source code in dreem/models/attention_head.py
class ATTWeightHead(torch.nn.Module):
"""Single attention head."""
def __init__(
self,
feature_dim: int,
num_layers: int,
dropout: float,
):
"""Initialize an instance of ATTWeightHead.
Args:
feature_dim: The dimensionality of input features.
num_layers: The number of hidden layers in the MLP.
dropout: Dropout probability.
"""
super().__init__()
self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> torch.Tensor:
"""Compute the attention weights of a query tensor using the key tensor.
Args:
query: Input tensor of shape (batch_size, num_frame_instances, feature_dim).
key: Input tensor of shape (batch_size, num_window_instances, feature_dim).
Returns:
Output tensor of shape (batch_size, num_frame_instances, num_window_instances).
"""
k = self.k_proj(key)
q = self.q_proj(query)
attn_weights = torch.bmm(q, k.transpose(1, 2))
return attn_weights # (B, N_t, N)
__init__(feature_dim, num_layers, dropout)
¶
Initialize an instance of ATTWeightHead.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
feature_dim |
int
|
The dimensionality of input features. |
required |
num_layers |
int
|
The number of hidden layers in the MLP. |
required |
dropout |
float
|
Dropout probability. |
required |
Source code in dreem/models/attention_head.py
def __init__(
self,
feature_dim: int,
num_layers: int,
dropout: float,
):
"""Initialize an instance of ATTWeightHead.
Args:
feature_dim: The dimensionality of input features.
num_layers: The number of hidden layers in the MLP.
dropout: Dropout probability.
"""
super().__init__()
self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
forward(query, key)
¶
Compute the attention weights of a query tensor using the key tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query |
Tensor
|
Input tensor of shape (batch_size, num_frame_instances, feature_dim). |
required |
key |
Tensor
|
Input tensor of shape (batch_size, num_window_instances, feature_dim). |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Output tensor of shape (batch_size, num_frame_instances, num_window_instances). |
Source code in dreem/models/attention_head.py
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> torch.Tensor:
"""Compute the attention weights of a query tensor using the key tensor.
Args:
query: Input tensor of shape (batch_size, num_frame_instances, feature_dim).
key: Input tensor of shape (batch_size, num_window_instances, feature_dim).
Returns:
Output tensor of shape (batch_size, num_frame_instances, num_window_instances).
"""
k = self.k_proj(key)
q = self.q_proj(query)
attn_weights = torch.bmm(q, k.transpose(1, 2))
return attn_weights # (B, N_t, N)