embedding
dreem.models.embedding
¶
Module containing different position and temporal embeddings.
Embedding
¶
Bases: Module
Class that wraps around different embedding types.
Used for both learned and fixed embeddings.
Source code in dreem/models/embedding.py
class Embedding(torch.nn.Module):
"""Class that wraps around different embedding types.
Used for both learned and fixed embeddings.
"""
EMB_TYPES = {
"temp": {},
"pos": {"over_boxes"},
"off": {},
None: {},
} # dict of valid args:keyword params
EMB_MODES = {
"fixed": {"temperature", "scale", "normalize"},
"learned": {"emb_num"},
"off": {},
} # dict of valid args:keyword params
def __init__(
self,
emb_type: str,
mode: str,
features: int,
n_points: Optional[int] = 1,
emb_num: Optional[int] = 16,
over_boxes: Optional[bool] = True,
temperature: Optional[int] = 10000,
normalize: Optional[bool] = False,
scale: Optional[float] = None,
mlp_cfg: dict = None,
):
"""Initialize embeddings.
Args:
emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", "off"}`
mode: The mode or function used to map positions to vector embeddings.
Must be one of `{"fixed", "learned", "off"}`
features: The embedding dimensions. Must match the dimension of the
input vectors for the transformer model.
n_points: the number of points that will be embedded.
emb_num: the number of embeddings in the `self.lookup` table (Only used in learned embeddings).
over_boxes: Whether to compute the position embedding for each bbox coordinate (y1x1y2x2) or the centroid + bbox size (yxwh).
temperature: the temperature constant to be used when computing the sinusoidal position embedding
normalize: whether or not to normalize the positions (Only used in fixed embeddings).
scale: factor by which to scale the positions after normalizing (Only used in fixed embeddings).
mlp_cfg: A dictionary of mlp hyperparameters for projecting embedding to correct space.
Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3}
"""
self._check_init_args(emb_type, mode)
super().__init__()
self.emb_type = emb_type
self.mode = mode
self.features = features
self.emb_num = emb_num
self.over_boxes = over_boxes
self.temperature = temperature
self.normalize = normalize
self.scale = scale
self.n_points = n_points
if self.normalize and self.scale is None:
self.scale = 2 * math.pi
if self.emb_type == "pos" and mlp_cfg is not None and mlp_cfg["num_layers"] > 0:
if self.mode == "fixed":
self.mlp = MLP(
input_dim=n_points * self.features,
output_dim=self.features,
**mlp_cfg,
)
else:
in_dim = (self.features // (4 * n_points)) * (4 * n_points)
self.mlp = MLP(
input_dim=in_dim,
output_dim=self.features,
**mlp_cfg,
)
else:
self.mlp = torch.nn.Identity()
self._emb_func = lambda tensor: torch.zeros(
(tensor.shape[0], self.features), dtype=tensor.dtype, device=tensor.device
) # turn off embedding by returning zeros
self.lookup = None
if self.mode == "learned":
if self.emb_type == "pos":
self.lookup = torch.nn.Embedding(
self.emb_num * 4 * self.n_points, self.features // (4 * n_points)
)
self._emb_func = self._learned_pos_embedding
elif self.emb_type == "temp":
self.lookup = torch.nn.Embedding(self.emb_num, self.features)
self._emb_func = self._learned_temp_embedding
elif self.mode == "fixed":
if self.emb_type == "pos":
self._emb_func = self._sine_box_embedding
elif self.emb_type == "temp":
self._emb_func = self._sine_temp_embedding
def _check_init_args(self, emb_type: str, mode: str):
"""Check whether the correct arguments were passed to initialization.
Args:
emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", ""}`
mode: The mode or function used to map positions to vector embeddings.
Must be one of `{"fixed", "learned"}`
Raises:
ValueError:
* if the incorrect `emb_type` or `mode` string are passed
NotImplementedError: if `emb_type` is `temp` and `mode` is `fixed`.
"""
if emb_type.lower() not in self.EMB_TYPES:
raise ValueError(
f"Embedding `emb_type` must be one of {self.EMB_TYPES} not {emb_type}"
)
if mode.lower() not in self.EMB_MODES:
raise ValueError(
f"Embedding `mode` must be one of {self.EMB_MODES} not {mode}"
)
def forward(self, seq_positions: torch.Tensor) -> torch.Tensor:
"""Get the sequence positional embeddings.
Args:
seq_positions:
* An (`N`, 1) tensor where seq_positions[i] represents the temporal position of instance_i in the sequence.
* An (`N`, n_anchors x 4) tensor where seq_positions[i, j, :] represents the [y1, x1, y2, x2] spatial locations of jth point of instance_i in the sequence.
Returns:
An `N` x `self.features` tensor representing the corresponding spatial or temporal embedding.
"""
emb = self._emb_func(seq_positions)
if emb.shape[-1] != self.features:
raise RuntimeError(
(
f"Output embedding dimension is {emb.shape[-1]} but requested {self.features} dimensions! \n"
f"hint: Try turning the MLP on by passing `mlp_cfg` to the constructor to project to the correct embedding dimensions."
)
)
return emb
def _torch_int_div(
self, tensor1: torch.Tensor, tensor2: torch.Tensor
) -> torch.Tensor:
"""Perform integer division of two tensors.
Args:
tensor1: dividend tensor.
tensor2: divisor tensor.
Returns:
torch.Tensor, resulting tensor.
"""
return torch.div(tensor1, tensor2, rounding_mode="floor")
def _sine_box_embedding(self, boxes: torch.Tensor) -> torch.Tensor:
"""Compute sine positional embeddings for boxes using given parameters.
Args:
boxes: the input boxes of shape N, n_anchors, 4 or B, N, n_anchors, 4
where the last dimension is the bbox coords in [y1, x1, y2, x2].
(Note currently `B=batch_size=1`).
Returns:
torch.Tensor, the sine positional embeddings
(embedding[:, 4i] = sin(x)
embedding[:, 4i+1] = cos(x)
embedding[:, 4i+2] = sin(y)
embedding[:, 4i+3] = cos(y)
)
"""
if self.scale is not None and self.normalize is False:
raise ValueError("normalize should be True if scale is passed")
if len(boxes.size()) == 3:
boxes = boxes.unsqueeze(0)
if self.normalize:
boxes = boxes / (boxes[:, :, -1:] + 1e-6) * self.scale
dim_t = torch.arange(self.features // 4, dtype=torch.float32)
dim_t = self.temperature ** (
2 * self._torch_int_div(dim_t, 2) / (self.features // 4)
)
# (b, n_t, n_anchors, 4, D//4)
pos_emb = boxes[:, :, :, :, None] / dim_t.to(boxes.device)
pos_emb = torch.stack(
(pos_emb[:, :, :, :, 0::2].sin(), pos_emb[:, :, :, :, 1::2].cos()), dim=4
)
pos_emb = pos_emb.flatten(2).squeeze(0) # (N_t, n_anchors * D)
pos_emb = self.mlp(pos_emb)
pos_emb = pos_emb.view(boxes.shape[1], self.features)
return pos_emb
def _sine_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
"""Compute fixed sine temporal embeddings.
Args:
times: the input times of shape (N,) or (N,1) where N = (sum(instances_per_frame))
which is the frame index of the instance relative
to the batch size
(e.g. `torch.tensor([0, 0, ..., 0, 1, 1, ..., 1, 2, 2, ..., 2,..., B, B, ...B])`).
Returns:
an n_instances x D embedding representing the temporal embedding.
"""
T = times.int().max().item() + 1
d = self.features
n = self.temperature
positions = torch.arange(0, T).unsqueeze(1)
temp_lookup = torch.zeros(T, d, device=times.device)
denominators = torch.pow(
n, 2 * torch.arange(0, d // 2) / d
) # 10000^(2i/d_model), i is the index of embedding
temp_lookup[:, 0::2] = torch.sin(
positions / denominators
) # sin(pos/10000^(2i/d_model))
temp_lookup[:, 1::2] = torch.cos(
positions / denominators
) # cos(pos/10000^(2i/d_model))
temp_emb = temp_lookup[times.int()]
return temp_emb # .view(len(times), self.features)
def _learned_pos_embedding(self, boxes: torch.Tensor) -> torch.Tensor:
"""Compute learned positional embeddings for boxes using given parameters.
Args:
boxes: the input boxes of shape N x 4 or B x N x 4
where the last dimension is the bbox coords in [y1, x1, y2, x2].
(Note currently `B=batch_size=1`).
Returns:
torch.Tensor, the learned positional embeddings.
"""
pos_lookup = self.lookup
N, n_anchors, _ = boxes.shape
boxes = boxes.view(N, n_anchors, 4)
if self.over_boxes:
xywh = boxes
else:
xywh = torch.cat(
[
(boxes[:, :, 2:] + boxes[:, :, :2]) / 2,
(boxes[:, :, 2:] - boxes[:, :, :2]),
],
dim=1,
)
left_ind, right_ind, left_weight, right_weight = self._compute_weights(xywh)
f = pos_lookup.weight.shape[1] # self.features // 4
try:
pos_emb_table = pos_lookup.weight.view(
self.emb_num, n_anchors, 4, f
) # T x 4 x (D * 4)
except RuntimeError as e:
print(f"Hint: `n_points` ({self.n_points}) may be set incorrectly!")
raise (e)
left_emb = pos_emb_table.gather(
0,
left_ind[:, :, :, None].to(pos_emb_table.device).expand(N, n_anchors, 4, f),
) # N x 4 x d
right_emb = pos_emb_table.gather(
0,
right_ind[:, :, :, None]
.to(pos_emb_table.device)
.expand(N, n_anchors, 4, f),
) # N x 4 x d
pos_emb = left_weight[:, :, :, None] * right_emb.to(
left_weight.device
) + right_weight[:, :, :, None] * left_emb.to(right_weight.device)
pos_emb = pos_emb.flatten(1)
pos_emb = self.mlp(pos_emb)
return pos_emb.view(N, self.features)
def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
"""Compute learned temporal embeddings for times using given parameters.
Args:
times: the input times of shape (N,) or (N,1) where N = (sum(instances_per_frame))
which is the frame index of the instance relative
to the batch size
(e.g. `torch.tensor([0, 0, ..., 0, 1, 1, ..., 1, 2, 2, ..., 2,..., B, B, ...B])`).
Returns:
torch.Tensor, the learned temporal embeddings.
"""
temp_lookup = self.lookup
N = times.shape[0]
left_ind, right_ind, left_weight, right_weight = self._compute_weights(times)
left_emb = temp_lookup.weight[
left_ind.to(temp_lookup.weight.device)
] # T x D --> N x D
right_emb = temp_lookup.weight[right_ind.to(temp_lookup.weight.device)]
temp_emb = left_weight[:, None] * right_emb.to(
left_weight.device
) + right_weight[:, None] * left_emb.to(right_weight.device)
return temp_emb.view(N, self.features)
def _compute_weights(self, data: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Compute left and right learned embedding weights.
Args:
data: the input data (e.g boxes or times).
Returns:
A torch.Tensor for each of the left/right indices and weights, respectively
"""
data = data * self.emb_num
left_ind = data.clamp(min=0, max=self.emb_num - 1).long() # N x 4
right_ind = (left_ind + 1).clamp(min=0, max=self.emb_num - 1).long() # N x 4
left_weight = data - left_ind.float() # N x 4
right_weight = 1.0 - left_weight
return left_ind, right_ind, left_weight, right_weight
__init__(emb_type, mode, features, n_points=1, emb_num=16, over_boxes=True, temperature=10000, normalize=False, scale=None, mlp_cfg=None)
¶
Initialize embeddings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
emb_type |
str
|
The type of embedding to compute. Must be one of |
required |
mode |
str
|
The mode or function used to map positions to vector embeddings.
Must be one of |
required |
features |
int
|
The embedding dimensions. Must match the dimension of the input vectors for the transformer model. |
required |
n_points |
Optional[int]
|
the number of points that will be embedded. |
1
|
emb_num |
Optional[int]
|
the number of embeddings in the |
16
|
over_boxes |
Optional[bool]
|
Whether to compute the position embedding for each bbox coordinate (y1x1y2x2) or the centroid + bbox size (yxwh). |
True
|
temperature |
Optional[int]
|
the temperature constant to be used when computing the sinusoidal position embedding |
10000
|
normalize |
Optional[bool]
|
whether or not to normalize the positions (Only used in fixed embeddings). |
False
|
scale |
Optional[float]
|
factor by which to scale the positions after normalizing (Only used in fixed embeddings). |
None
|
mlp_cfg |
dict
|
A dictionary of mlp hyperparameters for projecting embedding to correct space. Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3} |
None
|
Source code in dreem/models/embedding.py
def __init__(
self,
emb_type: str,
mode: str,
features: int,
n_points: Optional[int] = 1,
emb_num: Optional[int] = 16,
over_boxes: Optional[bool] = True,
temperature: Optional[int] = 10000,
normalize: Optional[bool] = False,
scale: Optional[float] = None,
mlp_cfg: dict = None,
):
"""Initialize embeddings.
Args:
emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", "off"}`
mode: The mode or function used to map positions to vector embeddings.
Must be one of `{"fixed", "learned", "off"}`
features: The embedding dimensions. Must match the dimension of the
input vectors for the transformer model.
n_points: the number of points that will be embedded.
emb_num: the number of embeddings in the `self.lookup` table (Only used in learned embeddings).
over_boxes: Whether to compute the position embedding for each bbox coordinate (y1x1y2x2) or the centroid + bbox size (yxwh).
temperature: the temperature constant to be used when computing the sinusoidal position embedding
normalize: whether or not to normalize the positions (Only used in fixed embeddings).
scale: factor by which to scale the positions after normalizing (Only used in fixed embeddings).
mlp_cfg: A dictionary of mlp hyperparameters for projecting embedding to correct space.
Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3}
"""
self._check_init_args(emb_type, mode)
super().__init__()
self.emb_type = emb_type
self.mode = mode
self.features = features
self.emb_num = emb_num
self.over_boxes = over_boxes
self.temperature = temperature
self.normalize = normalize
self.scale = scale
self.n_points = n_points
if self.normalize and self.scale is None:
self.scale = 2 * math.pi
if self.emb_type == "pos" and mlp_cfg is not None and mlp_cfg["num_layers"] > 0:
if self.mode == "fixed":
self.mlp = MLP(
input_dim=n_points * self.features,
output_dim=self.features,
**mlp_cfg,
)
else:
in_dim = (self.features // (4 * n_points)) * (4 * n_points)
self.mlp = MLP(
input_dim=in_dim,
output_dim=self.features,
**mlp_cfg,
)
else:
self.mlp = torch.nn.Identity()
self._emb_func = lambda tensor: torch.zeros(
(tensor.shape[0], self.features), dtype=tensor.dtype, device=tensor.device
) # turn off embedding by returning zeros
self.lookup = None
if self.mode == "learned":
if self.emb_type == "pos":
self.lookup = torch.nn.Embedding(
self.emb_num * 4 * self.n_points, self.features // (4 * n_points)
)
self._emb_func = self._learned_pos_embedding
elif self.emb_type == "temp":
self.lookup = torch.nn.Embedding(self.emb_num, self.features)
self._emb_func = self._learned_temp_embedding
elif self.mode == "fixed":
if self.emb_type == "pos":
self._emb_func = self._sine_box_embedding
elif self.emb_type == "temp":
self._emb_func = self._sine_temp_embedding
forward(seq_positions)
¶
Get the sequence positional embeddings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seq_positions |
Tensor
|
|
required |
Returns:
Type | Description |
---|---|
Tensor
|
An |
Source code in dreem/models/embedding.py
def forward(self, seq_positions: torch.Tensor) -> torch.Tensor:
"""Get the sequence positional embeddings.
Args:
seq_positions:
* An (`N`, 1) tensor where seq_positions[i] represents the temporal position of instance_i in the sequence.
* An (`N`, n_anchors x 4) tensor where seq_positions[i, j, :] represents the [y1, x1, y2, x2] spatial locations of jth point of instance_i in the sequence.
Returns:
An `N` x `self.features` tensor representing the corresponding spatial or temporal embedding.
"""
emb = self._emb_func(seq_positions)
if emb.shape[-1] != self.features:
raise RuntimeError(
(
f"Output embedding dimension is {emb.shape[-1]} but requested {self.features} dimensions! \n"
f"hint: Try turning the MLP on by passing `mlp_cfg` to the constructor to project to the correct embedding dimensions."
)
)
return emb