GlobalTrackingTransformer
¶
dreem.models.GlobalTrackingTransformer
¶
Bases: Module
Modular GTR model composed of visual encoder + transformer used for tracking.
Source code in dreem/models/global_tracking_transformer.py
class GlobalTrackingTransformer(torch.nn.Module):
"""Modular GTR model composed of visual encoder + transformer used for tracking."""
def __init__(
self,
encoder_cfg: dict = None,
d_model: int = 1024,
nhead: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dropout: int = 0.1,
activation: str = "relu",
return_intermediate_dec: bool = False,
norm: bool = False,
num_layers_attn_head: int = 2,
dropout_attn_head: int = 0.1,
embedding_meta: dict = None,
return_embedding: bool = False,
decoder_self_attn: bool = False,
):
"""Initialize GTR.
Args:
encoder_cfg: Dictionary of arguments to pass to the CNN constructor,
e.g: `cfg = {"model_name": "resnet18", "pretrained": False, "in_chans": 3}`
d_model: The number of features in the encoder/decoder inputs.
nhead: The number of heads in the transfomer encoder/decoder.
num_encoder_layers: The number of encoder-layers in the encoder.
num_decoder_layers: The number of decoder-layers in the decoder.
dropout: Dropout value applied to the output of transformer layers.
activation: Activation function to use.
return_intermediate_dec: Return intermediate layers from decoder.
norm: If True, normalize output of encoder and decoder.
num_layers_attn_head: The number of layers in the attention head.
dropout_attn_head: Dropout value for the attention_head.
embedding_meta: Metadata for positional embeddings. See below.
return_embedding: Whether to return the positional embeddings
decoder_self_attn: If True, use decoder self attention.
More details on `embedding_meta`:
By default this will be an empty dict and indicate
that no positional embeddings should be used. To use the positional embeddings
pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie:
`{"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: True},
"temp": {'mode': 'learned', 'emb_num': 16}}`. (see `dreem.models.embeddings.Embedding.EMB_TYPES`
and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters).
"""
super().__init__()
if encoder_cfg is not None:
self.visual_encoder = VisualEncoder(d_model=d_model, **encoder_cfg)
else:
self.visual_encoder = VisualEncoder(d_model=d_model)
self.transformer = Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dropout=dropout,
activation=activation,
return_intermediate_dec=return_intermediate_dec,
norm=norm,
num_layers_attn_head=num_layers_attn_head,
dropout_attn_head=dropout_attn_head,
embedding_meta=embedding_meta,
return_embedding=return_embedding,
decoder_self_attn=decoder_self_attn,
)
def forward(
self, ref_instances: list["Instance"], query_instances: list["Instance"] = None
) -> list["AssociationMatrix"]:
"""Execute forward pass of GTR Model to get asso matrix.
Args:
ref_instances: List of instances from chunk containing crops of objects + gt label info
query_instances: list of instances used as query in decoder.
Returns:
An N_T x N association matrix
"""
# Extract feature representations with pre-trained encoder.
self.extract_features(ref_instances)
if query_instances:
self.extract_features(query_instances)
asso_preds = self.transformer(ref_instances, query_instances)
return asso_preds
def extract_features(
self, instances: list["Instance"], force_recompute: bool = False
) -> None:
"""Extract features from instances using visual encoder backbone.
Args:
instances: A list of instances to compute features for
force_recompute: indicate whether to compute features for all instances regardless of if they have instances
"""
if not force_recompute:
instances_to_compute = [
instance
for instance in instances
if instance.has_crop() and not instance.has_features()
]
else:
instances_to_compute = instances
if len(instances_to_compute) == 0:
return
elif len(instances_to_compute) == 1: # handle batch norm error when B=1
instances_to_compute = instances
crops = torch.concatenate([instance.crop for instance in instances_to_compute])
features = self.visual_encoder(crops)
for i, z_i in enumerate(features):
instances_to_compute[i].features = z_i
__init__(encoder_cfg=None, d_model=1024, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dropout=0.1, activation='relu', return_intermediate_dec=False, norm=False, num_layers_attn_head=2, dropout_attn_head=0.1, embedding_meta=None, return_embedding=False, decoder_self_attn=False)
¶
Initialize GTR.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
encoder_cfg |
dict
|
Dictionary of arguments to pass to the CNN constructor,
e.g: |
None
|
d_model |
int
|
The number of features in the encoder/decoder inputs. |
1024
|
nhead |
int
|
The number of heads in the transfomer encoder/decoder. |
8
|
num_encoder_layers |
int
|
The number of encoder-layers in the encoder. |
6
|
num_decoder_layers |
int
|
The number of decoder-layers in the decoder. |
6
|
dropout |
int
|
Dropout value applied to the output of transformer layers. |
0.1
|
activation |
str
|
Activation function to use. |
'relu'
|
return_intermediate_dec |
bool
|
Return intermediate layers from decoder. |
False
|
norm |
bool
|
If True, normalize output of encoder and decoder. |
False
|
num_layers_attn_head |
int
|
The number of layers in the attention head. |
2
|
dropout_attn_head |
int
|
Dropout value for the attention_head. |
0.1
|
embedding_meta |
dict
|
Metadata for positional embeddings. See below. |
None
|
return_embedding |
bool
|
Whether to return the positional embeddings |
False
|
decoder_self_attn |
bool
|
If True, use decoder self attention. More details on |
False
|
Source code in dreem/models/global_tracking_transformer.py
def __init__(
self,
encoder_cfg: dict = None,
d_model: int = 1024,
nhead: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dropout: int = 0.1,
activation: str = "relu",
return_intermediate_dec: bool = False,
norm: bool = False,
num_layers_attn_head: int = 2,
dropout_attn_head: int = 0.1,
embedding_meta: dict = None,
return_embedding: bool = False,
decoder_self_attn: bool = False,
):
"""Initialize GTR.
Args:
encoder_cfg: Dictionary of arguments to pass to the CNN constructor,
e.g: `cfg = {"model_name": "resnet18", "pretrained": False, "in_chans": 3}`
d_model: The number of features in the encoder/decoder inputs.
nhead: The number of heads in the transfomer encoder/decoder.
num_encoder_layers: The number of encoder-layers in the encoder.
num_decoder_layers: The number of decoder-layers in the decoder.
dropout: Dropout value applied to the output of transformer layers.
activation: Activation function to use.
return_intermediate_dec: Return intermediate layers from decoder.
norm: If True, normalize output of encoder and decoder.
num_layers_attn_head: The number of layers in the attention head.
dropout_attn_head: Dropout value for the attention_head.
embedding_meta: Metadata for positional embeddings. See below.
return_embedding: Whether to return the positional embeddings
decoder_self_attn: If True, use decoder self attention.
More details on `embedding_meta`:
By default this will be an empty dict and indicate
that no positional embeddings should be used. To use the positional embeddings
pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie:
`{"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: True},
"temp": {'mode': 'learned', 'emb_num': 16}}`. (see `dreem.models.embeddings.Embedding.EMB_TYPES`
and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters).
"""
super().__init__()
if encoder_cfg is not None:
self.visual_encoder = VisualEncoder(d_model=d_model, **encoder_cfg)
else:
self.visual_encoder = VisualEncoder(d_model=d_model)
self.transformer = Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dropout=dropout,
activation=activation,
return_intermediate_dec=return_intermediate_dec,
norm=norm,
num_layers_attn_head=num_layers_attn_head,
dropout_attn_head=dropout_attn_head,
embedding_meta=embedding_meta,
return_embedding=return_embedding,
decoder_self_attn=decoder_self_attn,
)
extract_features(instances, force_recompute=False)
¶
Extract features from instances using visual encoder backbone.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
instances |
list[Instance]
|
A list of instances to compute features for |
required |
force_recompute |
bool
|
indicate whether to compute features for all instances regardless of if they have instances |
False
|
Source code in dreem/models/global_tracking_transformer.py
def extract_features(
self, instances: list["Instance"], force_recompute: bool = False
) -> None:
"""Extract features from instances using visual encoder backbone.
Args:
instances: A list of instances to compute features for
force_recompute: indicate whether to compute features for all instances regardless of if they have instances
"""
if not force_recompute:
instances_to_compute = [
instance
for instance in instances
if instance.has_crop() and not instance.has_features()
]
else:
instances_to_compute = instances
if len(instances_to_compute) == 0:
return
elif len(instances_to_compute) == 1: # handle batch norm error when B=1
instances_to_compute = instances
crops = torch.concatenate([instance.crop for instance in instances_to_compute])
features = self.visual_encoder(crops)
for i, z_i in enumerate(features):
instances_to_compute[i].features = z_i
forward(ref_instances, query_instances=None)
¶
Execute forward pass of GTR Model to get asso matrix.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ref_instances |
list[Instance]
|
List of instances from chunk containing crops of objects + gt label info |
required |
query_instances |
list[Instance]
|
list of instances used as query in decoder. |
None
|
Returns:
Type | Description |
---|---|
list[AssociationMatrix]
|
An N_T x N association matrix |
Source code in dreem/models/global_tracking_transformer.py
def forward(
self, ref_instances: list["Instance"], query_instances: list["Instance"] = None
) -> list["AssociationMatrix"]:
"""Execute forward pass of GTR Model to get asso matrix.
Args:
ref_instances: List of instances from chunk containing crops of objects + gt label info
query_instances: list of instances used as query in decoder.
Returns:
An N_T x N association matrix
"""
# Extract feature representations with pre-trained encoder.
self.extract_features(ref_instances)
if query_instances:
self.extract_features(query_instances)
asso_preds = self.transformer(ref_instances, query_instances)
return asso_preds