Model Parts¶
dreem.models.VisualEncoder
¶
Bases: Module
Class wrapping around a visual feature extractor backbone.
Currently CNN only.
Source code in dreem/models/visual_encoder.py
class VisualEncoder(torch.nn.Module):
"""Class wrapping around a visual feature extractor backbone.
Currently CNN only.
"""
def __init__(
self,
model_name: str = "resnet18",
d_model: int = 512,
in_chans: int = 3,
backend: int = "timm",
**kwargs: Optional[Any],
):
"""Initialize Visual Encoder.
Args:
model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
d_model (int): Output embedding dimension.
in_chans: the number of input channels of the image.
backend: Which model backend to use. One of {"timm", "torchvision"}
kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.
"""
super().__init__()
self.model_name = model_name.lower()
self.d_model = d_model
self.backend = backend
if in_chans == 1:
self.in_chans = 3
else:
self.in_chans = in_chans
self.feature_extractor = self.select_feature_extractor(
model_name=self.model_name,
in_chans=self.in_chans,
backend=self.backend,
**kwargs,
)
self.out_layer = torch.nn.Linear(
self.encoder_dim(self.feature_extractor), self.d_model
)
def select_feature_extractor(
self, model_name: str, in_chans: int, backend: str, **kwargs: Optional[Any]
) -> torch.nn.Module:
"""Select the appropriate feature extractor based on config.
Args:
model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
in_chans: the number of input channels of the image.
backend: Which model backend to use. One of {"timm", "torchvision"}
kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.
Returns:
a CNN encoder based on the config and backend selected.
"""
if "timm" in backend.lower():
feature_extractor = timm.create_model(
model_name=self.model_name,
in_chans=self.in_chans,
num_classes=0,
**kwargs,
)
elif "torch" in backend.lower():
if model_name.lower() == "resnet18":
feature_extractor = torchvision.models.resnet18(**kwargs)
elif model_name.lower() == "resnet50":
feature_extractor = torchvision.models.resnet50(**kwargs)
else:
raise ValueError(
f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}"
)
feature_extractor = torch.nn.Sequential(
*list(feature_extractor.children())[:-1]
)
input_layer = feature_extractor[0]
if in_chans != 3:
feature_extractor[0] = torch.nn.Conv2d(
in_channels=in_chans,
out_channels=input_layer.out_channels,
kernel_size=input_layer.kernel_size,
stride=input_layer.stride,
padding=input_layer.padding,
dilation=input_layer.dilation,
groups=input_layer.groups,
bias=input_layer.bias,
padding_mode=input_layer.padding_mode,
)
else:
raise ValueError(
f"Only ['timm', 'torch'] backends are available! Found {backend}."
)
return feature_extractor
def encoder_dim(self, model: torch.nn.Module) -> int:
"""Compute dummy forward pass of encoder model and get embedding dimension.
Args:
model: a vision encoder model.
Returns:
The embedding dimension size.
"""
_ = model.eval()
dummy_output = model(torch.randn(1, self.in_chans, 224, 224)).squeeze()
_ = model.train() # to be safe
return dummy_output.shape[-1]
def forward(self, img: torch.Tensor) -> torch.Tensor:
"""Forward pass of feature extractor to get feature vector.
Args:
img: Input image tensor of shape (B, C, H, W).
Returns:
feats: Normalized output tensor of shape (B, d_model).
"""
# If grayscale, tile the image to 3 channels.
if img.shape[1] == 1:
img = img.repeat([1, 3, 1, 1]) # (B, nc=3, H, W)
# Extract image features
feats = self.feature_extractor(
img
) # (B, out_dim, 1, 1) if using resnet18 backbone.
# Reshape feature vectors
feats = feats.reshape([img.shape[0], -1]) # (B, out_dim)
# Map feature vectors to output dimension using linear layer.
feats = self.out_layer(feats) # (B, d_model)
# Normalize output feature vectors.
feats = F.normalize(feats) # (B, d_model)
return feats
__init__(model_name='resnet18', d_model=512, in_chans=3, backend='timm', **kwargs)
¶
Initialize Visual Encoder.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str
|
Name of the CNN architecture to use (e.g. "resnet18", "resnet50"). |
'resnet18'
|
d_model |
int
|
Output embedding dimension. |
512
|
in_chans |
int
|
the number of input channels of the image. |
3
|
backend |
int
|
Which model backend to use. One of {"timm", "torchvision"} |
'timm'
|
kwargs |
Optional[Any]
|
see |
{}
|
Source code in dreem/models/visual_encoder.py
def __init__(
self,
model_name: str = "resnet18",
d_model: int = 512,
in_chans: int = 3,
backend: int = "timm",
**kwargs: Optional[Any],
):
"""Initialize Visual Encoder.
Args:
model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
d_model (int): Output embedding dimension.
in_chans: the number of input channels of the image.
backend: Which model backend to use. One of {"timm", "torchvision"}
kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.
"""
super().__init__()
self.model_name = model_name.lower()
self.d_model = d_model
self.backend = backend
if in_chans == 1:
self.in_chans = 3
else:
self.in_chans = in_chans
self.feature_extractor = self.select_feature_extractor(
model_name=self.model_name,
in_chans=self.in_chans,
backend=self.backend,
**kwargs,
)
self.out_layer = torch.nn.Linear(
self.encoder_dim(self.feature_extractor), self.d_model
)
encoder_dim(model)
¶
Compute dummy forward pass of encoder model and get embedding dimension.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module
|
a vision encoder model. |
required |
Returns:
Type | Description |
---|---|
int
|
The embedding dimension size. |
Source code in dreem/models/visual_encoder.py
def encoder_dim(self, model: torch.nn.Module) -> int:
"""Compute dummy forward pass of encoder model and get embedding dimension.
Args:
model: a vision encoder model.
Returns:
The embedding dimension size.
"""
_ = model.eval()
dummy_output = model(torch.randn(1, self.in_chans, 224, 224)).squeeze()
_ = model.train() # to be safe
return dummy_output.shape[-1]
forward(img)
¶
Forward pass of feature extractor to get feature vector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img |
Tensor
|
Input image tensor of shape (B, C, H, W). |
required |
Returns:
Name | Type | Description |
---|---|---|
feats |
Tensor
|
Normalized output tensor of shape (B, d_model). |
Source code in dreem/models/visual_encoder.py
def forward(self, img: torch.Tensor) -> torch.Tensor:
"""Forward pass of feature extractor to get feature vector.
Args:
img: Input image tensor of shape (B, C, H, W).
Returns:
feats: Normalized output tensor of shape (B, d_model).
"""
# If grayscale, tile the image to 3 channels.
if img.shape[1] == 1:
img = img.repeat([1, 3, 1, 1]) # (B, nc=3, H, W)
# Extract image features
feats = self.feature_extractor(
img
) # (B, out_dim, 1, 1) if using resnet18 backbone.
# Reshape feature vectors
feats = feats.reshape([img.shape[0], -1]) # (B, out_dim)
# Map feature vectors to output dimension using linear layer.
feats = self.out_layer(feats) # (B, d_model)
# Normalize output feature vectors.
feats = F.normalize(feats) # (B, d_model)
return feats
select_feature_extractor(model_name, in_chans, backend, **kwargs)
¶
Select the appropriate feature extractor based on config.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str
|
Name of the CNN architecture to use (e.g. "resnet18", "resnet50"). |
required |
in_chans |
int
|
the number of input channels of the image. |
required |
backend |
str
|
Which model backend to use. One of {"timm", "torchvision"} |
required |
kwargs |
Optional[Any]
|
see |
{}
|
Returns:
Type | Description |
---|---|
Module
|
a CNN encoder based on the config and backend selected. |
Source code in dreem/models/visual_encoder.py
def select_feature_extractor(
self, model_name: str, in_chans: int, backend: str, **kwargs: Optional[Any]
) -> torch.nn.Module:
"""Select the appropriate feature extractor based on config.
Args:
model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
in_chans: the number of input channels of the image.
backend: Which model backend to use. One of {"timm", "torchvision"}
kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.
Returns:
a CNN encoder based on the config and backend selected.
"""
if "timm" in backend.lower():
feature_extractor = timm.create_model(
model_name=self.model_name,
in_chans=self.in_chans,
num_classes=0,
**kwargs,
)
elif "torch" in backend.lower():
if model_name.lower() == "resnet18":
feature_extractor = torchvision.models.resnet18(**kwargs)
elif model_name.lower() == "resnet50":
feature_extractor = torchvision.models.resnet50(**kwargs)
else:
raise ValueError(
f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}"
)
feature_extractor = torch.nn.Sequential(
*list(feature_extractor.children())[:-1]
)
input_layer = feature_extractor[0]
if in_chans != 3:
feature_extractor[0] = torch.nn.Conv2d(
in_channels=in_chans,
out_channels=input_layer.out_channels,
kernel_size=input_layer.kernel_size,
stride=input_layer.stride,
padding=input_layer.padding,
dilation=input_layer.dilation,
groups=input_layer.groups,
bias=input_layer.bias,
padding_mode=input_layer.padding_mode,
)
else:
raise ValueError(
f"Only ['timm', 'torch'] backends are available! Found {backend}."
)
return feature_extractor
dreem.models.Transformer
¶
Bases: Module
Transformer class.
Source code in dreem/models/transformer.py
class Transformer(torch.nn.Module):
"""Transformer class."""
def __init__(
self,
d_model: int = 1024,
nhead: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dropout: float = 0.1,
activation: str = "relu",
return_intermediate_dec: bool = False,
norm: bool = False,
num_layers_attn_head: int = 2,
dropout_attn_head: float = 0.1,
embedding_meta: dict = None,
return_embedding: bool = False,
decoder_self_attn: bool = False,
) -> None:
"""Initialize Transformer.
Args:
d_model: The number of features in the encoder/decoder inputs.
nhead: The number of heads in the transfomer encoder/decoder.
num_encoder_layers: The number of encoder-layers in the encoder.
num_decoder_layers: The number of decoder-layers in the decoder.
dropout: Dropout value applied to the output of transformer layers.
activation: Activation function to use.
return_intermediate_dec: Return intermediate layers from decoder.
norm: If True, normalize output of encoder and decoder.
num_layers_attn_head: The number of layers in the attention head.
dropout_attn_head: Dropout value for the attention_head.
embedding_meta: Metadata for positional embeddings. See below.
return_embedding: Whether to return the positional embeddings
decoder_self_attn: If True, use decoder self attention.
More details on `embedding_meta`:
By default this will be an empty dict and indicate
that no positional embeddings should be used. To use the positional embeddings
pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie:
{"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'},
"temp": {'mode': 'learned', 'emb_num': 16}}. (see `dreem.models.embeddings.Embedding.EMB_TYPES`
and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters).
"""
super().__init__()
self.d_model = dim_feedforward = feature_dim_attn_head = d_model
self.embedding_meta = embedding_meta
self.return_embedding = return_embedding
self.pos_emb = Embedding(emb_type="off", mode="off", features=self.d_model)
self.temp_emb = Embedding(emb_type="off", mode="off", features=self.d_model)
if self.embedding_meta:
if "pos" in self.embedding_meta:
pos_emb_cfg = self.embedding_meta["pos"]
if pos_emb_cfg:
self.pos_emb = Embedding(
emb_type="pos", features=self.d_model, **pos_emb_cfg
)
if "temp" in self.embedding_meta:
temp_emb_cfg = self.embedding_meta["temp"]
if temp_emb_cfg:
self.temp_emb = Embedding(
emb_type="temp", features=self.d_model, **temp_emb_cfg
)
# Transformer Encoder
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, norm
)
encoder_norm = nn.LayerNorm(d_model) if (norm) else None
self.encoder = TransformerEncoder(
encoder_layer, num_encoder_layers, encoder_norm
)
# Transformer Decoder
decoder_layer = TransformerDecoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
activation,
norm,
decoder_self_attn,
)
decoder_norm = nn.LayerNorm(d_model) if (norm) else None
self.decoder = TransformerDecoder(
decoder_layer, num_decoder_layers, return_intermediate_dec, decoder_norm
)
# Transformer attention head
self.attn_head = ATTWeightHead(
feature_dim=feature_dim_attn_head,
num_layers=num_layers_attn_head,
dropout=dropout_attn_head,
)
self._reset_parameters()
def _reset_parameters(self):
"""Initialize model weights from xavier distribution."""
for p in self.parameters():
if not torch.nn.parameter.is_lazy(p) and p.dim() > 1:
try:
nn.init.xavier_uniform_(p)
except ValueError as e:
print(f"Failed Trying to initialize {p}")
raise (e)
def forward(
self,
ref_instances: list["dreem.io.Instance"],
query_instances: list["dreem.io.Instance"] = None,
) -> list[AssociationMatrix]:
"""Execute a forward pass through the transformer and attention head.
Args:
ref_instances: A list of instance objects (See `dreem.io.Instance` for more info.)
query_instances: An set of instances to be used as decoder queries.
Returns:
asso_output: A list of torch.Tensors of shape (L, n_query, total_instances) where:
L: number of decoder blocks
n_query: number of instances in current query/frame
total_instances: number of instances in window
"""
ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)
# window_length = len(frames)
# instances_per_frame = [frame.num_detected for frame in frames]
total_instances = len(ref_instances)
embed_dim = ref_features.shape[-1]
# print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}')
ref_boxes = get_boxes(ref_instances) # total_instances, 4
ref_boxes = torch.nan_to_num(ref_boxes, -1.0)
ref_times, query_times = get_times(ref_instances, query_instances)
window_length = len(ref_times.unique())
ref_temp_emb = self.temp_emb(ref_times / window_length)
ref_pos_emb = self.pos_emb(ref_boxes)
if self.return_embedding:
for i, instance in enumerate(ref_instances):
instance.add_embedding("pos", ref_pos_emb[i])
instance.add_embedding("temp", ref_temp_emb[i])
ref_emb = (ref_pos_emb + ref_temp_emb) / 2.0
ref_emb = ref_emb.view(1, total_instances, embed_dim)
ref_emb = ref_emb.permute(1, 0, 2) # (total_instances, batch_size, embed_dim)
batch_size, total_instances, embed_dim = ref_features.shape
ref_features = ref_features.permute(
1, 0, 2
) # (total_instances, batch_size, embed_dim)
encoder_queries = ref_features
encoder_features = self.encoder(
encoder_queries, pos_emb=ref_emb
) # (total_instances, batch_size, embed_dim)
n_query = total_instances
query_features = ref_features
query_pos_emb = ref_pos_emb
query_temp_emb = ref_temp_emb
query_emb = ref_emb
if query_instances is not None:
n_query = len(query_instances)
query_features = torch.cat(
[instance.features for instance in query_instances], dim=0
).unsqueeze(0)
query_features = query_features.permute(
1, 0, 2
) # (n_query, batch_size, embed_dim)
query_boxes = get_boxes(query_instances)
query_boxes = torch.nan_to_num(query_boxes, -1.0)
query_temp_emb = self.temp_emb(query_times / window_length)
query_pos_emb = self.pos_emb(query_boxes)
query_emb = (query_pos_emb + query_temp_emb) / 2.0
query_emb = query_emb.view(1, n_query, embed_dim)
query_emb = query_emb.permute(1, 0, 2) # (n_query, batch_size, embed_dim)
else:
query_instances = ref_instances
if self.return_embedding:
for i, instance in enumerate(query_instances):
instance.add_embedding("pos", query_pos_emb[i])
instance.add_embedding("temp", query_temp_emb[i])
decoder_features = self.decoder(
query_features,
encoder_features,
ref_pos_emb=ref_emb,
query_pos_emb=query_emb,
) # (L, n_query, batch_size, embed_dim)
decoder_features = decoder_features.transpose(
1, 2
) # # (L, batch_size, n_query, embed_dim)
encoder_features = encoder_features.permute(1, 0, 2).view(
batch_size, total_instances, embed_dim
) # (batch_size, total_instances, embed_dim)
asso_output = []
for frame_features in decoder_features:
asso_matrix = self.attn_head(frame_features, encoder_features).view(
n_query, total_instances
)
asso_matrix = AssociationMatrix(asso_matrix, ref_instances, query_instances)
asso_output.append(asso_matrix)
# (L=1, n_query, total_instances)
return asso_output
__init__(d_model=1024, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dropout=0.1, activation='relu', return_intermediate_dec=False, norm=False, num_layers_attn_head=2, dropout_attn_head=0.1, embedding_meta=None, return_embedding=False, decoder_self_attn=False)
¶
Initialize Transformer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
d_model |
int
|
The number of features in the encoder/decoder inputs. |
1024
|
nhead |
int
|
The number of heads in the transfomer encoder/decoder. |
8
|
num_encoder_layers |
int
|
The number of encoder-layers in the encoder. |
6
|
num_decoder_layers |
int
|
The number of decoder-layers in the decoder. |
6
|
dropout |
float
|
Dropout value applied to the output of transformer layers. |
0.1
|
activation |
str
|
Activation function to use. |
'relu'
|
return_intermediate_dec |
bool
|
Return intermediate layers from decoder. |
False
|
norm |
bool
|
If True, normalize output of encoder and decoder. |
False
|
num_layers_attn_head |
int
|
The number of layers in the attention head. |
2
|
dropout_attn_head |
float
|
Dropout value for the attention_head. |
0.1
|
embedding_meta |
dict
|
Metadata for positional embeddings. See below. |
None
|
return_embedding |
bool
|
Whether to return the positional embeddings |
False
|
decoder_self_attn |
bool
|
If True, use decoder self attention. More details on |
False
|
Source code in dreem/models/transformer.py
def __init__(
self,
d_model: int = 1024,
nhead: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dropout: float = 0.1,
activation: str = "relu",
return_intermediate_dec: bool = False,
norm: bool = False,
num_layers_attn_head: int = 2,
dropout_attn_head: float = 0.1,
embedding_meta: dict = None,
return_embedding: bool = False,
decoder_self_attn: bool = False,
) -> None:
"""Initialize Transformer.
Args:
d_model: The number of features in the encoder/decoder inputs.
nhead: The number of heads in the transfomer encoder/decoder.
num_encoder_layers: The number of encoder-layers in the encoder.
num_decoder_layers: The number of decoder-layers in the decoder.
dropout: Dropout value applied to the output of transformer layers.
activation: Activation function to use.
return_intermediate_dec: Return intermediate layers from decoder.
norm: If True, normalize output of encoder and decoder.
num_layers_attn_head: The number of layers in the attention head.
dropout_attn_head: Dropout value for the attention_head.
embedding_meta: Metadata for positional embeddings. See below.
return_embedding: Whether to return the positional embeddings
decoder_self_attn: If True, use decoder self attention.
More details on `embedding_meta`:
By default this will be an empty dict and indicate
that no positional embeddings should be used. To use the positional embeddings
pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie:
{"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'},
"temp": {'mode': 'learned', 'emb_num': 16}}. (see `dreem.models.embeddings.Embedding.EMB_TYPES`
and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters).
"""
super().__init__()
self.d_model = dim_feedforward = feature_dim_attn_head = d_model
self.embedding_meta = embedding_meta
self.return_embedding = return_embedding
self.pos_emb = Embedding(emb_type="off", mode="off", features=self.d_model)
self.temp_emb = Embedding(emb_type="off", mode="off", features=self.d_model)
if self.embedding_meta:
if "pos" in self.embedding_meta:
pos_emb_cfg = self.embedding_meta["pos"]
if pos_emb_cfg:
self.pos_emb = Embedding(
emb_type="pos", features=self.d_model, **pos_emb_cfg
)
if "temp" in self.embedding_meta:
temp_emb_cfg = self.embedding_meta["temp"]
if temp_emb_cfg:
self.temp_emb = Embedding(
emb_type="temp", features=self.d_model, **temp_emb_cfg
)
# Transformer Encoder
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, norm
)
encoder_norm = nn.LayerNorm(d_model) if (norm) else None
self.encoder = TransformerEncoder(
encoder_layer, num_encoder_layers, encoder_norm
)
# Transformer Decoder
decoder_layer = TransformerDecoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
activation,
norm,
decoder_self_attn,
)
decoder_norm = nn.LayerNorm(d_model) if (norm) else None
self.decoder = TransformerDecoder(
decoder_layer, num_decoder_layers, return_intermediate_dec, decoder_norm
)
# Transformer attention head
self.attn_head = ATTWeightHead(
feature_dim=feature_dim_attn_head,
num_layers=num_layers_attn_head,
dropout=dropout_attn_head,
)
self._reset_parameters()
forward(ref_instances, query_instances=None)
¶
Execute a forward pass through the transformer and attention head.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ref_instances |
list[Instance]
|
A list of instance objects (See |
required |
query_instances |
list[Instance]
|
An set of instances to be used as decoder queries. |
None
|
Returns:
Name | Type | Description |
---|---|---|
asso_output |
list[AssociationMatrix]
|
A list of torch.Tensors of shape (L, n_query, total_instances) where: L: number of decoder blocks n_query: number of instances in current query/frame total_instances: number of instances in window |
Source code in dreem/models/transformer.py
def forward(
self,
ref_instances: list["dreem.io.Instance"],
query_instances: list["dreem.io.Instance"] = None,
) -> list[AssociationMatrix]:
"""Execute a forward pass through the transformer and attention head.
Args:
ref_instances: A list of instance objects (See `dreem.io.Instance` for more info.)
query_instances: An set of instances to be used as decoder queries.
Returns:
asso_output: A list of torch.Tensors of shape (L, n_query, total_instances) where:
L: number of decoder blocks
n_query: number of instances in current query/frame
total_instances: number of instances in window
"""
ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)
# window_length = len(frames)
# instances_per_frame = [frame.num_detected for frame in frames]
total_instances = len(ref_instances)
embed_dim = ref_features.shape[-1]
# print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}')
ref_boxes = get_boxes(ref_instances) # total_instances, 4
ref_boxes = torch.nan_to_num(ref_boxes, -1.0)
ref_times, query_times = get_times(ref_instances, query_instances)
window_length = len(ref_times.unique())
ref_temp_emb = self.temp_emb(ref_times / window_length)
ref_pos_emb = self.pos_emb(ref_boxes)
if self.return_embedding:
for i, instance in enumerate(ref_instances):
instance.add_embedding("pos", ref_pos_emb[i])
instance.add_embedding("temp", ref_temp_emb[i])
ref_emb = (ref_pos_emb + ref_temp_emb) / 2.0
ref_emb = ref_emb.view(1, total_instances, embed_dim)
ref_emb = ref_emb.permute(1, 0, 2) # (total_instances, batch_size, embed_dim)
batch_size, total_instances, embed_dim = ref_features.shape
ref_features = ref_features.permute(
1, 0, 2
) # (total_instances, batch_size, embed_dim)
encoder_queries = ref_features
encoder_features = self.encoder(
encoder_queries, pos_emb=ref_emb
) # (total_instances, batch_size, embed_dim)
n_query = total_instances
query_features = ref_features
query_pos_emb = ref_pos_emb
query_temp_emb = ref_temp_emb
query_emb = ref_emb
if query_instances is not None:
n_query = len(query_instances)
query_features = torch.cat(
[instance.features for instance in query_instances], dim=0
).unsqueeze(0)
query_features = query_features.permute(
1, 0, 2
) # (n_query, batch_size, embed_dim)
query_boxes = get_boxes(query_instances)
query_boxes = torch.nan_to_num(query_boxes, -1.0)
query_temp_emb = self.temp_emb(query_times / window_length)
query_pos_emb = self.pos_emb(query_boxes)
query_emb = (query_pos_emb + query_temp_emb) / 2.0
query_emb = query_emb.view(1, n_query, embed_dim)
query_emb = query_emb.permute(1, 0, 2) # (n_query, batch_size, embed_dim)
else:
query_instances = ref_instances
if self.return_embedding:
for i, instance in enumerate(query_instances):
instance.add_embedding("pos", query_pos_emb[i])
instance.add_embedding("temp", query_temp_emb[i])
decoder_features = self.decoder(
query_features,
encoder_features,
ref_pos_emb=ref_emb,
query_pos_emb=query_emb,
) # (L, n_query, batch_size, embed_dim)
decoder_features = decoder_features.transpose(
1, 2
) # # (L, batch_size, n_query, embed_dim)
encoder_features = encoder_features.permute(1, 0, 2).view(
batch_size, total_instances, embed_dim
) # (batch_size, total_instances, embed_dim)
asso_output = []
for frame_features in decoder_features:
asso_matrix = self.attn_head(frame_features, encoder_features).view(
n_query, total_instances
)
asso_matrix = AssociationMatrix(asso_matrix, ref_instances, query_instances)
asso_output.append(asso_matrix)
# (L=1, n_query, total_instances)
return asso_output
dreem.models.transformer.TransformerEncoder
¶
Bases: Module
A transformer encoder block composed of encoder layers.
Source code in dreem/models/transformer.py
class TransformerEncoder(nn.Module):
"""A transformer encoder block composed of encoder layers."""
def __init__(
self,
encoder_layer: TransformerEncoderLayer,
num_layers: int,
norm: nn.Module = None,
) -> None:
"""Initialize transformer encoder.
Args:
encoder_layer: An instance of the TransformerEncoderLayer.
num_layers: The number of encoder layers to be stacked.
norm: The normalization layer to be applied.
"""
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm if norm is not None else nn.Identity()
def forward(
self, queries: torch.Tensor, pos_emb: torch.Tensor = None
) -> torch.Tensor:
"""Execute a forward pass of encoder layer.
Args:
queries: The input tensor of shape (n_query, batch_size, embed_dim).
pos_emb: The positional embedding tensor of shape (n_query, embed_dim).
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
for layer in self.layers:
queries = layer(queries, pos_emb=pos_emb)
encoder_features = self.norm(queries)
return encoder_features
__init__(encoder_layer, num_layers, norm=None)
¶
Initialize transformer encoder.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
encoder_layer |
TransformerEncoderLayer
|
An instance of the TransformerEncoderLayer. |
required |
num_layers |
int
|
The number of encoder layers to be stacked. |
required |
norm |
Module
|
The normalization layer to be applied. |
None
|
Source code in dreem/models/transformer.py
def __init__(
self,
encoder_layer: TransformerEncoderLayer,
num_layers: int,
norm: nn.Module = None,
) -> None:
"""Initialize transformer encoder.
Args:
encoder_layer: An instance of the TransformerEncoderLayer.
num_layers: The number of encoder layers to be stacked.
norm: The normalization layer to be applied.
"""
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm if norm is not None else nn.Identity()
forward(queries, pos_emb=None)
¶
Execute a forward pass of encoder layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
queries |
Tensor
|
The input tensor of shape (n_query, batch_size, embed_dim). |
required |
pos_emb |
Tensor
|
The positional embedding tensor of shape (n_query, embed_dim). |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
The output tensor of shape (n_query, batch_size, embed_dim). |
Source code in dreem/models/transformer.py
def forward(
self, queries: torch.Tensor, pos_emb: torch.Tensor = None
) -> torch.Tensor:
"""Execute a forward pass of encoder layer.
Args:
queries: The input tensor of shape (n_query, batch_size, embed_dim).
pos_emb: The positional embedding tensor of shape (n_query, embed_dim).
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
for layer in self.layers:
queries = layer(queries, pos_emb=pos_emb)
encoder_features = self.norm(queries)
return encoder_features
dreem.models.transformer.TransformerEncoderLayer
¶
Bases: Module
A single transformer encoder layer.
Source code in dreem/models/transformer.py
class TransformerEncoderLayer(nn.Module):
"""A single transformer encoder layer."""
def __init__(
self,
d_model: int = 1024,
nhead: int = 6,
dim_feedforward: int = 1024,
dropout: float = 0.1,
activation: str = "relu",
norm: bool = False,
) -> None:
"""Initialize a transformer encoder layer.
Args:
d_model: The number of features in the encoder inputs.
nhead: The number of heads for the encoder.
dim_feedforward: Dimension of the feedforward layers of encoder.
dropout: Dropout value applied to the output of encoder.
activation: Activation function to use.
norm: If True, normalize output of encoder.
"""
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def forward(
self, queries: torch.Tensor, pos_emb: torch.Tensor = None
) -> torch.Tensor:
"""Execute a forward pass of the encoder layer.
Args:
queries: Input sequence for encoder (n_query, batch_size, embed_dim).
pos_emb: Position embedding, if provided is added to src
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
if pos_emb is None:
pos_emb = torch.zeros_like(queries)
queries = queries + pos_emb
# q = k = src
attn_features = self.self_attn(
query=queries,
key=queries,
value=queries,
)[0]
queries = queries + self.dropout1(attn_features)
queries = self.norm1(queries)
projection = self.linear2(self.dropout(self.activation(self.linear1(queries))))
queries = queries + self.dropout2(projection)
encoder_features = self.norm2(queries)
return encoder_features
__init__(d_model=1024, nhead=6, dim_feedforward=1024, dropout=0.1, activation='relu', norm=False)
¶
Initialize a transformer encoder layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
d_model |
int
|
The number of features in the encoder inputs. |
1024
|
nhead |
int
|
The number of heads for the encoder. |
6
|
dim_feedforward |
int
|
Dimension of the feedforward layers of encoder. |
1024
|
dropout |
float
|
Dropout value applied to the output of encoder. |
0.1
|
activation |
str
|
Activation function to use. |
'relu'
|
norm |
bool
|
If True, normalize output of encoder. |
False
|
Source code in dreem/models/transformer.py
def __init__(
self,
d_model: int = 1024,
nhead: int = 6,
dim_feedforward: int = 1024,
dropout: float = 0.1,
activation: str = "relu",
norm: bool = False,
) -> None:
"""Initialize a transformer encoder layer.
Args:
d_model: The number of features in the encoder inputs.
nhead: The number of heads for the encoder.
dim_feedforward: Dimension of the feedforward layers of encoder.
dropout: Dropout value applied to the output of encoder.
activation: Activation function to use.
norm: If True, normalize output of encoder.
"""
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
forward(queries, pos_emb=None)
¶
Execute a forward pass of the encoder layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
queries |
Tensor
|
Input sequence for encoder (n_query, batch_size, embed_dim). |
required |
pos_emb |
Tensor
|
Position embedding, if provided is added to src |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
The output tensor of shape (n_query, batch_size, embed_dim). |
Source code in dreem/models/transformer.py
def forward(
self, queries: torch.Tensor, pos_emb: torch.Tensor = None
) -> torch.Tensor:
"""Execute a forward pass of the encoder layer.
Args:
queries: Input sequence for encoder (n_query, batch_size, embed_dim).
pos_emb: Position embedding, if provided is added to src
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
if pos_emb is None:
pos_emb = torch.zeros_like(queries)
queries = queries + pos_emb
# q = k = src
attn_features = self.self_attn(
query=queries,
key=queries,
value=queries,
)[0]
queries = queries + self.dropout1(attn_features)
queries = self.norm1(queries)
projection = self.linear2(self.dropout(self.activation(self.linear1(queries))))
queries = queries + self.dropout2(projection)
encoder_features = self.norm2(queries)
return encoder_features
dreem.models.Embedding
¶
Bases: Module
Class that wraps around different embedding types.
Used for both learned and fixed embeddings.
Source code in dreem/models/embedding.py
class Embedding(torch.nn.Module):
"""Class that wraps around different embedding types.
Used for both learned and fixed embeddings.
"""
EMB_TYPES = {
"temp": {},
"pos": {"over_boxes"},
"off": {},
None: {},
} # dict of valid args:keyword params
EMB_MODES = {
"fixed": {"temperature", "scale", "normalize"},
"learned": {"emb_num"},
"off": {},
} # dict of valid args:keyword params
def __init__(
self,
emb_type: str,
mode: str,
features: int,
n_points: Optional[int] = 1,
emb_num: Optional[int] = 16,
over_boxes: Optional[bool] = True,
temperature: Optional[int] = 10000,
normalize: Optional[bool] = False,
scale: Optional[float] = None,
mlp_cfg: dict = None,
):
"""Initialize embeddings.
Args:
emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", "off"}`
mode: The mode or function used to map positions to vector embeddings.
Must be one of `{"fixed", "learned", "off"}`
features: The embedding dimensions. Must match the dimension of the
input vectors for the transformer model.
n_points: the number of points that will be embedded.
emb_num: the number of embeddings in the `self.lookup` table (Only used in learned embeddings).
over_boxes: Whether to compute the position embedding for each bbox coordinate (y1x1y2x2) or the centroid + bbox size (yxwh).
temperature: the temperature constant to be used when computing the sinusoidal position embedding
normalize: whether or not to normalize the positions (Only used in fixed embeddings).
scale: factor by which to scale the positions after normalizing (Only used in fixed embeddings).
mlp_cfg: A dictionary of mlp hyperparameters for projecting embedding to correct space.
Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3}
"""
self._check_init_args(emb_type, mode)
super().__init__()
self.emb_type = emb_type
self.mode = mode
self.features = features
self.emb_num = emb_num
self.over_boxes = over_boxes
self.temperature = temperature
self.normalize = normalize
self.scale = scale
self.n_points = n_points
if self.normalize and self.scale is None:
self.scale = 2 * math.pi
if self.emb_type == "pos" and mlp_cfg is not None and mlp_cfg["num_layers"] > 0:
if self.mode == "fixed":
self.mlp = MLP(
input_dim=n_points * self.features,
output_dim=self.features,
**mlp_cfg,
)
else:
in_dim = (self.features // (4 * n_points)) * (4 * n_points)
self.mlp = MLP(
input_dim=in_dim,
output_dim=self.features,
**mlp_cfg,
)
else:
self.mlp = torch.nn.Identity()
self._emb_func = lambda tensor: torch.zeros(
(tensor.shape[0], self.features), dtype=tensor.dtype, device=tensor.device
) # turn off embedding by returning zeros
self.lookup = None
if self.mode == "learned":
if self.emb_type == "pos":
self.lookup = torch.nn.Embedding(
self.emb_num * 4 * self.n_points, self.features // (4 * n_points)
)
self._emb_func = self._learned_pos_embedding
elif self.emb_type == "temp":
self.lookup = torch.nn.Embedding(self.emb_num, self.features)
self._emb_func = self._learned_temp_embedding
elif self.mode == "fixed":
if self.emb_type == "pos":
self._emb_func = self._sine_box_embedding
elif self.emb_type == "temp":
self._emb_func = self._sine_temp_embedding
def _check_init_args(self, emb_type: str, mode: str):
"""Check whether the correct arguments were passed to initialization.
Args:
emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", ""}`
mode: The mode or function used to map positions to vector embeddings.
Must be one of `{"fixed", "learned"}`
Raises:
ValueError:
* if the incorrect `emb_type` or `mode` string are passed
NotImplementedError: if `emb_type` is `temp` and `mode` is `fixed`.
"""
if emb_type.lower() not in self.EMB_TYPES:
raise ValueError(
f"Embedding `emb_type` must be one of {self.EMB_TYPES} not {emb_type}"
)
if mode.lower() not in self.EMB_MODES:
raise ValueError(
f"Embedding `mode` must be one of {self.EMB_MODES} not {mode}"
)
def forward(self, seq_positions: torch.Tensor) -> torch.Tensor:
"""Get the sequence positional embeddings.
Args:
seq_positions:
* An (`N`, 1) tensor where seq_positions[i] represents the temporal position of instance_i in the sequence.
* An (`N`, n_anchors x 4) tensor where seq_positions[i, j, :] represents the [y1, x1, y2, x2] spatial locations of jth point of instance_i in the sequence.
Returns:
An `N` x `self.features` tensor representing the corresponding spatial or temporal embedding.
"""
emb = self._emb_func(seq_positions)
if emb.shape[-1] != self.features:
raise RuntimeError(
(
f"Output embedding dimension is {emb.shape[-1]} but requested {self.features} dimensions! \n"
f"hint: Try turning the MLP on by passing `mlp_cfg` to the constructor to project to the correct embedding dimensions."
)
)
return emb
def _torch_int_div(
self, tensor1: torch.Tensor, tensor2: torch.Tensor
) -> torch.Tensor:
"""Perform integer division of two tensors.
Args:
tensor1: dividend tensor.
tensor2: divisor tensor.
Returns:
torch.Tensor, resulting tensor.
"""
return torch.div(tensor1, tensor2, rounding_mode="floor")
def _sine_box_embedding(self, boxes: torch.Tensor) -> torch.Tensor:
"""Compute sine positional embeddings for boxes using given parameters.
Args:
boxes: the input boxes of shape N, n_anchors, 4 or B, N, n_anchors, 4
where the last dimension is the bbox coords in [y1, x1, y2, x2].
(Note currently `B=batch_size=1`).
Returns:
torch.Tensor, the sine positional embeddings
(embedding[:, 4i] = sin(x)
embedding[:, 4i+1] = cos(x)
embedding[:, 4i+2] = sin(y)
embedding[:, 4i+3] = cos(y)
)
"""
if self.scale is not None and self.normalize is False:
raise ValueError("normalize should be True if scale is passed")
if len(boxes.size()) == 3:
boxes = boxes.unsqueeze(0)
if self.normalize:
boxes = boxes / (boxes[:, :, -1:] + 1e-6) * self.scale
dim_t = torch.arange(self.features // 4, dtype=torch.float32)
dim_t = self.temperature ** (
2 * self._torch_int_div(dim_t, 2) / (self.features // 4)
)
# (b, n_t, n_anchors, 4, D//4)
pos_emb = boxes[:, :, :, :, None] / dim_t.to(boxes.device)
pos_emb = torch.stack(
(pos_emb[:, :, :, :, 0::2].sin(), pos_emb[:, :, :, :, 1::2].cos()), dim=4
)
pos_emb = pos_emb.flatten(2).squeeze(0) # (N_t, n_anchors * D)
pos_emb = self.mlp(pos_emb)
pos_emb = pos_emb.view(boxes.shape[1], self.features)
return pos_emb
def _sine_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
"""Compute fixed sine temporal embeddings.
Args:
times: the input times of shape (N,) or (N,1) where N = (sum(instances_per_frame))
which is the frame index of the instance relative
to the batch size
(e.g. `torch.tensor([0, 0, ..., 0, 1, 1, ..., 1, 2, 2, ..., 2,..., B, B, ...B])`).
Returns:
an n_instances x D embedding representing the temporal embedding.
"""
T = times.int().max().item() + 1
d = self.features
n = self.temperature
positions = torch.arange(0, T).unsqueeze(1)
temp_lookup = torch.zeros(T, d, device=times.device)
denominators = torch.pow(
n, 2 * torch.arange(0, d // 2) / d
) # 10000^(2i/d_model), i is the index of embedding
temp_lookup[:, 0::2] = torch.sin(
positions / denominators
) # sin(pos/10000^(2i/d_model))
temp_lookup[:, 1::2] = torch.cos(
positions / denominators
) # cos(pos/10000^(2i/d_model))
temp_emb = temp_lookup[times.int()]
return temp_emb # .view(len(times), self.features)
def _learned_pos_embedding(self, boxes: torch.Tensor) -> torch.Tensor:
"""Compute learned positional embeddings for boxes using given parameters.
Args:
boxes: the input boxes of shape N x 4 or B x N x 4
where the last dimension is the bbox coords in [y1, x1, y2, x2].
(Note currently `B=batch_size=1`).
Returns:
torch.Tensor, the learned positional embeddings.
"""
pos_lookup = self.lookup
N, n_anchors, _ = boxes.shape
boxes = boxes.view(N, n_anchors, 4)
if self.over_boxes:
xywh = boxes
else:
xywh = torch.cat(
[
(boxes[:, :, 2:] + boxes[:, :, :2]) / 2,
(boxes[:, :, 2:] - boxes[:, :, :2]),
],
dim=1,
)
left_ind, right_ind, left_weight, right_weight = self._compute_weights(xywh)
f = pos_lookup.weight.shape[1] # self.features // 4
try:
pos_emb_table = pos_lookup.weight.view(
self.emb_num, n_anchors, 4, f
) # T x 4 x (D * 4)
except RuntimeError as e:
print(f"Hint: `n_points` ({self.n_points}) may be set incorrectly!")
raise (e)
left_emb = pos_emb_table.gather(
0,
left_ind[:, :, :, None].to(pos_emb_table.device).expand(N, n_anchors, 4, f),
) # N x 4 x d
right_emb = pos_emb_table.gather(
0,
right_ind[:, :, :, None]
.to(pos_emb_table.device)
.expand(N, n_anchors, 4, f),
) # N x 4 x d
pos_emb = left_weight[:, :, :, None] * right_emb.to(
left_weight.device
) + right_weight[:, :, :, None] * left_emb.to(right_weight.device)
pos_emb = pos_emb.flatten(1)
pos_emb = self.mlp(pos_emb)
return pos_emb.view(N, self.features)
def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
"""Compute learned temporal embeddings for times using given parameters.
Args:
times: the input times of shape (N,) or (N,1) where N = (sum(instances_per_frame))
which is the frame index of the instance relative
to the batch size
(e.g. `torch.tensor([0, 0, ..., 0, 1, 1, ..., 1, 2, 2, ..., 2,..., B, B, ...B])`).
Returns:
torch.Tensor, the learned temporal embeddings.
"""
temp_lookup = self.lookup
N = times.shape[0]
left_ind, right_ind, left_weight, right_weight = self._compute_weights(times)
left_emb = temp_lookup.weight[
left_ind.to(temp_lookup.weight.device)
] # T x D --> N x D
right_emb = temp_lookup.weight[right_ind.to(temp_lookup.weight.device)]
temp_emb = left_weight[:, None] * right_emb.to(
left_weight.device
) + right_weight[:, None] * left_emb.to(right_weight.device)
return temp_emb.view(N, self.features)
def _compute_weights(self, data: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Compute left and right learned embedding weights.
Args:
data: the input data (e.g boxes or times).
Returns:
A torch.Tensor for each of the left/right indices and weights, respectively
"""
data = data * self.emb_num
left_ind = data.clamp(min=0, max=self.emb_num - 1).long() # N x 4
right_ind = (left_ind + 1).clamp(min=0, max=self.emb_num - 1).long() # N x 4
left_weight = data - left_ind.float() # N x 4
right_weight = 1.0 - left_weight
return left_ind, right_ind, left_weight, right_weight
__init__(emb_type, mode, features, n_points=1, emb_num=16, over_boxes=True, temperature=10000, normalize=False, scale=None, mlp_cfg=None)
¶
Initialize embeddings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
emb_type |
str
|
The type of embedding to compute. Must be one of |
required |
mode |
str
|
The mode or function used to map positions to vector embeddings.
Must be one of |
required |
features |
int
|
The embedding dimensions. Must match the dimension of the input vectors for the transformer model. |
required |
n_points |
Optional[int]
|
the number of points that will be embedded. |
1
|
emb_num |
Optional[int]
|
the number of embeddings in the |
16
|
over_boxes |
Optional[bool]
|
Whether to compute the position embedding for each bbox coordinate (y1x1y2x2) or the centroid + bbox size (yxwh). |
True
|
temperature |
Optional[int]
|
the temperature constant to be used when computing the sinusoidal position embedding |
10000
|
normalize |
Optional[bool]
|
whether or not to normalize the positions (Only used in fixed embeddings). |
False
|
scale |
Optional[float]
|
factor by which to scale the positions after normalizing (Only used in fixed embeddings). |
None
|
mlp_cfg |
dict
|
A dictionary of mlp hyperparameters for projecting embedding to correct space. Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3} |
None
|
Source code in dreem/models/embedding.py
def __init__(
self,
emb_type: str,
mode: str,
features: int,
n_points: Optional[int] = 1,
emb_num: Optional[int] = 16,
over_boxes: Optional[bool] = True,
temperature: Optional[int] = 10000,
normalize: Optional[bool] = False,
scale: Optional[float] = None,
mlp_cfg: dict = None,
):
"""Initialize embeddings.
Args:
emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", "off"}`
mode: The mode or function used to map positions to vector embeddings.
Must be one of `{"fixed", "learned", "off"}`
features: The embedding dimensions. Must match the dimension of the
input vectors for the transformer model.
n_points: the number of points that will be embedded.
emb_num: the number of embeddings in the `self.lookup` table (Only used in learned embeddings).
over_boxes: Whether to compute the position embedding for each bbox coordinate (y1x1y2x2) or the centroid + bbox size (yxwh).
temperature: the temperature constant to be used when computing the sinusoidal position embedding
normalize: whether or not to normalize the positions (Only used in fixed embeddings).
scale: factor by which to scale the positions after normalizing (Only used in fixed embeddings).
mlp_cfg: A dictionary of mlp hyperparameters for projecting embedding to correct space.
Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3}
"""
self._check_init_args(emb_type, mode)
super().__init__()
self.emb_type = emb_type
self.mode = mode
self.features = features
self.emb_num = emb_num
self.over_boxes = over_boxes
self.temperature = temperature
self.normalize = normalize
self.scale = scale
self.n_points = n_points
if self.normalize and self.scale is None:
self.scale = 2 * math.pi
if self.emb_type == "pos" and mlp_cfg is not None and mlp_cfg["num_layers"] > 0:
if self.mode == "fixed":
self.mlp = MLP(
input_dim=n_points * self.features,
output_dim=self.features,
**mlp_cfg,
)
else:
in_dim = (self.features // (4 * n_points)) * (4 * n_points)
self.mlp = MLP(
input_dim=in_dim,
output_dim=self.features,
**mlp_cfg,
)
else:
self.mlp = torch.nn.Identity()
self._emb_func = lambda tensor: torch.zeros(
(tensor.shape[0], self.features), dtype=tensor.dtype, device=tensor.device
) # turn off embedding by returning zeros
self.lookup = None
if self.mode == "learned":
if self.emb_type == "pos":
self.lookup = torch.nn.Embedding(
self.emb_num * 4 * self.n_points, self.features // (4 * n_points)
)
self._emb_func = self._learned_pos_embedding
elif self.emb_type == "temp":
self.lookup = torch.nn.Embedding(self.emb_num, self.features)
self._emb_func = self._learned_temp_embedding
elif self.mode == "fixed":
if self.emb_type == "pos":
self._emb_func = self._sine_box_embedding
elif self.emb_type == "temp":
self._emb_func = self._sine_temp_embedding
forward(seq_positions)
¶
Get the sequence positional embeddings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seq_positions |
Tensor
|
|
required |
Returns:
Type | Description |
---|---|
Tensor
|
An |
Source code in dreem/models/embedding.py
def forward(self, seq_positions: torch.Tensor) -> torch.Tensor:
"""Get the sequence positional embeddings.
Args:
seq_positions:
* An (`N`, 1) tensor where seq_positions[i] represents the temporal position of instance_i in the sequence.
* An (`N`, n_anchors x 4) tensor where seq_positions[i, j, :] represents the [y1, x1, y2, x2] spatial locations of jth point of instance_i in the sequence.
Returns:
An `N` x `self.features` tensor representing the corresponding spatial or temporal embedding.
"""
emb = self._emb_func(seq_positions)
if emb.shape[-1] != self.features:
raise RuntimeError(
(
f"Output embedding dimension is {emb.shape[-1]} but requested {self.features} dimensions! \n"
f"hint: Try turning the MLP on by passing `mlp_cfg` to the constructor to project to the correct embedding dimensions."
)
)
return emb
dreem.models.transformer.TransformerDecoder
¶
Bases: Module
Transformer Decoder Block composed of Transformer Decoder Layers.
Source code in dreem/models/transformer.py
class TransformerDecoder(nn.Module):
"""Transformer Decoder Block composed of Transformer Decoder Layers."""
def __init__(
self,
decoder_layer: TransformerDecoderLayer,
num_layers: int,
return_intermediate: bool = False,
norm: nn.Module = None,
) -> None:
"""Initialize transformer decoder block.
Args:
decoder_layer: An instance of TransformerDecoderLayer.
num_layers: The number of decoder layers to be stacked.
return_intermediate: Return intermediate layers from decoder.
norm: The normalization layer to be applied.
"""
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.return_intermediate = return_intermediate
self.norm = norm if norm is not None else nn.Identity()
def forward(
self,
decoder_queries: torch.Tensor,
encoder_features: torch.Tensor,
ref_pos_emb: torch.Tensor = None,
query_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
"""Execute a forward pass of the decoder block.
Args:
decoder_queries: Query sequence for decoder to generate (n_query, batch_size, embed_dim).
encoder_features: Output from encoder, that decoder uses to attend to relevant
parts of input sequence (total_instances, batch_size, embed_dim)
ref_pos_emb: The input positional embedding tensor of shape (total_instances, batch_size, embed_dim).
query_pos_emb: The query positional embedding of shape (n_query, batch_size, embed_dim)
Returns:
The output tensor of shape (L, n_query, batch_size, embed_dim).
"""
decoder_features = decoder_queries
intermediate = []
for layer in self.layers:
decoder_features = layer(
decoder_features,
encoder_features,
ref_pos_emb=ref_pos_emb,
query_pos_emb=query_pos_emb,
)
if self.return_intermediate:
intermediate.append(self.norm(decoder_features))
decoder_features = self.norm(decoder_features)
if self.return_intermediate:
intermediate.pop()
intermediate.append(decoder_features)
return torch.stack(intermediate)
return decoder_features.unsqueeze(0)
__init__(decoder_layer, num_layers, return_intermediate=False, norm=None)
¶
Initialize transformer decoder block.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
decoder_layer |
TransformerDecoderLayer
|
An instance of TransformerDecoderLayer. |
required |
num_layers |
int
|
The number of decoder layers to be stacked. |
required |
return_intermediate |
bool
|
Return intermediate layers from decoder. |
False
|
norm |
Module
|
The normalization layer to be applied. |
None
|
Source code in dreem/models/transformer.py
def __init__(
self,
decoder_layer: TransformerDecoderLayer,
num_layers: int,
return_intermediate: bool = False,
norm: nn.Module = None,
) -> None:
"""Initialize transformer decoder block.
Args:
decoder_layer: An instance of TransformerDecoderLayer.
num_layers: The number of decoder layers to be stacked.
return_intermediate: Return intermediate layers from decoder.
norm: The normalization layer to be applied.
"""
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.return_intermediate = return_intermediate
self.norm = norm if norm is not None else nn.Identity()
forward(decoder_queries, encoder_features, ref_pos_emb=None, query_pos_emb=None)
¶
Execute a forward pass of the decoder block.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
decoder_queries |
Tensor
|
Query sequence for decoder to generate (n_query, batch_size, embed_dim). |
required |
encoder_features |
Tensor
|
Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim) |
required |
ref_pos_emb |
Tensor
|
The input positional embedding tensor of shape (total_instances, batch_size, embed_dim). |
None
|
query_pos_emb |
Tensor
|
The query positional embedding of shape (n_query, batch_size, embed_dim) |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
The output tensor of shape (L, n_query, batch_size, embed_dim). |
Source code in dreem/models/transformer.py
def forward(
self,
decoder_queries: torch.Tensor,
encoder_features: torch.Tensor,
ref_pos_emb: torch.Tensor = None,
query_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
"""Execute a forward pass of the decoder block.
Args:
decoder_queries: Query sequence for decoder to generate (n_query, batch_size, embed_dim).
encoder_features: Output from encoder, that decoder uses to attend to relevant
parts of input sequence (total_instances, batch_size, embed_dim)
ref_pos_emb: The input positional embedding tensor of shape (total_instances, batch_size, embed_dim).
query_pos_emb: The query positional embedding of shape (n_query, batch_size, embed_dim)
Returns:
The output tensor of shape (L, n_query, batch_size, embed_dim).
"""
decoder_features = decoder_queries
intermediate = []
for layer in self.layers:
decoder_features = layer(
decoder_features,
encoder_features,
ref_pos_emb=ref_pos_emb,
query_pos_emb=query_pos_emb,
)
if self.return_intermediate:
intermediate.append(self.norm(decoder_features))
decoder_features = self.norm(decoder_features)
if self.return_intermediate:
intermediate.pop()
intermediate.append(decoder_features)
return torch.stack(intermediate)
return decoder_features.unsqueeze(0)
dreem.models.transformer.TransformerDecoderLayer
¶
Bases: Module
A single transformer decoder layer.
Source code in dreem/models/transformer.py
class TransformerDecoderLayer(nn.Module):
"""A single transformer decoder layer."""
def __init__(
self,
d_model: int = 1024,
nhead: int = 6,
dim_feedforward: int = 1024,
dropout: float = 0.1,
activation: str = "relu",
norm: bool = False,
decoder_self_attn: bool = False,
) -> None:
"""Initialize transformer decoder layer.
Args:
d_model: The number of features in the decoder inputs.
nhead: The number of heads for the decoder.
dim_feedforward: Dimension of the feedforward layers of decoder.
dropout: Dropout value applied to the output of decoder.
activation: Activation function to use.
norm: If True, normalize output of decoder.
decoder_self_attn: If True, use decoder self attention
"""
super().__init__()
self.decoder_self_attn = decoder_self_attn
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
if self.decoder_self_attn:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm3 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def forward(
self,
decoder_queries: torch.Tensor,
encoder_features: torch.Tensor,
ref_pos_emb: torch.Tensor = None,
query_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
"""Execute forward pass of decoder layer.
Args:
decoder_queries: Target sequence for decoder to generate (n_query, batch_size, embed_dim).
encoder_features: Output from encoder, that decoder uses to attend to relevant
parts of input sequence (total_instances, batch_size, embed_dim)
ref_pos_emb: The input positional embedding tensor of shape (n_query, embed_dim).
query_pos_emb: The target positional embedding of shape (n_query, embed_dim)
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
if query_pos_emb is None:
query_pos_emb = torch.zeros_like(decoder_queries)
if ref_pos_emb is None:
ref_pos_emb = torch.zeros_like(encoder_features)
decoder_queries = decoder_queries + query_pos_emb
encoder_features = encoder_features + ref_pos_emb
if self.decoder_self_attn:
self_attn_features = self.self_attn(
query=decoder_queries, key=decoder_queries, value=decoder_queries
)[0]
decoder_queries = decoder_queries + self.dropout1(self_attn_features)
decoder_queries = self.norm1(decoder_queries)
x_attn_features = self.multihead_attn(
query=decoder_queries, # (n_query, batch_size, embed_dim)
key=encoder_features, # (total_instances, batch_size, embed_dim)
value=encoder_features, # (total_instances, batch_size, embed_dim)
)[
0
] # (n_query, batch_size, embed_dim)
decoder_queries = decoder_queries + self.dropout2(
x_attn_features
) # (n_query, batch_size, embed_dim)
decoder_queries = self.norm2(
decoder_queries
) # (n_query, batch_size, embed_dim)
projection = self.linear2(
self.dropout(self.activation(self.linear1(decoder_queries)))
) # (n_query, batch_size, embed_dim)
decoder_queries = decoder_queries + self.dropout3(
projection
) # (n_query, batch_size, embed_dim)
decoder_features = self.norm3(decoder_queries)
return decoder_features # (n_query, batch_size, embed_dim)
__init__(d_model=1024, nhead=6, dim_feedforward=1024, dropout=0.1, activation='relu', norm=False, decoder_self_attn=False)
¶
Initialize transformer decoder layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
d_model |
int
|
The number of features in the decoder inputs. |
1024
|
nhead |
int
|
The number of heads for the decoder. |
6
|
dim_feedforward |
int
|
Dimension of the feedforward layers of decoder. |
1024
|
dropout |
float
|
Dropout value applied to the output of decoder. |
0.1
|
activation |
str
|
Activation function to use. |
'relu'
|
norm |
bool
|
If True, normalize output of decoder. |
False
|
decoder_self_attn |
bool
|
If True, use decoder self attention |
False
|
Source code in dreem/models/transformer.py
def __init__(
self,
d_model: int = 1024,
nhead: int = 6,
dim_feedforward: int = 1024,
dropout: float = 0.1,
activation: str = "relu",
norm: bool = False,
decoder_self_attn: bool = False,
) -> None:
"""Initialize transformer decoder layer.
Args:
d_model: The number of features in the decoder inputs.
nhead: The number of heads for the decoder.
dim_feedforward: Dimension of the feedforward layers of decoder.
dropout: Dropout value applied to the output of decoder.
activation: Activation function to use.
norm: If True, normalize output of decoder.
decoder_self_attn: If True, use decoder self attention
"""
super().__init__()
self.decoder_self_attn = decoder_self_attn
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
if self.decoder_self_attn:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.norm3 = nn.LayerNorm(d_model) if norm else nn.Identity()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
forward(decoder_queries, encoder_features, ref_pos_emb=None, query_pos_emb=None)
¶
Execute forward pass of decoder layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
decoder_queries |
Tensor
|
Target sequence for decoder to generate (n_query, batch_size, embed_dim). |
required |
encoder_features |
Tensor
|
Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim) |
required |
ref_pos_emb |
Tensor
|
The input positional embedding tensor of shape (n_query, embed_dim). |
None
|
query_pos_emb |
Tensor
|
The target positional embedding of shape (n_query, embed_dim) |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
The output tensor of shape (n_query, batch_size, embed_dim). |
Source code in dreem/models/transformer.py
def forward(
self,
decoder_queries: torch.Tensor,
encoder_features: torch.Tensor,
ref_pos_emb: torch.Tensor = None,
query_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
"""Execute forward pass of decoder layer.
Args:
decoder_queries: Target sequence for decoder to generate (n_query, batch_size, embed_dim).
encoder_features: Output from encoder, that decoder uses to attend to relevant
parts of input sequence (total_instances, batch_size, embed_dim)
ref_pos_emb: The input positional embedding tensor of shape (n_query, embed_dim).
query_pos_emb: The target positional embedding of shape (n_query, embed_dim)
Returns:
The output tensor of shape (n_query, batch_size, embed_dim).
"""
if query_pos_emb is None:
query_pos_emb = torch.zeros_like(decoder_queries)
if ref_pos_emb is None:
ref_pos_emb = torch.zeros_like(encoder_features)
decoder_queries = decoder_queries + query_pos_emb
encoder_features = encoder_features + ref_pos_emb
if self.decoder_self_attn:
self_attn_features = self.self_attn(
query=decoder_queries, key=decoder_queries, value=decoder_queries
)[0]
decoder_queries = decoder_queries + self.dropout1(self_attn_features)
decoder_queries = self.norm1(decoder_queries)
x_attn_features = self.multihead_attn(
query=decoder_queries, # (n_query, batch_size, embed_dim)
key=encoder_features, # (total_instances, batch_size, embed_dim)
value=encoder_features, # (total_instances, batch_size, embed_dim)
)[
0
] # (n_query, batch_size, embed_dim)
decoder_queries = decoder_queries + self.dropout2(
x_attn_features
) # (n_query, batch_size, embed_dim)
decoder_queries = self.norm2(
decoder_queries
) # (n_query, batch_size, embed_dim)
projection = self.linear2(
self.dropout(self.activation(self.linear1(decoder_queries)))
) # (n_query, batch_size, embed_dim)
decoder_queries = decoder_queries + self.dropout3(
projection
) # (n_query, batch_size, embed_dim)
decoder_features = self.norm3(decoder_queries)
return decoder_features # (n_query, batch_size, embed_dim)
dreem.models.attention_head.ATTWeightHead
¶
Bases: Module
Single attention head.
Source code in dreem/models/attention_head.py
class ATTWeightHead(torch.nn.Module):
"""Single attention head."""
def __init__(
self,
feature_dim: int,
num_layers: int,
dropout: float,
):
"""Initialize an instance of ATTWeightHead.
Args:
feature_dim: The dimensionality of input features.
num_layers: The number of hidden layers in the MLP.
dropout: Dropout probability.
"""
super().__init__()
self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> torch.Tensor:
"""Compute the attention weights of a query tensor using the key tensor.
Args:
query: Input tensor of shape (batch_size, num_frame_instances, feature_dim).
key: Input tensor of shape (batch_size, num_window_instances, feature_dim).
Returns:
Output tensor of shape (batch_size, num_frame_instances, num_window_instances).
"""
k = self.k_proj(key)
q = self.q_proj(query)
attn_weights = torch.bmm(q, k.transpose(1, 2))
return attn_weights # (B, N_t, N)
__init__(feature_dim, num_layers, dropout)
¶
Initialize an instance of ATTWeightHead.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
feature_dim |
int
|
The dimensionality of input features. |
required |
num_layers |
int
|
The number of hidden layers in the MLP. |
required |
dropout |
float
|
Dropout probability. |
required |
Source code in dreem/models/attention_head.py
def __init__(
self,
feature_dim: int,
num_layers: int,
dropout: float,
):
"""Initialize an instance of ATTWeightHead.
Args:
feature_dim: The dimensionality of input features.
num_layers: The number of hidden layers in the MLP.
dropout: Dropout probability.
"""
super().__init__()
self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
forward(query, key)
¶
Compute the attention weights of a query tensor using the key tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query |
Tensor
|
Input tensor of shape (batch_size, num_frame_instances, feature_dim). |
required |
key |
Tensor
|
Input tensor of shape (batch_size, num_window_instances, feature_dim). |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Output tensor of shape (batch_size, num_frame_instances, num_window_instances). |
Source code in dreem/models/attention_head.py
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> torch.Tensor:
"""Compute the attention weights of a query tensor using the key tensor.
Args:
query: Input tensor of shape (batch_size, num_frame_instances, feature_dim).
key: Input tensor of shape (batch_size, num_window_instances, feature_dim).
Returns:
Output tensor of shape (batch_size, num_frame_instances, num_window_instances).
"""
k = self.k_proj(key)
q = self.q_proj(query)
attn_weights = torch.bmm(q, k.transpose(1, 2))
return attn_weights # (B, N_t, N)
dreem.models.mlp.MLP
¶
Bases: Module
Multi-Layer Perceptron (MLP) module.
Source code in dreem/models/mlp.py
class MLP(torch.nn.Module):
"""Multi-Layer Perceptron (MLP) module."""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
dropout: float = 0.0,
):
"""Initialize MLP.
Args:
input_dim: Dimensionality of the input features.
hidden_dim: Number of units in the hidden layers.
output_dim: Dimensionality of the output features.
num_layers: Number of hidden layers.
dropout: Dropout probability.
"""
super().__init__()
self.num_layers = num_layers
self.dropout = dropout
if self.num_layers > 0:
h = [hidden_dim] * (num_layers - 1)
self.layers = torch.nn.ModuleList(
[
torch.nn.Linear(n, k)
for n, k in zip([input_dim] + h, h + [output_dim])
]
)
if self.dropout > 0.0:
self.dropouts = torch.nn.ModuleList(
[torch.nn.Dropout(dropout) for _ in range(self.num_layers - 1)]
)
else:
self.layers = []
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the MLP.
Args:
x: Input tensor of shape (batch_size, num_instances, input_dim).
Returns:
Output tensor of shape (batch_size, num_instances, output_dim).
"""
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if i < self.num_layers - 1 and self.dropout > 0.0:
x = self.dropouts[i](x)
return x
__init__(input_dim, hidden_dim, output_dim, num_layers, dropout=0.0)
¶
Initialize MLP.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim |
int
|
Dimensionality of the input features. |
required |
hidden_dim |
int
|
Number of units in the hidden layers. |
required |
output_dim |
int
|
Dimensionality of the output features. |
required |
num_layers |
int
|
Number of hidden layers. |
required |
dropout |
float
|
Dropout probability. |
0.0
|
Source code in dreem/models/mlp.py
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
dropout: float = 0.0,
):
"""Initialize MLP.
Args:
input_dim: Dimensionality of the input features.
hidden_dim: Number of units in the hidden layers.
output_dim: Dimensionality of the output features.
num_layers: Number of hidden layers.
dropout: Dropout probability.
"""
super().__init__()
self.num_layers = num_layers
self.dropout = dropout
if self.num_layers > 0:
h = [hidden_dim] * (num_layers - 1)
self.layers = torch.nn.ModuleList(
[
torch.nn.Linear(n, k)
for n, k in zip([input_dim] + h, h + [output_dim])
]
)
if self.dropout > 0.0:
self.dropouts = torch.nn.ModuleList(
[torch.nn.Dropout(dropout) for _ in range(self.num_layers - 1)]
)
else:
self.layers = []
forward(x)
¶
Forward pass of the MLP.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor
|
Input tensor of shape (batch_size, num_instances, input_dim). |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Output tensor of shape (batch_size, num_instances, output_dim). |
Source code in dreem/models/mlp.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the MLP.
Args:
x: Input tensor of shape (batch_size, num_instances, input_dim).
Returns:
Output tensor of shape (batch_size, num_instances, output_dim).
"""
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if i < self.num_layers - 1 and self.dropout > 0.0:
x = self.dropouts[i](x)
return x