instance
dreem.io.instance
¶
Module containing data class for storing detections.
Instance
¶
Class representing a single instance to be tracked.
Attributes:
Name | Type | Description |
---|---|---|
gt_track_id |
Tensor
|
Ground truth track id - only used for train/eval. |
pred_track_id |
Tensor
|
Predicted track id. Untracked instance is represented by -1. |
bbox |
Tensor
|
The bounding box coordinate of the instance. Defaults to an empty tensor. |
crop |
Tensor
|
The crop of the instance. |
centroid |
dict[str, ArrayLike]
|
the centroid around which the bbox was cropped. |
features |
Tensor
|
The reid features extracted from the CNN backbone used in the transformer. |
track_score |
float
|
The track score output from the association matrix. |
point_scores |
ArrayLike
|
The point scores from sleap. |
instance_score |
float
|
The instance scores from sleap. |
skeleton |
Skeleton
|
The sleap skeleton used for the instance. |
pose |
dict[str, ArrayLike]
|
A dictionary containing the node name and corresponding point. |
device |
str
|
String representation of the device the instance should be on. |
Source code in dreem/io/instance.py
@attrs.define(eq=False)
class Instance:
"""Class representing a single instance to be tracked.
Attributes:
gt_track_id: Ground truth track id - only used for train/eval.
pred_track_id: Predicted track id. Untracked instance is represented by -1.
bbox: The bounding box coordinate of the instance. Defaults to an empty tensor.
crop: The crop of the instance.
centroid: the centroid around which the bbox was cropped.
features: The reid features extracted from the CNN backbone used in the transformer.
track_score: The track score output from the association matrix.
point_scores: The point scores from sleap.
instance_score: The instance scores from sleap.
skeleton: The sleap skeleton used for the instance.
pose: A dictionary containing the node name and corresponding point.
device: String representation of the device the instance should be on.
"""
_gt_track_id: int = attrs.field(
alias="gt_track_id", default=-1, converter=_to_tensor
)
_pred_track_id: int = attrs.field(
alias="pred_track_id", default=-1, converter=_to_tensor
)
_bbox: ArrayLike = attrs.field(alias="bbox", factory=list, converter=_to_tensor)
_crop: ArrayLike = attrs.field(alias="crop", factory=list, converter=_to_tensor)
_centroid: dict[str, ArrayLike] = attrs.field(alias="centroid", factory=dict)
_features: ArrayLike = attrs.field(
alias="features", factory=list, converter=_to_tensor
)
_embeddings: dict = attrs.field(alias="embeddings", factory=dict)
_track_score: float = attrs.field(alias="track_score", default=-1.0)
_instance_score: float = attrs.field(alias="instance_score", default=-1.0)
_point_scores: ArrayLike = attrs.field(alias="point_scores", default=None)
_skeleton: sio.Skeleton = attrs.field(alias="skeleton", default=None)
_pose: dict[str, ArrayLike] = attrs.field(alias="pose", factory=dict)
_device: str = attrs.field(alias="device", default=None)
_frame: "Frame" = None
def __attrs_post_init__(self) -> None:
"""Handle dimensionality and more intricate default initializations post-init."""
self.bbox = _expand_to_rank(self.bbox, 3)
self.crop = _expand_to_rank(self.crop, 4)
self.features = _expand_to_rank(self.features, 2)
if self.skeleton is None:
self.skeleton = sio.Skeleton(["centroid"])
if self.bbox.shape[-1] == 0:
self.bbox = torch.empty([1, 0, 4])
if self.crop.shape[-1] == 0 and self.bbox.shape[1] != 0:
y1, x1, y2, x2 = self.bbox.squeeze(dim=0).nanmean(dim=0)
self.centroid = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])}
if len(self.pose) == 0 and self.bbox.shape[1]:
y1, x1, y2, x2 = self.bbox.squeeze(dim=0).mean(dim=0)
self._pose = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])}
if self.point_scores is None and len(self.pose) != 0:
self._point_scores = np.zeros((len(self.pose), 2))
self.to(self.device)
def __repr__(self) -> str:
"""Return string representation of the Instance."""
return (
"Instance("
f"gt_track_id={self._gt_track_id.item()}, "
f"pred_track_id={self._pred_track_id.item()}, "
f"bbox={self._bbox}, "
f"centroid={self._centroid}, "
f"crop={self._crop.shape}, "
f"features={self._features.shape}, "
f"device={self._device}"
")"
)
def to(self, map_location: Union[str, torch.device]) -> "Instance":
"""Move instance to different device or change dtype. (See `torch.to` for more info).
Args:
map_location: Either the device or dtype for the instance to be moved.
Returns:
self: reference to the instance moved to correct device/dtype.
"""
if map_location is not None and map_location != "":
self._gt_track_id = self._gt_track_id.to(map_location)
self._pred_track_id = self._pred_track_id.to(map_location)
self._bbox = self._bbox.to(map_location)
self._crop = self._crop.to(map_location)
self._features = self._features.to(map_location)
if isinstance(map_location, (str, torch.device)):
self.device = map_location
return self
@classmethod
def from_slp(
cls,
slp_instance: Union[sio.PredictedInstance, sio.Instance],
bbox_size: Union[int, tuple] = 64,
crop: ArrayLike = None,
device: str = None,
) -> None:
"""Convert a slp instance to a dreem instance.
Args:
slp_instance: A `sleap_io.Instance` object representing a detection
bbox_size: size of the pose-centered bbox to form.
crop: The corresponding crop of the bbox
device: which device to keep the instance on
Returns:
A dreem.Instance object with a pose-centered bbox and no crop.
"""
try:
track_id = int(slp_instance.track.name)
except ValueError:
track_id = int(
"".join([str(ord(c)) for c in slp_instance.track.name])
) # better way to handle this?
if isinstance(bbox_size, int):
bbox_size = (bbox_size, bbox_size)
track_score = -1.0
point_scores = np.full(len(slp_instance.points), -1)
instance_score = -1
if isinstance(slp_instance, sio.PredictedInstance):
track_score = slp_instance.tracking_score
point_scores = slp_instance.numpy()[:, -1]
instance_score = slp_instance.score
centroid = np.nanmean(slp_instance.numpy(), axis=1)
bbox = [
centroid[1] - bbox_size[1],
centroid[0] - bbox_size[0],
centroid[1] + bbox_size[1],
centroid[0] + bbox_size[0],
]
return cls(
gt_track_id=track_id,
bbox=bbox,
crop=crop,
centroid={"centroid": centroid},
track_score=track_score,
point_scores=point_scores,
instance_score=instance_score,
skeleton=slp_instance.skeleton,
pose={
node.name: point.numpy() for node, point in slp_instance.points.items()
},
device=device,
)
def to_slp(
self, track_lookup: dict[int, sio.Track] = {}
) -> tuple[sio.PredictedInstance, dict[int, sio.Track]]:
"""Convert instance to sleap_io.PredictedInstance object.
Args:
track_lookup: A track look up dictionary containing track_id:sio.Track.
Returns: A sleap_io.PredictedInstance with necessary metadata
and a track_lookup dictionary to persist tracks.
"""
try:
track_id = self.pred_track_id.item()
if track_id not in track_lookup:
track_lookup[track_id] = sio.Track(name=self.pred_track_id.item())
track = track_lookup[track_id]
return (
sio.PredictedInstance.from_numpy(
points=np.array(list(self.pose.values())),
skeleton=self.skeleton,
point_scores=self.point_scores,
instance_score=self.instance_score,
tracking_score=self.track_score,
track=track,
),
track_lookup,
)
except Exception as e:
print(
f"Pose: {np.array(list(self.pose.values())).shape}, Pose score shape {self.point_scores.shape}"
)
raise RuntimeError(f"Failed to convert to sio.PredictedInstance: {e}")
@property
def device(self) -> str:
"""The device the instance is on.
Returns:
The str representation of the device the gpu is on.
"""
return self._device
@device.setter
def device(self, device) -> None:
"""Set for the device property.
Args:
device: The str representation of the device.
"""
self._device = device
@property
def gt_track_id(self) -> torch.Tensor:
"""The ground truth track id of the instance.
Returns:
A tensor containing the ground truth track id
"""
return self._gt_track_id
@gt_track_id.setter
def gt_track_id(self, track: int):
"""Set the instance ground-truth track id.
Args:
track: An int representing the ground-truth track id.
"""
if track is not None:
self._gt_track_id = torch.tensor([track])
else:
self._gt_track_id = torch.tensor([])
def has_gt_track_id(self) -> bool:
"""Determine if instance has a gt track assignment.
Returns:
True if the gt track id is set, otherwise False.
"""
if self._gt_track_id.shape[0] == 0:
return False
else:
return True
@property
def pred_track_id(self) -> torch.Tensor:
"""The track id predicted by the tracker using asso_output from model.
Returns:
A tensor containing the predicted track id.
"""
return self._pred_track_id
@pred_track_id.setter
def pred_track_id(self, track: int) -> None:
"""Set predicted track id.
Args:
track: an int representing the predicted track id.
"""
if track is not None:
self._pred_track_id = torch.tensor([track])
else:
self._pred_track_id = torch.tensor([])
def has_pred_track_id(self) -> bool:
"""Determine whether instance has predicted track id.
Returns:
True if instance has a pred track id, False otherwise.
"""
if self._pred_track_id.item() == -1 or self._pred_track_id.shape[0] == 0:
return False
else:
return True
@property
def bbox(self) -> torch.Tensor:
"""The bounding box coordinates of the instance in the original frame.
Returns:
A (1,4) tensor containing the bounding box coordinates.
"""
return self._bbox
@bbox.setter
def bbox(self, bbox: ArrayLike) -> None:
"""Set the instance bounding box.
Args:
bbox: an arraylike object containing the bounding box coordinates.
"""
if bbox is None or len(bbox) == 0:
self._bbox = torch.empty((0, 4))
else:
if not isinstance(bbox, torch.Tensor):
self._bbox = torch.tensor(bbox)
else:
self._bbox = bbox
if self._bbox.shape[0] and len(self._bbox.shape) == 1:
self._bbox = self._bbox.unsqueeze(0)
if self._bbox.shape[1] and len(self._bbox.shape) == 2:
self._bbox = self._bbox.unsqueeze(0)
def has_bbox(self) -> bool:
"""Determine if the instance has a bbox.
Returns:
True if the instance has a bounding box, false otherwise.
"""
if self._bbox.shape[1] == 0:
return False
else:
return True
@property
def centroid(self) -> dict[str, ArrayLike]:
"""The centroid around which the crop was formed.
Returns:
A dict containing the anchor name and the x, y bbox midpoint.
"""
return self._centroid
@centroid.setter
def centroid(self, centroid: dict[str, ArrayLike]) -> None:
"""Set the centroid of the instance.
Args:
centroid: A dict containing the anchor name and points.
"""
self._centroid = centroid
@property
def anchor(self) -> list[str]:
"""The anchor node name around which the crop was formed.
Returns:
the list of anchors around which each crop was formed
the list of anchors around which each crop was formed
"""
if self.centroid:
return list(self.centroid.keys())
return ""
@property
def crop(self) -> torch.Tensor:
"""The crop of the instance.
Returns:
A (1, c, h , w) tensor containing the cropped image centered around the instance.
"""
return self._crop
@crop.setter
def crop(self, crop: ArrayLike) -> None:
"""Set the crop of the instance.
Args:
crop: an arraylike object containing the cropped image of the centered instance.
"""
if crop is None or len(crop) == 0:
self._crop = torch.tensor([])
else:
if not isinstance(crop, torch.Tensor):
self._crop = torch.tensor(crop)
else:
self._crop = crop
if len(self._crop.shape) == 2:
self._crop = self._crop.unsqueeze(0)
if len(self._crop.shape) == 3:
self._crop = self._crop.unsqueeze(0)
def has_crop(self) -> bool:
"""Determine if the instance has a crop.
Returns:
True if the instance has an image otherwise False.
"""
if self._crop.shape[-1] == 0:
return False
else:
return True
@property
def features(self) -> torch.Tensor:
"""Re-ID feature vector from backbone model to be used as input to transformer.
Returns:
a (1, d) tensor containing the reid feature vector.
"""
return self._features
@features.setter
def features(self, features: ArrayLike) -> None:
"""Set the reid feature vector of the instance.
Args:
features: a (1,d) array like object containing the reid features for the instance.
"""
if features is None or len(features) == 0:
self._features = torch.tensor([])
elif not isinstance(features, torch.Tensor):
self._features = torch.tensor(features)
else:
self._features = features
if self._features.shape[0] and len(self._features.shape) == 1:
self._features = self._features.unsqueeze(0)
def has_features(self) -> bool:
"""Determine if the instance has computed reid features.
Returns:
True if the instance has reid features, False otherwise.
"""
if self._features.shape[-1] == 0:
return False
else:
return True
def has_embedding(self, emb_type: str = None) -> bool:
"""Determine if the instance has embedding type requested.
Args:
emb_type: The key to check in the embedding dictionary.
Returns:
True if `emb_type` in embedding_dict else false
"""
return emb_type in self._embeddings
def get_embedding(
self, emb_type: str = "all"
) -> Union[dict[str, torch.Tensor], torch.Tensor, None]:
"""Retrieve instance's spatial/temporal embedding.
Args:
emb_type: The string key of the embedding to retrieve. Should be "pos", "temp"
Returns:
* A torch tensor representing the spatial/temporal location of the instance.
* None if the embedding is not stored
"""
if emb_type.lower() == "all":
return self._embeddings
else:
try:
return self._embeddings[emb_type]
except KeyError:
print(
f"{emb_type} not saved! Only {list(self._embeddings.keys())} are available"
)
return None
def add_embedding(self, emb_type: str, embedding: torch.Tensor) -> None:
"""Save embedding to instance embedding dictionary.
Args:
emb_type: Key/embedding type to be saved to dictionary
embedding: The actual torch tensor embedding.
"""
embedding = _expand_to_rank(embedding, 2)
self._embeddings[emb_type] = embedding
@property
def frame(self) -> "Frame":
"""Get the frame the instance belongs to.
Returns:
The back reference to the `Frame` that this `Instance` belongs to.
"""
return self._frame
@frame.setter
def frame(self, frame: "Frame") -> None:
"""Set the back reference to the `Frame` that this `Instance` belongs to.
This field is set when instances are added to `Frame` object.
Args:
frame: A `Frame` object containing the metadata for the frame that the instance belongs to
"""
self._frame = frame
@property
def pose(self) -> dict[str, ArrayLike]:
"""Get the pose of the instance.
Returns:
A dictionary containing the node and corresponding x,y points
"""
return self._pose
@pose.setter
def pose(self, pose: dict[str, ArrayLike]) -> None:
"""Set the pose of the instance.
Args:
pose: A nodes x 2 array containing the pose coordinates.
"""
if pose is not None:
self._pose = pose
elif self.bbox.shape[0]:
y1, x1, y2, x2 = self.bbox.squeeze()
self._pose = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])}
else:
self._pose = {}
def has_pose(self) -> bool:
"""Check if the instance has a pose.
Returns True if the instance has a pose.
"""
if len(self.pose):
return True
return False
@property
def shown_pose(self) -> dict[str, ArrayLike]:
"""Get the pose with shown nodes only.
Returns: A dictionary filtered by nodes that are shown (points are not nan).
"""
pose = self.pose
return {node: point for node, point in pose.items() if not np.isna(point).any()}
@property
def skeleton(self) -> sio.Skeleton:
"""Get the skeleton associated with the instance.
Returns: The sio.Skeleton associated with the instance.
"""
return self._skeleton
@skeleton.setter
def skeleton(self, skeleton: sio.Skeleton) -> None:
"""Set the skeleton associated with the instance.
Args:
skeleton: The sio.Skeleton associated with the instance.
"""
self._skeleton = skeleton
@property
def point_scores(self) -> ArrayLike:
"""Get the point scores associated with the pose prediction.
Returns: a vector of shape n containing the point scores outputed from sleap associated with pose predictions.
"""
return self._point_scores
@point_scores.setter
def point_scores(self, point_scores: ArrayLike) -> None:
"""Set the point scores associated with the pose prediction.
Args:
point_scores: a vector of shape n containing the point scores
outputted from sleap associated with pose predictions.
"""
self._point_scores = point_scores
@property
def instance_score(self) -> float:
"""Get the pose prediction score associated with the instance.
Returns: a float from 0-1 representing an instance_score.
"""
return self._instance_score
@instance_score.setter
def instance_score(self, instance_score: float) -> None:
"""Set the pose prediction score associated with the instance.
Args:
instance_score: a float from 0-1 representing an instance_score.
"""
self._instance_score = instance_score
@property
def track_score(self) -> float:
"""Get the track_score of the instance.
Returns: A float from 0-1 representing the output used in the tracker for assignment.
"""
return self._track_score
@track_score.setter
def track_score(self, track_score: float) -> None:
"""Set the track_score of the instance.
Args:
track_score: A float from 0-1 representing the output used in the tracker for assignment.
"""
self._track_score = track_score
anchor: list[str]
property
¶
The anchor node name around which the crop was formed.
Returns:
Type | Description |
---|---|
list[str]
|
the list of anchors around which each crop was formed the list of anchors around which each crop was formed |
bbox: torch.Tensor
property
writable
¶
The bounding box coordinates of the instance in the original frame.
Returns:
Type | Description |
---|---|
Tensor
|
A (1,4) tensor containing the bounding box coordinates. |
centroid: dict[str, ArrayLike]
property
writable
¶
The centroid around which the crop was formed.
Returns:
Type | Description |
---|---|
dict[str, ArrayLike]
|
A dict containing the anchor name and the x, y bbox midpoint. |
crop: torch.Tensor
property
writable
¶
The crop of the instance.
Returns:
Type | Description |
---|---|
Tensor
|
A (1, c, h , w) tensor containing the cropped image centered around the instance. |
device: str
property
writable
¶
The device the instance is on.
Returns:
Type | Description |
---|---|
str
|
The str representation of the device the gpu is on. |
features: torch.Tensor
property
writable
¶
Re-ID feature vector from backbone model to be used as input to transformer.
Returns:
Type | Description |
---|---|
Tensor
|
a (1, d) tensor containing the reid feature vector. |
frame: Frame
property
writable
¶
Get the frame the instance belongs to.
Returns:
Type | Description |
---|---|
Frame
|
The back reference to the |
gt_track_id: torch.Tensor
property
writable
¶
The ground truth track id of the instance.
Returns:
Type | Description |
---|---|
Tensor
|
A tensor containing the ground truth track id |
instance_score: float
property
writable
¶
Get the pose prediction score associated with the instance.
Returns: a float from 0-1 representing an instance_score.
point_scores: ArrayLike
property
writable
¶
Get the point scores associated with the pose prediction.
Returns: a vector of shape n containing the point scores outputed from sleap associated with pose predictions.
pose: dict[str, ArrayLike]
property
writable
¶
Get the pose of the instance.
Returns:
Type | Description |
---|---|
dict[str, ArrayLike]
|
A dictionary containing the node and corresponding x,y points |
pred_track_id: torch.Tensor
property
writable
¶
The track id predicted by the tracker using asso_output from model.
Returns:
Type | Description |
---|---|
Tensor
|
A tensor containing the predicted track id. |
shown_pose: dict[str, ArrayLike]
property
¶
Get the pose with shown nodes only.
Returns: A dictionary filtered by nodes that are shown (points are not nan).
skeleton: sio.Skeleton
property
writable
¶
Get the skeleton associated with the instance.
Returns: The sio.Skeleton associated with the instance.
track_score: float
property
writable
¶
Get the track_score of the instance.
Returns: A float from 0-1 representing the output used in the tracker for assignment.
__attrs_post_init__()
¶
Handle dimensionality and more intricate default initializations post-init.
Source code in dreem/io/instance.py
def __attrs_post_init__(self) -> None:
"""Handle dimensionality and more intricate default initializations post-init."""
self.bbox = _expand_to_rank(self.bbox, 3)
self.crop = _expand_to_rank(self.crop, 4)
self.features = _expand_to_rank(self.features, 2)
if self.skeleton is None:
self.skeleton = sio.Skeleton(["centroid"])
if self.bbox.shape[-1] == 0:
self.bbox = torch.empty([1, 0, 4])
if self.crop.shape[-1] == 0 and self.bbox.shape[1] != 0:
y1, x1, y2, x2 = self.bbox.squeeze(dim=0).nanmean(dim=0)
self.centroid = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])}
if len(self.pose) == 0 and self.bbox.shape[1]:
y1, x1, y2, x2 = self.bbox.squeeze(dim=0).mean(dim=0)
self._pose = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])}
if self.point_scores is None and len(self.pose) != 0:
self._point_scores = np.zeros((len(self.pose), 2))
self.to(self.device)
__repr__()
¶
Return string representation of the Instance.
Source code in dreem/io/instance.py
def __repr__(self) -> str:
"""Return string representation of the Instance."""
return (
"Instance("
f"gt_track_id={self._gt_track_id.item()}, "
f"pred_track_id={self._pred_track_id.item()}, "
f"bbox={self._bbox}, "
f"centroid={self._centroid}, "
f"crop={self._crop.shape}, "
f"features={self._features.shape}, "
f"device={self._device}"
")"
)
add_embedding(emb_type, embedding)
¶
Save embedding to instance embedding dictionary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
emb_type |
str
|
Key/embedding type to be saved to dictionary |
required |
embedding |
Tensor
|
The actual torch tensor embedding. |
required |
Source code in dreem/io/instance.py
def add_embedding(self, emb_type: str, embedding: torch.Tensor) -> None:
"""Save embedding to instance embedding dictionary.
Args:
emb_type: Key/embedding type to be saved to dictionary
embedding: The actual torch tensor embedding.
"""
embedding = _expand_to_rank(embedding, 2)
self._embeddings[emb_type] = embedding
from_slp(slp_instance, bbox_size=64, crop=None, device=None)
classmethod
¶
Convert a slp instance to a dreem instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
slp_instance |
Union[PredictedInstance, Instance]
|
A |
required |
bbox_size |
Union[int, tuple]
|
size of the pose-centered bbox to form. |
64
|
crop |
ArrayLike
|
The corresponding crop of the bbox |
None
|
device |
str
|
which device to keep the instance on |
None
|
Returns: A dreem.Instance object with a pose-centered bbox and no crop.
Source code in dreem/io/instance.py
@classmethod
def from_slp(
cls,
slp_instance: Union[sio.PredictedInstance, sio.Instance],
bbox_size: Union[int, tuple] = 64,
crop: ArrayLike = None,
device: str = None,
) -> None:
"""Convert a slp instance to a dreem instance.
Args:
slp_instance: A `sleap_io.Instance` object representing a detection
bbox_size: size of the pose-centered bbox to form.
crop: The corresponding crop of the bbox
device: which device to keep the instance on
Returns:
A dreem.Instance object with a pose-centered bbox and no crop.
"""
try:
track_id = int(slp_instance.track.name)
except ValueError:
track_id = int(
"".join([str(ord(c)) for c in slp_instance.track.name])
) # better way to handle this?
if isinstance(bbox_size, int):
bbox_size = (bbox_size, bbox_size)
track_score = -1.0
point_scores = np.full(len(slp_instance.points), -1)
instance_score = -1
if isinstance(slp_instance, sio.PredictedInstance):
track_score = slp_instance.tracking_score
point_scores = slp_instance.numpy()[:, -1]
instance_score = slp_instance.score
centroid = np.nanmean(slp_instance.numpy(), axis=1)
bbox = [
centroid[1] - bbox_size[1],
centroid[0] - bbox_size[0],
centroid[1] + bbox_size[1],
centroid[0] + bbox_size[0],
]
return cls(
gt_track_id=track_id,
bbox=bbox,
crop=crop,
centroid={"centroid": centroid},
track_score=track_score,
point_scores=point_scores,
instance_score=instance_score,
skeleton=slp_instance.skeleton,
pose={
node.name: point.numpy() for node, point in slp_instance.points.items()
},
device=device,
)
get_embedding(emb_type='all')
¶
Retrieve instance's spatial/temporal embedding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
emb_type |
str
|
The string key of the embedding to retrieve. Should be "pos", "temp" |
'all'
|
Returns:
Type | Description |
---|---|
Union[dict[str, Tensor], Tensor, None]
|
|
Source code in dreem/io/instance.py
def get_embedding(
self, emb_type: str = "all"
) -> Union[dict[str, torch.Tensor], torch.Tensor, None]:
"""Retrieve instance's spatial/temporal embedding.
Args:
emb_type: The string key of the embedding to retrieve. Should be "pos", "temp"
Returns:
* A torch tensor representing the spatial/temporal location of the instance.
* None if the embedding is not stored
"""
if emb_type.lower() == "all":
return self._embeddings
else:
try:
return self._embeddings[emb_type]
except KeyError:
print(
f"{emb_type} not saved! Only {list(self._embeddings.keys())} are available"
)
return None
has_bbox()
¶
Determine if the instance has a bbox.
Returns:
Type | Description |
---|---|
bool
|
True if the instance has a bounding box, false otherwise. |
has_crop()
¶
Determine if the instance has a crop.
Returns:
Type | Description |
---|---|
bool
|
True if the instance has an image otherwise False. |
has_embedding(emb_type=None)
¶
Determine if the instance has embedding type requested.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
emb_type |
str
|
The key to check in the embedding dictionary. |
None
|
Returns:
Type | Description |
---|---|
bool
|
True if |
Source code in dreem/io/instance.py
has_features()
¶
Determine if the instance has computed reid features.
Returns:
Type | Description |
---|---|
bool
|
True if the instance has reid features, False otherwise. |
has_gt_track_id()
¶
Determine if instance has a gt track assignment.
Returns:
Type | Description |
---|---|
bool
|
True if the gt track id is set, otherwise False. |
has_pose()
¶
has_pred_track_id()
¶
Determine whether instance has predicted track id.
Returns:
Type | Description |
---|---|
bool
|
True if instance has a pred track id, False otherwise. |
to(map_location)
¶
Move instance to different device or change dtype. (See torch.to
for more info).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
map_location |
Union[str, device]
|
Either the device or dtype for the instance to be moved. |
required |
Returns:
Name | Type | Description |
---|---|---|
self |
Instance
|
reference to the instance moved to correct device/dtype. |
Source code in dreem/io/instance.py
def to(self, map_location: Union[str, torch.device]) -> "Instance":
"""Move instance to different device or change dtype. (See `torch.to` for more info).
Args:
map_location: Either the device or dtype for the instance to be moved.
Returns:
self: reference to the instance moved to correct device/dtype.
"""
if map_location is not None and map_location != "":
self._gt_track_id = self._gt_track_id.to(map_location)
self._pred_track_id = self._pred_track_id.to(map_location)
self._bbox = self._bbox.to(map_location)
self._crop = self._crop.to(map_location)
self._features = self._features.to(map_location)
if isinstance(map_location, (str, torch.device)):
self.device = map_location
return self
to_slp(track_lookup={})
¶
Convert instance to sleap_io.PredictedInstance object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
track_lookup |
dict[int, Track]
|
A track look up dictionary containing track_id:sio.Track. |
{}
|
Returns: A sleap_io.PredictedInstance with necessary metadata and a track_lookup dictionary to persist tracks.
Source code in dreem/io/instance.py
def to_slp(
self, track_lookup: dict[int, sio.Track] = {}
) -> tuple[sio.PredictedInstance, dict[int, sio.Track]]:
"""Convert instance to sleap_io.PredictedInstance object.
Args:
track_lookup: A track look up dictionary containing track_id:sio.Track.
Returns: A sleap_io.PredictedInstance with necessary metadata
and a track_lookup dictionary to persist tracks.
"""
try:
track_id = self.pred_track_id.item()
if track_id not in track_lookup:
track_lookup[track_id] = sio.Track(name=self.pred_track_id.item())
track = track_lookup[track_id]
return (
sio.PredictedInstance.from_numpy(
points=np.array(list(self.pose.values())),
skeleton=self.skeleton,
point_scores=self.point_scores,
instance_score=self.instance_score,
tracking_score=self.track_score,
track=track,
),
track_lookup,
)
except Exception as e:
print(
f"Pose: {np.array(list(self.pose.values())).shape}, Pose score shape {self.point_scores.shape}"
)
raise RuntimeError(f"Failed to convert to sio.PredictedInstance: {e}")