sleap_dataset
dreem.datasets.sleap_dataset
¶
Module containing logic for loading sleap datasets.
SleapDataset
¶
Bases: BaseDataset
Dataset for loading animal behavior data from sleap.
Source code in dreem/datasets/sleap_dataset.py
class SleapDataset(BaseDataset):
"""Dataset for loading animal behavior data from sleap."""
def __init__(
self,
slp_files: list[str],
video_files: list[str],
padding: int = 5,
crop_size: int = 128,
anchors: Union[int, list[str], str] = "",
chunk: bool = True,
clip_length: int = 500,
mode: str = "train",
handle_missing: str = "centroid",
augmentations: dict = None,
n_chunks: Union[int, float] = 1.0,
seed: int = None,
verbose: bool = False,
):
"""Initialize SleapDataset.
Args:
slp_files: a list of .slp files storing tracking annotations
video_files: a list of paths to video files
padding: amount of padding around object crops
crop_size: the size of the object crops
anchors: One of:
* a string indicating a single node to center crops around
* a list of skeleton node names to be used as the center of crops
* an int indicating the number of anchors to randomly select
If unavailable then crop around the midpoint between all visible anchors.
chunk: whether or not to chunk the dataset into batches
clip_length: the number of frames in each chunk
mode: `train`, `val`, or `test`. Determines whether this dataset is used for
training, validation/testing/inference.
handle_missing: how to handle missing single nodes. one of `["drop", "ignore", "centroid"]`.
if "drop" then we dont include instances which are missing the `anchor`.
if "ignore" then we use a mask instead of a crop and nan centroids/bboxes.
if "centroid" then we default to the pose centroid as the node to crop around.
augmentations: An optional dict mapping augmentations to parameters. The keys
should map directly to augmentation classes in albumentations. Example:
augmentations = {
'Rotate': {'limit': [-90, 90], 'p': 0.5},
'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0, 'p': 0.2},
'RandomContrast': {'limit': 0.2, 'p': 0.6}
}
n_chunks: Number of chunks to subsample from.
Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
seed: set a seed for reproducibility
verbose: boolean representing whether to print
"""
super().__init__(
slp_files,
video_files,
padding,
crop_size,
chunk,
clip_length,
mode,
augmentations,
n_chunks,
seed,
)
self.slp_files = slp_files
self.video_files = video_files
self.padding = padding
self.crop_size = crop_size
self.chunk = chunk
self.clip_length = clip_length
self.mode = mode.lower()
self.handle_missing = handle_missing.lower()
self.n_chunks = n_chunks
self.seed = seed
if isinstance(anchors, int):
self.anchors = anchors
elif isinstance(anchors, str):
self.anchors = [anchors]
else:
self.anchors = anchors
if (
isinstance(self.anchors, list) and len(self.anchors) == 0
) or self.anchors == 0:
raise ValueError(f"Must provide at least one anchor but got {self.anchors}")
self.verbose = verbose
# if self.seed is not None:
# np.random.seed(self.seed)
self.labels = [sio.load_slp(slp_file) for slp_file in self.slp_files]
# do we need this? would need to update with sleap-io
# for label in self.labels:
# label.remove_empty_instances(keep_empty_frames=False)
self.frame_idx = [torch.arange(len(labels)) for labels in self.labels]
# Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
# used in call to get_instances()
self.create_chunks()
def get_indices(self, idx: int) -> tuple:
"""Retrieve label and frame indices given batch index.
Args:
idx: the index of the batch.
"""
return self.label_idx[idx], self.chunked_frame_idx[idx]
def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[Frame]:
"""Get an element of the dataset.
Args:
label_idx: index of the labels
frame_idx: index of the frames
Returns:
A list of `dreem.io.Frame` objects containing metadata and instance data for the batch/clip.
"""
video = self.labels[label_idx]
video_name = self.video_files[label_idx]
vid_reader = imageio.get_reader(video_name, "ffmpeg")
img = vid_reader.get_data(0)
skeleton = video.skeletons[-1]
frames = []
for i, frame_ind in enumerate(frame_idx):
(
instances,
gt_track_ids,
poses,
shown_poses,
point_scores,
instance_score,
) = ([], [], [], [], [], [])
frame_ind = int(frame_ind)
lf = video[frame_ind]
try:
img = vid_reader.get_data(frame_ind)
if len(img.shape) == 2:
img = np.expand_dims(img, 0)
h, w, c = img.shape
except IndexError as e:
print(f"Could not read frame {frame_ind} from {video_name} due to {e}")
continue
if len(img.shape) == 2:
img = img.expand_dims(-1)
h, w, c = img.shape
if c == 1:
img = np.concatenate(
[img, img, img], axis=-1
) # convert to grayscale to rgb
if np.issubdtype(img.dtype, np.integer): # convert int to float
img = img.astype(np.float32) / 255
n_instances_dropped = 0
gt_instances = lf.instances
if self.mode == "train":
np.random.shuffle(gt_instances)
for instance in gt_instances:
if (
np.random.uniform() < self.instance_dropout["p"]
and n_instances_dropped < self.instance_dropout["n"]
):
n_instances_dropped += 1
continue
if instance.track is not None:
gt_track_id = video.tracks.index(instance.track)
else:
gt_track_id = -1
gt_track_ids.append(gt_track_id)
poses.append(
dict(
zip(
[n.name for n in instance.skeleton.nodes],
[[p.x, p.y] for p in instance.points.values()],
)
)
)
shown_poses = [
{
key: val
for key, val in instance.items()
if not np.isnan(val).any()
}
for instance in poses
]
point_scores.append(
np.array(
[
(
point.score
if isinstance(point, sio.PredictedPoint)
else 1.0
)
for point in instance.points.values()
]
)
)
if isinstance(instance, sio.PredictedInstance):
instance_score.append(instance.score)
else:
instance_score.append(1.0)
# augmentations
if self.augmentations is not None:
for transform in self.augmentations:
if isinstance(transform, A.CoarseDropout):
transform.fill_value = random.randint(0, 255)
if shown_poses:
keypoints = np.vstack([list(s.values()) for s in shown_poses])
else:
keypoints = []
augmented = self.augmentations(image=img, keypoints=keypoints)
img, aug_poses = augmented["image"], augmented["keypoints"]
aug_poses = [
arr
for arr in np.split(
np.array(aug_poses),
np.array([len(s) for s in shown_poses]).cumsum(),
)
if arr.size != 0
]
aug_poses = [
dict(zip(list(pose_dict.keys()), aug_pose_arr.tolist()))
for aug_pose_arr, pose_dict in zip(aug_poses, shown_poses)
]
_ = [
pose.update(aug_pose)
for pose, aug_pose in zip(shown_poses, aug_poses)
]
img = tvf.to_tensor(img)
for j in range(len(gt_track_ids)):
pose = shown_poses[j]
"""Check for anchor"""
crops = []
boxes = []
centroids = {}
if isinstance(self.anchors, int):
anchors_to_choose = list(pose.keys()) + ["midpoint"]
anchors = np.random.choice(anchors_to_choose, self.anchors)
else:
anchors = self.anchors
dropped_anchors = self.node_dropout(anchors)
for anchor in anchors:
if anchor in dropped_anchors:
centroid = np.array([np.nan, np.nan])
elif anchor == "midpoint" or anchor == "centroid":
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
elif anchor in pose:
centroid = np.array(pose[anchor])
if np.isnan(centroid).any():
centroid = np.array([np.nan, np.nan])
elif (
anchor not in pose
and len(anchors) == 1
and self.handle_missing == "centroid"
):
anchor = "midpoint"
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
else:
centroid = np.array([np.nan, np.nan])
if np.isnan(centroid).all():
bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan])
else:
bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
padding=self.padding,
)
if bbox.isnan().all():
crop = torch.zeros(
c,
self.crop_size + 2 * self.padding,
self.crop_size + 2 * self.padding,
dtype=img.dtype,
)
else:
crop = data_utils.crop_bbox(img, bbox)
crops.append(crop)
centroids[anchor] = centroid
boxes.append(bbox)
if len(crops) > 0:
crops = torch.concat(crops, dim=0)
if len(boxes) > 0:
boxes = torch.stack(boxes, dim=0)
if self.handle_missing == "drop" and boxes.isnan().any():
continue
instance = Instance(
gt_track_id=gt_track_ids[j],
pred_track_id=-1,
crop=crops,
centroid=centroids,
bbox=boxes,
skeleton=skeleton,
pose=poses[j],
point_scores=point_scores[j],
instance_score=instance_score[j],
)
instances.append(instance)
frame = Frame(
video_id=label_idx,
frame_id=frame_ind,
vid_file=video_name,
img_shape=img.shape,
instances=instances,
)
frames.append(frame)
return frames
__init__(slp_files, video_files, padding=5, crop_size=128, anchors='', chunk=True, clip_length=500, mode='train', handle_missing='centroid', augmentations=None, n_chunks=1.0, seed=None, verbose=False)
¶
Initialize SleapDataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
slp_files |
list[str]
|
a list of .slp files storing tracking annotations |
required |
video_files |
list[str]
|
a list of paths to video files |
required |
padding |
int
|
amount of padding around object crops |
5
|
crop_size |
int
|
the size of the object crops |
128
|
anchors |
Union[int, list[str], str]
|
One of: * a string indicating a single node to center crops around * a list of skeleton node names to be used as the center of crops * an int indicating the number of anchors to randomly select If unavailable then crop around the midpoint between all visible anchors. |
''
|
chunk |
bool
|
whether or not to chunk the dataset into batches |
True
|
clip_length |
int
|
the number of frames in each chunk |
500
|
mode |
str
|
|
'train'
|
handle_missing |
str
|
how to handle missing single nodes. one of |
'centroid'
|
augmentations |
dict
|
An optional dict mapping augmentations to parameters. The keys should map directly to augmentation classes in albumentations. Example: augmentations = { 'Rotate': {'limit': [-90, 90], 'p': 0.5}, 'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0, 'p': 0.2}, 'RandomContrast': {'limit': 0.2, 'p': 0.6} } |
None
|
n_chunks |
Union[int, float]
|
Number of chunks to subsample from. Can either a fraction of the dataset (ie (0,1.0]) or number of chunks |
1.0
|
seed |
int
|
set a seed for reproducibility |
None
|
verbose |
bool
|
boolean representing whether to print |
False
|
Source code in dreem/datasets/sleap_dataset.py
def __init__(
self,
slp_files: list[str],
video_files: list[str],
padding: int = 5,
crop_size: int = 128,
anchors: Union[int, list[str], str] = "",
chunk: bool = True,
clip_length: int = 500,
mode: str = "train",
handle_missing: str = "centroid",
augmentations: dict = None,
n_chunks: Union[int, float] = 1.0,
seed: int = None,
verbose: bool = False,
):
"""Initialize SleapDataset.
Args:
slp_files: a list of .slp files storing tracking annotations
video_files: a list of paths to video files
padding: amount of padding around object crops
crop_size: the size of the object crops
anchors: One of:
* a string indicating a single node to center crops around
* a list of skeleton node names to be used as the center of crops
* an int indicating the number of anchors to randomly select
If unavailable then crop around the midpoint between all visible anchors.
chunk: whether or not to chunk the dataset into batches
clip_length: the number of frames in each chunk
mode: `train`, `val`, or `test`. Determines whether this dataset is used for
training, validation/testing/inference.
handle_missing: how to handle missing single nodes. one of `["drop", "ignore", "centroid"]`.
if "drop" then we dont include instances which are missing the `anchor`.
if "ignore" then we use a mask instead of a crop and nan centroids/bboxes.
if "centroid" then we default to the pose centroid as the node to crop around.
augmentations: An optional dict mapping augmentations to parameters. The keys
should map directly to augmentation classes in albumentations. Example:
augmentations = {
'Rotate': {'limit': [-90, 90], 'p': 0.5},
'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0, 'p': 0.2},
'RandomContrast': {'limit': 0.2, 'p': 0.6}
}
n_chunks: Number of chunks to subsample from.
Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
seed: set a seed for reproducibility
verbose: boolean representing whether to print
"""
super().__init__(
slp_files,
video_files,
padding,
crop_size,
chunk,
clip_length,
mode,
augmentations,
n_chunks,
seed,
)
self.slp_files = slp_files
self.video_files = video_files
self.padding = padding
self.crop_size = crop_size
self.chunk = chunk
self.clip_length = clip_length
self.mode = mode.lower()
self.handle_missing = handle_missing.lower()
self.n_chunks = n_chunks
self.seed = seed
if isinstance(anchors, int):
self.anchors = anchors
elif isinstance(anchors, str):
self.anchors = [anchors]
else:
self.anchors = anchors
if (
isinstance(self.anchors, list) and len(self.anchors) == 0
) or self.anchors == 0:
raise ValueError(f"Must provide at least one anchor but got {self.anchors}")
self.verbose = verbose
# if self.seed is not None:
# np.random.seed(self.seed)
self.labels = [sio.load_slp(slp_file) for slp_file in self.slp_files]
# do we need this? would need to update with sleap-io
# for label in self.labels:
# label.remove_empty_instances(keep_empty_frames=False)
self.frame_idx = [torch.arange(len(labels)) for labels in self.labels]
# Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
# used in call to get_instances()
self.create_chunks()
get_indices(idx)
¶
Retrieve label and frame indices given batch index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx |
int
|
the index of the batch. |
required |
get_instances(label_idx, frame_idx)
¶
Get an element of the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
label_idx |
List[int]
|
index of the labels |
required |
frame_idx |
List[int]
|
index of the frames |
required |
Returns:
Type | Description |
---|---|
list[Frame]
|
A list of |
Source code in dreem/datasets/sleap_dataset.py
def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[Frame]:
"""Get an element of the dataset.
Args:
label_idx: index of the labels
frame_idx: index of the frames
Returns:
A list of `dreem.io.Frame` objects containing metadata and instance data for the batch/clip.
"""
video = self.labels[label_idx]
video_name = self.video_files[label_idx]
vid_reader = imageio.get_reader(video_name, "ffmpeg")
img = vid_reader.get_data(0)
skeleton = video.skeletons[-1]
frames = []
for i, frame_ind in enumerate(frame_idx):
(
instances,
gt_track_ids,
poses,
shown_poses,
point_scores,
instance_score,
) = ([], [], [], [], [], [])
frame_ind = int(frame_ind)
lf = video[frame_ind]
try:
img = vid_reader.get_data(frame_ind)
if len(img.shape) == 2:
img = np.expand_dims(img, 0)
h, w, c = img.shape
except IndexError as e:
print(f"Could not read frame {frame_ind} from {video_name} due to {e}")
continue
if len(img.shape) == 2:
img = img.expand_dims(-1)
h, w, c = img.shape
if c == 1:
img = np.concatenate(
[img, img, img], axis=-1
) # convert to grayscale to rgb
if np.issubdtype(img.dtype, np.integer): # convert int to float
img = img.astype(np.float32) / 255
n_instances_dropped = 0
gt_instances = lf.instances
if self.mode == "train":
np.random.shuffle(gt_instances)
for instance in gt_instances:
if (
np.random.uniform() < self.instance_dropout["p"]
and n_instances_dropped < self.instance_dropout["n"]
):
n_instances_dropped += 1
continue
if instance.track is not None:
gt_track_id = video.tracks.index(instance.track)
else:
gt_track_id = -1
gt_track_ids.append(gt_track_id)
poses.append(
dict(
zip(
[n.name for n in instance.skeleton.nodes],
[[p.x, p.y] for p in instance.points.values()],
)
)
)
shown_poses = [
{
key: val
for key, val in instance.items()
if not np.isnan(val).any()
}
for instance in poses
]
point_scores.append(
np.array(
[
(
point.score
if isinstance(point, sio.PredictedPoint)
else 1.0
)
for point in instance.points.values()
]
)
)
if isinstance(instance, sio.PredictedInstance):
instance_score.append(instance.score)
else:
instance_score.append(1.0)
# augmentations
if self.augmentations is not None:
for transform in self.augmentations:
if isinstance(transform, A.CoarseDropout):
transform.fill_value = random.randint(0, 255)
if shown_poses:
keypoints = np.vstack([list(s.values()) for s in shown_poses])
else:
keypoints = []
augmented = self.augmentations(image=img, keypoints=keypoints)
img, aug_poses = augmented["image"], augmented["keypoints"]
aug_poses = [
arr
for arr in np.split(
np.array(aug_poses),
np.array([len(s) for s in shown_poses]).cumsum(),
)
if arr.size != 0
]
aug_poses = [
dict(zip(list(pose_dict.keys()), aug_pose_arr.tolist()))
for aug_pose_arr, pose_dict in zip(aug_poses, shown_poses)
]
_ = [
pose.update(aug_pose)
for pose, aug_pose in zip(shown_poses, aug_poses)
]
img = tvf.to_tensor(img)
for j in range(len(gt_track_ids)):
pose = shown_poses[j]
"""Check for anchor"""
crops = []
boxes = []
centroids = {}
if isinstance(self.anchors, int):
anchors_to_choose = list(pose.keys()) + ["midpoint"]
anchors = np.random.choice(anchors_to_choose, self.anchors)
else:
anchors = self.anchors
dropped_anchors = self.node_dropout(anchors)
for anchor in anchors:
if anchor in dropped_anchors:
centroid = np.array([np.nan, np.nan])
elif anchor == "midpoint" or anchor == "centroid":
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
elif anchor in pose:
centroid = np.array(pose[anchor])
if np.isnan(centroid).any():
centroid = np.array([np.nan, np.nan])
elif (
anchor not in pose
and len(anchors) == 1
and self.handle_missing == "centroid"
):
anchor = "midpoint"
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
else:
centroid = np.array([np.nan, np.nan])
if np.isnan(centroid).all():
bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan])
else:
bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
padding=self.padding,
)
if bbox.isnan().all():
crop = torch.zeros(
c,
self.crop_size + 2 * self.padding,
self.crop_size + 2 * self.padding,
dtype=img.dtype,
)
else:
crop = data_utils.crop_bbox(img, bbox)
crops.append(crop)
centroids[anchor] = centroid
boxes.append(bbox)
if len(crops) > 0:
crops = torch.concat(crops, dim=0)
if len(boxes) > 0:
boxes = torch.stack(boxes, dim=0)
if self.handle_missing == "drop" and boxes.isnan().any():
continue
instance = Instance(
gt_track_id=gt_track_ids[j],
pred_track_id=-1,
crop=crops,
centroid=centroids,
bbox=boxes,
skeleton=skeleton,
pose=poses[j],
point_scores=point_scores[j],
instance_score=instance_score[j],
)
instances.append(instance)
frame = Frame(
video_id=label_idx,
frame_id=frame_ind,
vid_file=video_name,
img_shape=img.shape,
instances=instances,
)
frames.append(frame)
return frames