transformer
dreem.models.transformer
¶
DETR Transformer class.
Copyright © Facebook, Inc. and its affiliates. All Rights Reserved
- Modified from https://github.com/facebookresearch/detr/blob/main/models/transformer.py
- Modified from https://github.com/xingyizhou/GTR/blob/master/gtr/modeling/roi_heads/transformer.py
- Modifications:
- positional encodings are passed in MHattention
- extra LN at the end of encoder is removed
- decoder returns a stack of activations from all decoding layers
- added fixed embeddings over boxes
Transformer
¶
Bases: Module
Transformer class.
Source code in dreem/models/transformer.py
class Transformer(torch.nn.Module):
"""Transformer class."""
def __init__(
self,
d_model: int = 1024,
nhead: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dropout: float = 0.1,
activation: str = "relu",
return_intermediate_dec: bool = False,
norm: bool = False,
num_layers_attn_head: int = 2,
dropout_attn_head: float = 0.1,
embedding_meta: dict = None,
return_embedding: bool = False,
decoder_self_attn: bool = False,
) -> None:
"""Initialize Transformer.
Args:
d_model: The number of features in the encoder/decoder inputs.
nhead: The number of heads in the transfomer encoder/decoder.
num_encoder_layers: The number of encoder-layers in the encoder.
num_decoder_layers: The number of decoder-layers in the decoder.
dropout: Dropout value applied to the output of transformer layers.
activation: Activation function to use.
return_intermediate_dec: Return intermediate layers from decoder.
norm: If True, normalize output of encoder and decoder.
num_layers_attn_head: The number of layers in the attention head.
dropout_attn_head: Dropout value for the attention_head.
embedding_meta: Metadata for positional embeddings. See below.
return_embedding: Whether to return the positional embeddings
decoder_self_attn: If True, use decoder self attention.
More details on `embedding_meta`:
By default this will be an empty dict and indicate
that no positional embeddings should be used. To use the positional embeddings
pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie:
{"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'},
"temp": {'mode': 'learned', 'emb_num': 16}}. (see `dreem.models.embeddings.Embedding.EMB_TYPES`
and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters).
"""
super().__init__()
self.d_model = dim_feedforward = feature_dim_attn_head = d_model
self.embedding_meta = embedding_meta
self.return_embedding = return_embedding
self.pos_emb = Embedding(emb_type="off", mode="off", features=self.d_model)
self.temp_emb = Embedding(emb_type="off", mode="off", features=self.d_model)
if self.embedding_meta:
if "pos" in self.embedding_meta:
pos_emb_cfg = self.embedding_meta["pos"]
if pos_emb_cfg:
self.pos_emb = Embedding(
emb_type="pos", features=self.d_model, **pos_emb_cfg
)
if "temp" in self.embedding_meta:
temp_emb_cfg = self.embedding_meta["temp"]
if temp_emb_cfg:
self.temp_emb = Embedding(
emb_type="temp", features=self.d_model, **temp_emb_cfg
)
# Transformer Encoder
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, norm
)
encoder_norm = nn.LayerNorm(d_model) if (norm) else None
self.encoder = TransformerEncoder(
encoder_layer, num_encoder_layers, encoder_norm
)
# Transformer Decoder
decoder_layer = TransformerDecoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
activation,
norm,
decoder_self_attn,
)
decoder_norm = nn.LayerNorm(d_model) if (norm) else None
self.decoder = TransformerDecoder(
decoder_layer, num_decoder_layers, return_intermediate_dec, decoder_norm
)
# Transformer attention head
self.attn_head = ATTWeightHead(
feature_dim=feature_dim_attn_head,
num_layers=num_layers_attn_head,
dropout=dropout_attn_head,
)
self._reset_parameters()
def _reset_parameters(self):
"""Initialize model weights from xavier distribution."""
for p in self.parameters():
if not torch.nn.parameter.is_lazy(p) and p.dim() > 1:
try:
nn.init.xavier_uniform_(p)
except ValueError as e:
print(f"Failed Trying to initialize {p}")
raise (e)
def forward(
self,
ref_instances: list["dreem.io.Instance"],
query_instances: list["dreem.io.Instance"] = None,
) -> list[AssociationMatrix]:
"""Execute a forward pass through the transformer and attention head.
Args:
ref_instances: A list of instance objects (See `dreem.io.Instance` for more info.)
query_instances: An set of instances to be used as decoder queries.
Returns:
asso_output: A list of torch.Tensors of shape (L, n_query, total_instances) where:
L: number of decoder blocks
n_query: number of instances in current query/frame
total_instances: number of instances in window
"""
ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)
# window_length = len(frames)
# instances_per_frame = [frame.num_detected for frame in frames]
total_instances = len(ref_instances)
embed_dim = ref_features.shape[-1]
# print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}')
ref_boxes = get_boxes(ref_instances) # total_instances, 4
ref_boxes = torch.nan_to_num(ref_boxes, -1.0)
ref_times, query_times = get_times(ref_instances, query_instances)
window_length = len(ref_times.unique())
ref_temp_emb = self.temp_emb(ref_times / window_length)
ref_pos_emb = self.pos_emb(ref_boxes)
if self.return_embedding:
for i, instance in enumerate(ref_instances):
instance.add_embedding("pos", ref_pos_emb[i])
instance.add_embedding("temp", ref_temp_emb[i])
ref_emb = (ref_pos_emb + ref_temp_emb) / 2.0
ref_emb = ref_emb.view(1, total_instances, embed_dim)
ref_emb = ref_emb.permute(1, 0, 2) # (total_instances, batch_size, embed_dim)
batch_size, total_instances, embed_dim = ref_features.shape
ref_features = ref_features.permute(
1, 0, 2
) # (total_instances, batch_size, embed_dim)
encoder_queries = ref_features
encoder_features = self.encoder(
encoder_queries, pos_emb=ref_emb
) # (total_instances, batch_size, embed_dim)
n_query = total_instances
query_features = ref_features
query_pos_emb = ref_pos_emb
query_temp_emb = ref_temp_emb
query_emb = ref_emb
if query_instances is not None:
n_query = len(query_instances)
query_features = torch.cat(
[instance.features for instance in query_instances], dim=0
).unsqueeze(0)
query_features = query_features.permute(
1, 0, 2
) # (n_query, batch_size, embed_dim)
query_boxes = get_boxes(query_instances)
query_boxes = torch.nan_to_num(query_boxes, -1.0)
query_temp_emb = self.temp_emb(query_times / window_length)
query_pos_emb = self.pos_emb(query_boxes)
query_emb = (query_pos_emb + query_temp_emb) / 2.0
query_emb = query_emb.view(1, n_query, embed_dim)
query_emb = query_emb.permute(1, 0, 2) # (n_query, batch_size, embed_dim)
else:
query_instances = ref_instances
if self.return_embedding:
for i, instance in enumerate(query_instances):
instance.add_embedding("pos", query_pos_emb[i])
instance.add_embedding("temp", query_temp_emb[i])
decoder_features = self.decoder(
query_features,
encoder_features,
ref_pos_emb=ref_emb,
query_pos_emb=query_emb,
) # (L, n_query, batch_size, embed_dim)
decoder_features = decoder_features.transpose(
1, 2
) # # (L, batch_size, n_query, embed_dim)
encoder_features = encoder_features.permute(1, 0, 2).view(
batch_size, total_instances, embed_dim
) # (batch_size, total_instances, embed_dim)
asso_output = []
for frame_features in decoder_features:
asso_matrix = self.attn_head(frame_features, encoder_features).view(
n_query, total_instances
)
asso_matrix = AssociationMatrix(asso_matrix, ref_instances, query_instances)
asso_output.append(asso_matrix)
# (L=1, n_query, total_instances)
return asso_output
__init__(d_model=1024, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dropout=0.1, activation='relu', return_intermediate_dec=False, norm=False, num_layers_attn_head=2, dropout_attn_head=0.1, embedding_meta=None, return_embedding=False, decoder_self_attn=False)
¶
Initialize Transformer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
d_model |
int
|
The number of features in the encoder/decoder inputs. |
1024
|
nhead |
int
|
The number of heads in the transfomer encoder/decoder. |
8
|
num_encoder_layers |
int
|
The number of encoder-layers in the encoder. |
6
|
num_decoder_layers |
int
|
The number of decoder-layers in the decoder. |
6
|
dropout |
float
|
Dropout value applied to the output of transformer layers. |
0.1
|
activation |
str
|
Activation function to use. |
'relu'
|
return_intermediate_dec |
bool
|
Return intermediate layers from decoder. |
False
|
norm |
bool
|
If True, normalize output of encoder and decoder. |
False
|
num_layers_attn_head |
int
|
The number of layers in the attention head. |
2
|
dropout_attn_head |
float
|
Dropout value for the attention_head. |
0.1
|
embedding_meta |
dict
|
Metadata for positional embeddings. See below. |
None
|
return_embedding |
bool
|
Whether to return the positional embeddings |
False
|
decoder_self_attn |
bool
|
If True, use decoder self attention. More details on |
False
|
Source code in dreem/models/transformer.py
def __init__(
self,
d_model: int = 1024,
nhead: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dropout: float = 0.1,
activation: str = "relu",
return_intermediate_dec: bool = False,
norm: bool = False,
num_layers_attn_head: int = 2,
dropout_attn_head: float = 0.1,
embedding_meta: dict = None,
return_embedding: bool = False,
decoder_self_attn: bool = False,
) -> None:
"""Initialize Transformer.
Args:
d_model: The number of features in the encoder/decoder inputs.
nhead: The number of heads in the transfomer encoder/decoder.
num_encoder_layers: The number of encoder-layers in the encoder.
num_decoder_layers: The number of decoder-layers in the decoder.
dropout: Dropout value applied to the output of transformer layers.
activation: Activation function to use.
return_intermediate_dec: Return intermediate layers from decoder.
norm: If True, normalize output of encoder and decoder.
num_layers_attn_head: The number of layers in the attention head.
dropout_attn_head: Dropout value for the attention_head.
embedding_meta: Metadata for positional embeddings. See below.
return_embedding: Whether to return the positional embeddings
decoder_self_attn: If True, use decoder self attention.
More details on `embedding_meta`:
By default this will be an empty dict and indicate
that no positional embeddings should be used. To use the positional embeddings
pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie:
{"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'},
"temp": {'mode': 'learned', 'emb_num': 16}}. (see `dreem.models.embeddings.Embedding.EMB_TYPES`
and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters).
"""
super().__init__()
self.d_model = dim_feedforward = feature_dim_attn_head = d_model
self.embedding_meta = embedding_meta
self.return_embedding = return_embedding
self.pos_emb = Embedding(emb_type="off", mode="off", features=self.d_model)
self.temp_emb = Embedding(emb_type="off", mode="off", features=self.d_model)
if self.embedding_meta:
if "pos" in self.embedding_meta:
pos_emb_cfg = self.embedding_meta["pos"]
if pos_emb_cfg:
self.pos_emb = Embedding(
emb_type="pos", features=self.d_model, **pos_emb_cfg
)
if "temp" in self.embedding_meta:
temp_emb_cfg = self.embedding_meta["temp"]
if temp_emb_cfg:
self.temp_emb = Embedding(
emb_type="temp", features=self.d_model, **temp_emb_cfg
)
# Transformer Encoder
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, norm
)
encoder_norm = nn.LayerNorm(d_model) if (norm) else None
self.encoder = TransformerEncoder(
encoder_layer, num_encoder_layers, encoder_norm
)
# Transformer Decoder
decoder_layer = TransformerDecoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
activation,
norm,
decoder_self_attn,
)
decoder_norm = nn.LayerNorm(d_model) if (norm) else None
self.decoder = TransformerDecoder(
decoder_layer, num_decoder_layers, return_intermediate_dec, decoder_norm
)
# Transformer attention head
self.attn_head = ATTWeightHead(
feature_dim=feature_dim_attn_head,
num_layers=num_layers_attn_head,
dropout=dropout_attn_head,
)
self._reset_parameters()
forward(ref_instances, query_instances=None)
¶
Execute a forward pass through the transformer and attention head.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ref_instances |
list[Instance]
|
A list of instance objects (See |
required |
query_instances |
list[Instance]
|
An set of instances to be used as decoder queries. |
None
|
Returns:
Name | Type | Description |
---|---|---|
asso_output |
list[AssociationMatrix]
|
A list of torch.Tensors of shape (L, n_query, total_instances) where: L: number of decoder blocks n_query: number of instances in current query/frame total_instances: number of instances in window |
Source code in dreem/models/transformer.py
def forward(
self,
ref_instances: list["dreem.io.Instance"],
query_instances: list["dreem.io.Instance"] = None,
) -> list[AssociationMatrix]:
"""Execute a forward pass through the transformer and attention head.
Args:
ref_instances: A list of instance objects (See `dreem.io.Instance` for more info.)
query_instances: An set of instances to be used as decoder queries.
Returns:
asso_output: A list of torch.Tensors of shape (L, n_query, total_instances) where:
L: number of decoder blocks
n_query: number of instances in current query/frame
total_instances: number of instances in window
"""
ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)
# window_length = len(frames)
# instances_per_frame = [frame.num_detected for frame in frames]
total_instances = len(ref_instances)
embed_dim = ref_features.shape[-1]
# print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}')
ref_boxes = get_boxes(ref_instances) # total_instances, 4
ref_boxes = torch.nan_to_num(ref_boxes, -1.0)
ref_times, query_times = get_times(ref_instances, query_instances)
window_length = len(ref_times.unique())
ref_temp_emb = self.temp_emb(ref_times / window_length)
ref_pos_emb = self.pos_emb(ref_boxes)
if self.return_embedding:
for i, instance in enumerate(ref_instances):
instance.add_embedding("pos", ref_pos_emb[i])
instance.add_embedding("temp", ref_temp_emb[i])
ref_emb = (ref_pos_emb + ref_temp_emb) / 2.0
ref_emb = ref_emb.view(1, total_instances, embed_dim)
ref_emb = ref_emb.permute(1, 0, 2) # (total_instances, batch_size, embed_dim)
batch_size, total_instances, embed_dim = ref_features.shape
ref_features = ref_features.permute(
1, 0, 2
) # (total_instances, batch_size, embed_dim)
encoder_queries = ref_features
encoder_features = self.encoder(
encoder_queries, pos_emb=ref_emb
) # (total_instances, batch_size, embed_dim)
n_query = total_instances
query_features = ref_features
query_pos_emb = ref_pos_emb
query_temp_emb = ref_temp_emb
query_emb = ref_emb
if query_instances is not None:
n_query = len(query_instances)
query_features = torch.cat(
[instance.features for instance in query_instances], dim=0
).unsqueeze(0)
query_features = query_features.permute(
1, 0, 2
) # (n_query, batch_size, embed_dim)
query_boxes = get_boxes(query_instances)
query_boxes = torch.nan_to_num(query_boxes, -1.0)
query_temp_emb = self.temp_emb(query_times / window_length)
query_pos_emb = self.pos_emb(query_boxes)
query_emb = (query_pos_emb + query_temp_emb) / 2.0
query_emb = query_emb.view(1, n_query, embed_dim)
query_emb = query_emb.permute(1, 0, 2) # (n_query, batch_size, embed_dim)
else:
query_instances = ref_instances
if self.return_embedding:
for i, instance in enumerate(query_instances):
instance.add_embedding("pos", query_pos_emb[i])
instance.add_embedding("temp", query_temp_emb[i])
decoder_features = self.decoder(
query_features,
encoder_features,
ref_pos_emb=ref_emb,
query_pos_emb=query_emb,
) # (L, n_query, batch_size, embed_dim)
decoder_features = decoder_features.transpose(
1, 2
) # # (L, batch_size, n_query, embed_dim)
encoder_features = encoder_features.permute(1, 0, 2).view(
batch_size, total_instances, embed_dim
) # (batch_size, total_instances, embed_dim)
asso_output = []
for frame_features in decoder_features:
asso_matrix = self.attn_head(frame_features, encoder_features).view(
n_query, total_instances
)
asso_matrix = AssociationMatrix(asso_matrix, ref_instances, query_instances)
asso_output.append(asso_matrix)
# (L=1, n_query, total_instances)
return asso_output
TransformerDecoder
¶
Bases: Module
Transformer Decoder Block composed of Transformer Decoder Layers.
Source code in dreem/models/transformer.py
class TransformerDecoder(nn.Module):
"""Transformer Decoder Block composed of Transformer Decoder Layers."""
def __init__(
self,
decoder_layer: TransformerDecoderLayer,
num_layers: int,
return_intermediate: bool = False,
norm: nn.Module = None,
) -> None:
"""Initialize transformer decoder block.
Args:
decoder_layer: An instance of TransformerDecoderLayer.
num_layers: The number of decoder layers to be stacked.
return_intermediate: Return intermediate layers from decoder.
norm: The normalization layer to be applied.
"""
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.return_intermediate = return_intermediate
self.norm = norm if norm is not None else nn.Identity()
def forward(
self,
decoder_queries: torch.Tensor,
encoder_features: torch.Tensor,
ref_pos_emb: torch.Tensor = None,
query_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
"""Execute a forward pass of the decoder block.
Args:
decoder_queries: Query sequence for decoder to generate (n_query, batch_size, embed_dim).
encoder_features: Output from encoder, that decoder uses to attend to relevant
parts of input sequence (total_instances, batch_size, embed_dim)
ref_pos_emb: The input positional embedding tensor of shape (total_instances, batch_size, embed_dim).
query_pos_emb: The query positional embedding of shape (n_query, batch_size, embed_dim)
Returns:
The output tensor of shape (L, n_query, batch_size, embed_dim).
"""
decoder_features = decoder_queries
intermediate = []
for layer in self.layers:
decoder_features = layer(
decoder_features,
encoder_features,
ref_pos_emb=ref_pos_emb,
query_pos_emb=query_pos_emb,
)
if self.return_intermediate:
intermediate.append(self.norm(decoder_features))
decoder_features = self.norm(decoder_features)
if self.return_intermediate:
intermediate.pop()
intermediate.append(decoder_features)
return torch.stack(intermediate)
return decoder_features.unsqueeze(0)
__init__(decoder_layer, num_layers, return_intermediate=False, norm=None)
¶
Initialize transformer decoder block.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
decoder_layer |
TransformerDecoderLayer
|
An instance of TransformerDecoderLayer. |
required |
num_layers |
int
|
The number of decoder layers to be stacked. |
required |
return_intermediate |
bool
|
Return intermediate layers from decoder. |
False
|
norm |
Module
|
The normalization layer to be applied. |
None
|
Source code in dreem/models/transformer.py
def __init__(
self,
decoder_layer: TransformerDecoderLayer,
num_layers: int,
return_intermediate: bool = False,
norm: nn.Module = None,
) -> None:
"""Initialize transformer decoder block.
Args:
decoder_layer: An instance of TransformerDecoderLayer.
num_layers: The number of decoder layers to be stacked.
return_intermediate: Return intermediate layers from decoder.
norm: The normalization layer to be applied.
"""
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.return_intermediate = return_intermediate
self.norm = norm if norm is not None else nn.Identity()
forward(decoder_queries, encoder_features, ref_pos_emb=None, query_pos_emb=None)
¶
Execute a forward pass of the decoder block.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
decoder_queries |
Tensor
|
Query sequence for decoder to generate (n_query, batch_size, embed_dim). |
required |
encoder_features |
Tensor
|
Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim) |
required |
ref_pos_emb |
Tensor
|
The input positional embedding tensor of shape (total_instances, batch_size, embed_dim). |
None
|
query_pos_emb |
Tensor
|
The query positional embedding of shape (n_query, batch_size, embed_dim) |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
The output tensor of shape (L, n_query, batch_size, embed_dim). |
Source code in dreem/models/transformer.py
def forward(
self,
decoder_queries: torch.Tensor,
encoder_features: torch.Tensor,
ref_pos_emb: torch.Tensor = None,
query_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
"""Execute a forward pass of the decoder block.
Args:
decoder_queries: Query sequence for decoder to generate (n_query, batch_size, embed_dim).
encoder_features: Output from encoder, that decoder uses to attend to relevant
parts of input sequence (total_instances, batch_size, embed_dim)
ref_pos_emb: The input positional embedding tensor of shape (total_instances, batch_size, embed_dim).
query_pos_emb: The query positional embedding of shape (n_query, batch_size, embed_dim)
Returns:
The output tensor of shape (L, n_query, batch_size, embed_dim).
"""
decoder_features = decoder_queries
intermediate = []
for layer in self.layers:
decoder_features = layer(
decoder_features,
encoder_features,
ref_pos_emb=ref_pos_emb,
query_pos_emb=query_pos_emb,
)
if self.return_intermediate:
intermediate.append(self.norm(decoder_features))
decoder_features = self.norm(decoder_features)
if self.return_intermediate:
intermediate.pop()
intermediate.append(decoder_features)
return torch.stack(intermediate)
return decoder_features.unsqueeze(0)
TransformerDecoderLayer
¶
Bases: Module
A single transformer decoder layer.
Source code in dreem/models/transformer.py
class TransformerDecoderLayer(nn.Module):
"""A single transformer decoder layer."""
def __init__(
self,
d_model: int = 1024,
nhead: int = 6,
dim_feedforward: int = 1024,
dropout: float = 0.1,
activation: str = "relu",
norm: bool = False,
decoder_self_attn: bool = False,
) -> None:
"""Initialize transformer decoder layer.
Args:
d_model: The number of features in the decoder inputs.
nhead: The number of heads for the decoder.
dim_feedforward: Dimension of the feedforward layers of decoder.
dropout: Dropout value applied to the output of decoder.
activation: Activation function to use.
norm: If True, normalize output of decoder.
decoder_self_attn: If True, use decoder self attention
"""
super().__init__()
self.decoder_self_attn = decoder_self_attn
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
if self.decoder_self_attn:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm3 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def forward(
self,
decoder_queries: torch.Tensor,
encoder_features: torch.Tensor,
ref_pos_emb: torch.Tensor = None,
query_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
"""Execute forward pass of decoder layer.
Args:
decoder_queries: Target sequence for decoder to generate (n_query, batch_size, embed_dim).
encoder_features: Output from encoder, that decoder uses to attend to relevant
parts of input sequence (total_instances, batch_size, embed_dim)
ref_pos_emb: The input positional embedding tensor of shape (n_query, embed_dim).
query_pos_emb: The target positional embedding of shape (n_query, embed_dim)
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
if query_pos_emb is None:
query_pos_emb = torch.zeros_like(decoder_queries)
if ref_pos_emb is None:
ref_pos_emb = torch.zeros_like(encoder_features)
decoder_queries = decoder_queries + query_pos_emb
encoder_features = encoder_features + ref_pos_emb
if self.decoder_self_attn:
self_attn_features = self.self_attn(
query=decoder_queries, key=decoder_queries, value=decoder_queries
)[0]
decoder_queries = decoder_queries + self.dropout1(self_attn_features)
decoder_queries = self.norm1(decoder_queries)
x_attn_features = self.multihead_attn(
query=decoder_queries, # (n_query, batch_size, embed_dim)
key=encoder_features, # (total_instances, batch_size, embed_dim)
value=encoder_features, # (total_instances, batch_size, embed_dim)
)[
0
] # (n_query, batch_size, embed_dim)
decoder_queries = decoder_queries + self.dropout2(
x_attn_features
) # (n_query, batch_size, embed_dim)
decoder_queries = self.norm2(
decoder_queries
) # (n_query, batch_size, embed_dim)
projection = self.linear2(
self.dropout(self.activation(self.linear1(decoder_queries)))
) # (n_query, batch_size, embed_dim)
decoder_queries = decoder_queries + self.dropout3(
projection
) # (n_query, batch_size, embed_dim)
decoder_features = self.norm3(decoder_queries)
return decoder_features # (n_query, batch_size, embed_dim)
__init__(d_model=1024, nhead=6, dim_feedforward=1024, dropout=0.1, activation='relu', norm=False, decoder_self_attn=False)
¶
Initialize transformer decoder layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
d_model |
int
|
The number of features in the decoder inputs. |
1024
|
nhead |
int
|
The number of heads for the decoder. |
6
|
dim_feedforward |
int
|
Dimension of the feedforward layers of decoder. |
1024
|
dropout |
float
|
Dropout value applied to the output of decoder. |
0.1
|
activation |
str
|
Activation function to use. |
'relu'
|
norm |
bool
|
If True, normalize output of decoder. |
False
|
decoder_self_attn |
bool
|
If True, use decoder self attention |
False
|
Source code in dreem/models/transformer.py
def __init__(
self,
d_model: int = 1024,
nhead: int = 6,
dim_feedforward: int = 1024,
dropout: float = 0.1,
activation: str = "relu",
norm: bool = False,
decoder_self_attn: bool = False,
) -> None:
"""Initialize transformer decoder layer.
Args:
d_model: The number of features in the decoder inputs.
nhead: The number of heads for the decoder.
dim_feedforward: Dimension of the feedforward layers of decoder.
dropout: Dropout value applied to the output of decoder.
activation: Activation function to use.
norm: If True, normalize output of decoder.
decoder_self_attn: If True, use decoder self attention
"""
super().__init__()
self.decoder_self_attn = decoder_self_attn
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
if self.decoder_self_attn:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm3 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
forward(decoder_queries, encoder_features, ref_pos_emb=None, query_pos_emb=None)
¶
Execute forward pass of decoder layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
decoder_queries |
Tensor
|
Target sequence for decoder to generate (n_query, batch_size, embed_dim). |
required |
encoder_features |
Tensor
|
Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim) |
required |
ref_pos_emb |
Tensor
|
The input positional embedding tensor of shape (n_query, embed_dim). |
None
|
query_pos_emb |
Tensor
|
The target positional embedding of shape (n_query, embed_dim) |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
The output tensor of shape (n_query, batch_size, embed_dim). |
Source code in dreem/models/transformer.py
def forward(
self,
decoder_queries: torch.Tensor,
encoder_features: torch.Tensor,
ref_pos_emb: torch.Tensor = None,
query_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
"""Execute forward pass of decoder layer.
Args:
decoder_queries: Target sequence for decoder to generate (n_query, batch_size, embed_dim).
encoder_features: Output from encoder, that decoder uses to attend to relevant
parts of input sequence (total_instances, batch_size, embed_dim)
ref_pos_emb: The input positional embedding tensor of shape (n_query, embed_dim).
query_pos_emb: The target positional embedding of shape (n_query, embed_dim)
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
if query_pos_emb is None:
query_pos_emb = torch.zeros_like(decoder_queries)
if ref_pos_emb is None:
ref_pos_emb = torch.zeros_like(encoder_features)
decoder_queries = decoder_queries + query_pos_emb
encoder_features = encoder_features + ref_pos_emb
if self.decoder_self_attn:
self_attn_features = self.self_attn(
query=decoder_queries, key=decoder_queries, value=decoder_queries
)[0]
decoder_queries = decoder_queries + self.dropout1(self_attn_features)
decoder_queries = self.norm1(decoder_queries)
x_attn_features = self.multihead_attn(
query=decoder_queries, # (n_query, batch_size, embed_dim)
key=encoder_features, # (total_instances, batch_size, embed_dim)
value=encoder_features, # (total_instances, batch_size, embed_dim)
)[
0
] # (n_query, batch_size, embed_dim)
decoder_queries = decoder_queries + self.dropout2(
x_attn_features
) # (n_query, batch_size, embed_dim)
decoder_queries = self.norm2(
decoder_queries
) # (n_query, batch_size, embed_dim)
projection = self.linear2(
self.dropout(self.activation(self.linear1(decoder_queries)))
) # (n_query, batch_size, embed_dim)
decoder_queries = decoder_queries + self.dropout3(
projection
) # (n_query, batch_size, embed_dim)
decoder_features = self.norm3(decoder_queries)
return decoder_features # (n_query, batch_size, embed_dim)
TransformerEncoder
¶
Bases: Module
A transformer encoder block composed of encoder layers.
Source code in dreem/models/transformer.py
class TransformerEncoder(nn.Module):
"""A transformer encoder block composed of encoder layers."""
def __init__(
self,
encoder_layer: TransformerEncoderLayer,
num_layers: int,
norm: nn.Module = None,
) -> None:
"""Initialize transformer encoder.
Args:
encoder_layer: An instance of the TransformerEncoderLayer.
num_layers: The number of encoder layers to be stacked.
norm: The normalization layer to be applied.
"""
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm if norm is not None else nn.Identity()
def forward(
self, queries: torch.Tensor, pos_emb: torch.Tensor = None
) -> torch.Tensor:
"""Execute a forward pass of encoder layer.
Args:
queries: The input tensor of shape (n_query, batch_size, embed_dim).
pos_emb: The positional embedding tensor of shape (n_query, embed_dim).
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
for layer in self.layers:
queries = layer(queries, pos_emb=pos_emb)
encoder_features = self.norm(queries)
return encoder_features
__init__(encoder_layer, num_layers, norm=None)
¶
Initialize transformer encoder.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
encoder_layer |
TransformerEncoderLayer
|
An instance of the TransformerEncoderLayer. |
required |
num_layers |
int
|
The number of encoder layers to be stacked. |
required |
norm |
Module
|
The normalization layer to be applied. |
None
|
Source code in dreem/models/transformer.py
def __init__(
self,
encoder_layer: TransformerEncoderLayer,
num_layers: int,
norm: nn.Module = None,
) -> None:
"""Initialize transformer encoder.
Args:
encoder_layer: An instance of the TransformerEncoderLayer.
num_layers: The number of encoder layers to be stacked.
norm: The normalization layer to be applied.
"""
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm if norm is not None else nn.Identity()
forward(queries, pos_emb=None)
¶
Execute a forward pass of encoder layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
queries |
Tensor
|
The input tensor of shape (n_query, batch_size, embed_dim). |
required |
pos_emb |
Tensor
|
The positional embedding tensor of shape (n_query, embed_dim). |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
The output tensor of shape (n_query, batch_size, embed_dim). |
Source code in dreem/models/transformer.py
def forward(
self, queries: torch.Tensor, pos_emb: torch.Tensor = None
) -> torch.Tensor:
"""Execute a forward pass of encoder layer.
Args:
queries: The input tensor of shape (n_query, batch_size, embed_dim).
pos_emb: The positional embedding tensor of shape (n_query, embed_dim).
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
for layer in self.layers:
queries = layer(queries, pos_emb=pos_emb)
encoder_features = self.norm(queries)
return encoder_features
TransformerEncoderLayer
¶
Bases: Module
A single transformer encoder layer.
Source code in dreem/models/transformer.py
class TransformerEncoderLayer(nn.Module):
"""A single transformer encoder layer."""
def __init__(
self,
d_model: int = 1024,
nhead: int = 6,
dim_feedforward: int = 1024,
dropout: float = 0.1,
activation: str = "relu",
norm: bool = False,
) -> None:
"""Initialize a transformer encoder layer.
Args:
d_model: The number of features in the encoder inputs.
nhead: The number of heads for the encoder.
dim_feedforward: Dimension of the feedforward layers of encoder.
dropout: Dropout value applied to the output of encoder.
activation: Activation function to use.
norm: If True, normalize output of encoder.
"""
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def forward(
self, queries: torch.Tensor, pos_emb: torch.Tensor = None
) -> torch.Tensor:
"""Execute a forward pass of the encoder layer.
Args:
queries: Input sequence for encoder (n_query, batch_size, embed_dim).
pos_emb: Position embedding, if provided is added to src
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
if pos_emb is None:
pos_emb = torch.zeros_like(queries)
queries = queries + pos_emb
# q = k = src
attn_features = self.self_attn(
query=queries,
key=queries,
value=queries,
)[0]
queries = queries + self.dropout1(attn_features)
queries = self.norm1(queries)
projection = self.linear2(self.dropout(self.activation(self.linear1(queries))))
queries = queries + self.dropout2(projection)
encoder_features = self.norm2(queries)
return encoder_features
__init__(d_model=1024, nhead=6, dim_feedforward=1024, dropout=0.1, activation='relu', norm=False)
¶
Initialize a transformer encoder layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
d_model |
int
|
The number of features in the encoder inputs. |
1024
|
nhead |
int
|
The number of heads for the encoder. |
6
|
dim_feedforward |
int
|
Dimension of the feedforward layers of encoder. |
1024
|
dropout |
float
|
Dropout value applied to the output of encoder. |
0.1
|
activation |
str
|
Activation function to use. |
'relu'
|
norm |
bool
|
If True, normalize output of encoder. |
False
|
Source code in dreem/models/transformer.py
def __init__(
self,
d_model: int = 1024,
nhead: int = 6,
dim_feedforward: int = 1024,
dropout: float = 0.1,
activation: str = "relu",
norm: bool = False,
) -> None:
"""Initialize a transformer encoder layer.
Args:
d_model: The number of features in the encoder inputs.
nhead: The number of heads for the encoder.
dim_feedforward: Dimension of the feedforward layers of encoder.
dropout: Dropout value applied to the output of encoder.
activation: Activation function to use.
norm: If True, normalize output of encoder.
"""
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
forward(queries, pos_emb=None)
¶
Execute a forward pass of the encoder layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
queries |
Tensor
|
Input sequence for encoder (n_query, batch_size, embed_dim). |
required |
pos_emb |
Tensor
|
Position embedding, if provided is added to src |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
The output tensor of shape (n_query, batch_size, embed_dim). |
Source code in dreem/models/transformer.py
def forward(
self, queries: torch.Tensor, pos_emb: torch.Tensor = None
) -> torch.Tensor:
"""Execute a forward pass of the encoder layer.
Args:
queries: Input sequence for encoder (n_query, batch_size, embed_dim).
pos_emb: Position embedding, if provided is added to src
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
if pos_emb is None:
pos_emb = torch.zeros_like(queries)
queries = queries + pos_emb
# q = k = src
attn_features = self.self_attn(
query=queries,
key=queries,
value=queries,
)[0]
queries = queries + self.dropout1(attn_features)
queries = self.norm1(queries)
projection = self.linear2(self.dropout(self.activation(self.linear1(queries))))
queries = queries + self.dropout2(projection)
encoder_features = self.norm2(queries)
return encoder_features