Skip to content

sleap_dataset

dreem.datasets.sleap_dataset

Module containing logic for loading sleap datasets.

Classes:

Name Description
SleapDataset

Dataset for loading animal behavior data from sleap.

SleapDataset

Bases: BaseDataset

Dataset for loading animal behavior data from sleap.

Methods:

Name Description
__del__

Handle file closing before garbage collection.

__init__

Initialize SleapDataset.

get_indices

Retrieve label and frame indices given batch index.

get_instances

Get an element of the dataset.

Source code in dreem/datasets/sleap_dataset.py
class SleapDataset(BaseDataset):
    """Dataset for loading animal behavior data from sleap."""

    def __init__(
        self,
        slp_files: list[str],
        video_files: list[str],
        data_dirs: Optional[list[str]] = None,
        padding: int = 5,
        crop_size: Union[int, list[int]] = 128,
        anchors: int | list[str] | str = "",
        chunk: bool = True,
        clip_length: int = 16,
        mode: str = "train",
        handle_missing: str = "centroid",
        augmentations: dict | None = None,
        n_chunks: int | float = 1.0,
        seed: int | None = None,
        verbose: bool = False,
        normalize_image: bool = True,
        max_batching_gap: int = 15,
        use_tight_bbox: bool = False,
        **kwargs,
    ):
        """Initialize SleapDataset.

        Args:
            slp_files: a list of .slp files storing tracking annotations
            video_files: a list of paths to video files
            data_dirs: a path, or a list of paths to data directories. If provided, crop_size should be a list of integers
                with the same length as data_dirs.
            padding: amount of padding around object crops
            crop_size: the size of the object crops. Can be either:
                - An integer specifying a single crop size for all objects
                - A list of integers specifying different crop sizes for different data directories
            anchors: One of:
                        * a string indicating a single node to center crops around
                        * a list of skeleton node names to be used as the center of crops
                        * an int indicating the number of anchors to randomly select
                    If unavailable then crop around the midpoint between all visible anchors.
            chunk: whether or not to chunk the dataset into batches
            clip_length: the number of frames in each chunk
            mode: `train`, `val`, or `test`. Determines whether this dataset is used for
                training, validation/testing/inference.
            handle_missing: how to handle missing single nodes. one of `["drop", "ignore", "centroid"]`.
                            if "drop" then we dont include instances which are missing the `anchor`.
                            if "ignore" then we use a mask instead of a crop and nan centroids/bboxes.
                            if "centroid" then we default to the pose centroid as the node to crop around.
            augmentations: An optional dict mapping augmentations to parameters. The keys
                should map directly to augmentation classes in albumentations. Example:
                    augmentations = {
                        'Rotate': {'limit': [-90, 90], 'p': 0.5},
                        'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0, 'p': 0.2},
                        'RandomContrast': {'limit': 0.2, 'p': 0.6}
                    }
            n_chunks: Number of chunks to subsample from.
                Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
            seed: set a seed for reproducibility
            verbose: boolean representing whether to print
            normalize_image: whether to normalize the image to [0, 1]
            max_batching_gap: the max number of frames that can be unlabelled before starting a new batch
            use_tight_bbox: whether to use tight bounding box (around keypoints) instead of the default square bounding box
        """
        super().__init__(
            slp_files,
            video_files,
            padding,
            crop_size,
            chunk,
            clip_length,
            mode,
            augmentations,
            n_chunks,
            seed,
        )

        self.slp_files = slp_files
        self.data_dirs = data_dirs
        self.video_files = video_files
        self.padding = padding
        self.crop_size = crop_size
        self.chunk = chunk
        self.clip_length = clip_length
        self.mode = mode.lower()
        self.handle_missing = handle_missing.lower()
        self.n_chunks = n_chunks
        self.seed = seed
        self.normalize_image = normalize_image
        self.max_batching_gap = max_batching_gap
        self.use_tight_bbox = use_tight_bbox

        if isinstance(anchors, int):
            self.anchors = anchors
        elif isinstance(anchors, str):
            self.anchors = [anchors]
        else:
            self.anchors = anchors

        if not isinstance(self.data_dirs, list):
            self.data_dirs = [self.data_dirs]

        if not isinstance(self.crop_size, list):
            # make a list so its handled consistently if multiple crops are used
            if len(self.data_dirs) > 0:  # for test mode, data_dirs is []
                self.crop_size = [self.crop_size] * len(self.data_dirs)
            else:
                self.crop_size = [self.crop_size]

        if len(self.data_dirs) > 0 and len(self.crop_size) != len(self.data_dirs):
            raise ValueError(
                f"If a list of crop sizes or data directories are given,"
                f"they must have the same length but got {len(self.crop_size)} "
                f"and {len(self.data_dirs)}"
            )

        if (
            isinstance(self.anchors, list) and len(self.anchors) == 0
        ) or self.anchors == 0:
            raise ValueError(f"Must provide at least one anchor but got {self.anchors}")

        self.verbose = verbose

        # if self.seed is not None:
        #     np.random.seed(self.seed)

        # load_slp is a wrapper around sio.load_slp for frame gap checks
        self.labels = []
        self.annotated_segments = {}
        for slp_file in self.slp_files:
            labels, annotated_segments = data_utils.load_slp(slp_file)
            self.labels.append(labels)
            self.annotated_segments[slp_file] = annotated_segments

        self.videos = [imageio.get_reader(vid_file) for vid_file in self.vid_files]
        # do we need this? would need to update with sleap-io

        # Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
        # used in call to get_instances()
        self.create_chunks_slp()

    def get_indices(self, idx: int) -> tuple:
        """Retrieve label and frame indices given batch index.

        Args:
            idx: the index of the batch.
        """
        return self.label_idx[idx], self.chunked_frame_idx[idx]

    def get_instances(
        self, label_idx: list[int], frame_idx: torch.Tensor
    ) -> list[Frame]:
        """Get an element of the dataset.

        Args:
            label_idx: index of the labels
            frame_idx: indices of the frames to load in to the batch

        Returns:
            A list of `dreem.io.Frame` objects containing metadata and instance data for the batch/clip.

        """
        sleap_labels_obj = self.labels[label_idx]
        video_name = self.video_files[label_idx]

        # get the correct crop size based on the video
        video_par_path = Path(video_name).parent
        if len(self.data_dirs) > 0:
            crop_size = self.crop_size[0]
            for j, data_dir in enumerate(self.data_dirs):
                if Path(data_dir) == video_par_path:
                    crop_size = self.crop_size[j]
                    break
        else:
            crop_size = self.crop_size[0]

        vid_reader = self.videos[label_idx]

        skeleton = sleap_labels_obj.skeletons[-1]

        frames = []
        max_crop_h, max_crop_w = 0, 0
        for i, frame_ind in enumerate(frame_idx):
            (
                instances,
                gt_track_ids,
                poses,
                shown_poses,
                point_scores,
                instance_score,
            ) = ([], [], [], [], [], [])

            frame_ind = int(frame_ind)

            # sleap-io method for indexing a Labels() object based on the frame's index
            lf = sleap_labels_obj[(sleap_labels_obj.video, frame_ind)]
            if frame_ind != lf.frame_idx:
                logger.warning(f"Frame index mismatch: {frame_ind} != {lf.frame_idx}")

            try:
                img = vid_reader.get_data(int(frame_ind))
            except IndexError as e:
                logger.warning(
                    f"Could not read frame {frame_ind} from {video_name} due to {e}"
                )
                continue

            if len(img.shape) == 2:
                img = img.expand_dims(-1)
            h, w, c = img.shape

            if c == 1:
                img = np.concatenate(
                    [img, img, img], axis=-1
                )  # convert to grayscale to rgb

            if np.issubdtype(img.dtype, np.integer):  # convert int to float
                img = img.astype(np.float32)
                if self.normalize_image:
                    img = img / 255

            n_instances_dropped = 0

            gt_instances = []
            # don't load instances that have been 'greyed out' i.e. all nans for keypoints
            for inst in lf.instances:
                pts = np.array([p for p in inst.numpy()])
                if np.isnan(pts).all():
                    continue
                else:
                    gt_instances.append(inst)

            dict_instances = {}
            no_track_instances = []
            for instance in gt_instances:
                if instance.track is not None:
                    gt_track_id = sleap_labels_obj.tracks.index(instance.track)
                    if gt_track_id not in dict_instances:
                        dict_instances[gt_track_id] = instance
                    else:
                        existing_instance = dict_instances[gt_track_id]
                        # if existing is PredictedInstance and current is not, then current is a UserInstance and should be used
                        if isinstance(
                            existing_instance, sio.PredictedInstance
                        ) and not isinstance(instance, sio.PredictedInstance):
                            dict_instances[gt_track_id] = instance
                else:
                    no_track_instances.append(instance)

            gt_instances = list(dict_instances.values()) + no_track_instances

            if self.mode == "train":
                np.random.shuffle(gt_instances)

            for instance in gt_instances:
                if (
                    np.random.uniform() < self.instance_dropout["p"]
                    and n_instances_dropped < self.instance_dropout["n"]
                ):
                    n_instances_dropped += 1
                    continue

                if instance.track is not None:
                    gt_track_id = sleap_labels_obj.tracks.index(instance.track)
                else:
                    gt_track_id = -1
                gt_track_ids.append(gt_track_id)

                poses.append(
                    dict(
                        zip(
                            [n.name for n in instance.skeleton.nodes],
                            [p for p in instance.numpy()],
                        )
                    )
                )

                shown_poses = [
                    {
                        key: val
                        for key, val in instance.items()
                        if not np.isnan(val).any()
                    }
                    for instance in poses
                ]

                point_scores.append(
                    np.array(
                        [
                            (
                                1.0  # point scores not reliably available in sleap io PredictedPointsArray
                                # point.score
                                # if isinstance(point, sio.PredictedPoint)
                                # else 1.0
                            )
                            for point in instance.numpy()
                        ]
                    )
                )
                if isinstance(instance, sio.PredictedInstance):
                    instance_score.append(instance.score)
                else:
                    instance_score.append(1.0)
            # augmentations
            if self.augmentations is not None:
                for transform in self.augmentations:
                    if isinstance(transform, A.CoarseDropout):
                        transform.fill_value = random.randint(0, 255)

                if shown_poses:
                    keypoints = np.vstack([list(s.values()) for s in shown_poses])

                else:
                    keypoints = []

                augmented = self.augmentations(image=img, keypoints=keypoints)

                img, aug_poses = augmented["image"], augmented["keypoints"]

                aug_poses = [
                    arr
                    for arr in np.split(
                        np.array(aug_poses),
                        np.array([len(s) for s in shown_poses]).cumsum(),
                    )
                    if arr.size != 0
                ]

                aug_poses = [
                    dict(zip(list(pose_dict.keys()), aug_pose_arr.tolist()))
                    for aug_pose_arr, pose_dict in zip(aug_poses, shown_poses)
                ]

                _ = [
                    pose.update(aug_pose)
                    for pose, aug_pose in zip(shown_poses, aug_poses)
                ]

            img = tvf.to_tensor(img)

            for j in range(len(gt_track_ids)):
                pose = shown_poses[j]

                """Check for anchor"""
                crops = []
                boxes = []
                centroids = {}

                if isinstance(self.anchors, int):
                    anchors_to_choose = list(pose.keys()) + ["midpoint"]
                    anchors = np.random.choice(anchors_to_choose, self.anchors)
                else:
                    anchors = self.anchors

                dropped_anchors = self.node_dropout(anchors)

                for anchor in anchors:
                    if anchor in dropped_anchors:
                        centroid = np.array([np.nan, np.nan])

                    elif anchor == "midpoint" or anchor == "centroid":
                        centroid = np.nanmean(np.array(list(pose.values())), axis=0)

                    elif anchor in pose:
                        centroid = np.array(pose[anchor])
                        if np.isnan(centroid).any():
                            centroid = np.array([np.nan, np.nan])

                    elif (
                        anchor not in pose
                        and len(anchors) == 1
                        and self.handle_missing == "centroid"
                    ):
                        anchor = "midpoint"
                        centroid = np.nanmean(np.array(list(pose.values())), axis=0)

                    else:
                        centroid = np.array([np.nan, np.nan])

                    if np.isnan(centroid).all():
                        bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan])

                    else:
                        if self.use_tight_bbox and len(pose) > 1:
                            # tight bbox
                            # dont allow this for centroid-only poses!
                            arr_pose = np.array(list(pose.values()))
                            # note bbox will be a different size for each instance; padded at the end of the loop
                            bbox = data_utils.get_tight_bbox(arr_pose)

                        else:
                            bbox = data_utils.pad_bbox(
                                data_utils.get_bbox(centroid, crop_size),
                                padding=self.padding,
                            )

                    if bbox.isnan().all():
                        crop = torch.zeros(
                            c,
                            crop_size + 2 * self.padding,
                            crop_size + 2 * self.padding,
                            dtype=img.dtype,
                        )
                    else:
                        crop = data_utils.crop_bbox(img, bbox)

                    crops.append(crop)
                    # get max h,w for padding for tight bboxes
                    c, h, w = crop.shape
                    if h > max_crop_h:
                        max_crop_h = h
                    if w > max_crop_w:
                        max_crop_w = w

                    centroids[anchor] = centroid
                    boxes.append(bbox)

                if len(crops) > 0:
                    crops = torch.concat(crops, dim=0)

                if len(boxes) > 0:
                    boxes = torch.stack(boxes, dim=0)

                if self.handle_missing == "drop" and boxes.isnan().any():
                    continue

                instance = Instance(
                    gt_track_id=gt_track_ids[j],
                    pred_track_id=-1,
                    crop=crops,
                    centroid=centroids,
                    bbox=boxes,
                    skeleton=skeleton,
                    pose=poses[j],
                    point_scores=point_scores[j],
                    instance_score=instance_score[j],
                )

                instances.append(instance)

            frame = Frame(
                video_id=label_idx,
                frame_id=frame_ind,
                vid_file=video_name,
                img_shape=img.shape,
                instances=instances,
            )
            frames.append(frame)

        # pad bbox to max size
        if self.use_tight_bbox:
            # bound the max crop size to the user defined crop size
            max_crop_h = crop_size if max_crop_h == 0 else min(max_crop_h, crop_size)
            max_crop_w = crop_size if max_crop_w == 0 else min(max_crop_w, crop_size)
            # gather all the crops
            for frame in frames:
                for instance in frame.instances:
                    data_utils.pad_variable_size_crops(
                        instance, (max_crop_h, max_crop_w)
                    )
        return frames

    def __del__(self):
        """Handle file closing before garbage collection."""
        for reader in self.videos:
            reader.close()

__del__()

Handle file closing before garbage collection.

Source code in dreem/datasets/sleap_dataset.py
def __del__(self):
    """Handle file closing before garbage collection."""
    for reader in self.videos:
        reader.close()

__init__(slp_files, video_files, data_dirs=None, padding=5, crop_size=128, anchors='', chunk=True, clip_length=16, mode='train', handle_missing='centroid', augmentations=None, n_chunks=1.0, seed=None, verbose=False, normalize_image=True, max_batching_gap=15, use_tight_bbox=False, **kwargs)

Initialize SleapDataset.

Parameters:

Name Type Description Default
slp_files list[str]

a list of .slp files storing tracking annotations

required
video_files list[str]

a list of paths to video files

required
data_dirs Optional[list[str]]

a path, or a list of paths to data directories. If provided, crop_size should be a list of integers with the same length as data_dirs.

None
padding int

amount of padding around object crops

5
crop_size Union[int, list[int]]

the size of the object crops. Can be either: - An integer specifying a single crop size for all objects - A list of integers specifying different crop sizes for different data directories

128
anchors int | list[str] | str

One of: * a string indicating a single node to center crops around * a list of skeleton node names to be used as the center of crops * an int indicating the number of anchors to randomly select If unavailable then crop around the midpoint between all visible anchors.

''
chunk bool

whether or not to chunk the dataset into batches

True
clip_length int

the number of frames in each chunk

16
mode str

train, val, or test. Determines whether this dataset is used for training, validation/testing/inference.

'train'
handle_missing str

how to handle missing single nodes. one of ["drop", "ignore", "centroid"]. if "drop" then we dont include instances which are missing the anchor. if "ignore" then we use a mask instead of a crop and nan centroids/bboxes. if "centroid" then we default to the pose centroid as the node to crop around.

'centroid'
augmentations dict | None

An optional dict mapping augmentations to parameters. The keys should map directly to augmentation classes in albumentations. Example: augmentations = { 'Rotate': {'limit': [-90, 90], 'p': 0.5}, 'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0, 'p': 0.2}, 'RandomContrast': {'limit': 0.2, 'p': 0.6} }

None
n_chunks int | float

Number of chunks to subsample from. Can either a fraction of the dataset (ie (0,1.0]) or number of chunks

1.0
seed int | None

set a seed for reproducibility

None
verbose bool

boolean representing whether to print

False
normalize_image bool

whether to normalize the image to [0, 1]

True
max_batching_gap int

the max number of frames that can be unlabelled before starting a new batch

15
use_tight_bbox bool

whether to use tight bounding box (around keypoints) instead of the default square bounding box

False
Source code in dreem/datasets/sleap_dataset.py
def __init__(
    self,
    slp_files: list[str],
    video_files: list[str],
    data_dirs: Optional[list[str]] = None,
    padding: int = 5,
    crop_size: Union[int, list[int]] = 128,
    anchors: int | list[str] | str = "",
    chunk: bool = True,
    clip_length: int = 16,
    mode: str = "train",
    handle_missing: str = "centroid",
    augmentations: dict | None = None,
    n_chunks: int | float = 1.0,
    seed: int | None = None,
    verbose: bool = False,
    normalize_image: bool = True,
    max_batching_gap: int = 15,
    use_tight_bbox: bool = False,
    **kwargs,
):
    """Initialize SleapDataset.

    Args:
        slp_files: a list of .slp files storing tracking annotations
        video_files: a list of paths to video files
        data_dirs: a path, or a list of paths to data directories. If provided, crop_size should be a list of integers
            with the same length as data_dirs.
        padding: amount of padding around object crops
        crop_size: the size of the object crops. Can be either:
            - An integer specifying a single crop size for all objects
            - A list of integers specifying different crop sizes for different data directories
        anchors: One of:
                    * a string indicating a single node to center crops around
                    * a list of skeleton node names to be used as the center of crops
                    * an int indicating the number of anchors to randomly select
                If unavailable then crop around the midpoint between all visible anchors.
        chunk: whether or not to chunk the dataset into batches
        clip_length: the number of frames in each chunk
        mode: `train`, `val`, or `test`. Determines whether this dataset is used for
            training, validation/testing/inference.
        handle_missing: how to handle missing single nodes. one of `["drop", "ignore", "centroid"]`.
                        if "drop" then we dont include instances which are missing the `anchor`.
                        if "ignore" then we use a mask instead of a crop and nan centroids/bboxes.
                        if "centroid" then we default to the pose centroid as the node to crop around.
        augmentations: An optional dict mapping augmentations to parameters. The keys
            should map directly to augmentation classes in albumentations. Example:
                augmentations = {
                    'Rotate': {'limit': [-90, 90], 'p': 0.5},
                    'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0, 'p': 0.2},
                    'RandomContrast': {'limit': 0.2, 'p': 0.6}
                }
        n_chunks: Number of chunks to subsample from.
            Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
        seed: set a seed for reproducibility
        verbose: boolean representing whether to print
        normalize_image: whether to normalize the image to [0, 1]
        max_batching_gap: the max number of frames that can be unlabelled before starting a new batch
        use_tight_bbox: whether to use tight bounding box (around keypoints) instead of the default square bounding box
    """
    super().__init__(
        slp_files,
        video_files,
        padding,
        crop_size,
        chunk,
        clip_length,
        mode,
        augmentations,
        n_chunks,
        seed,
    )

    self.slp_files = slp_files
    self.data_dirs = data_dirs
    self.video_files = video_files
    self.padding = padding
    self.crop_size = crop_size
    self.chunk = chunk
    self.clip_length = clip_length
    self.mode = mode.lower()
    self.handle_missing = handle_missing.lower()
    self.n_chunks = n_chunks
    self.seed = seed
    self.normalize_image = normalize_image
    self.max_batching_gap = max_batching_gap
    self.use_tight_bbox = use_tight_bbox

    if isinstance(anchors, int):
        self.anchors = anchors
    elif isinstance(anchors, str):
        self.anchors = [anchors]
    else:
        self.anchors = anchors

    if not isinstance(self.data_dirs, list):
        self.data_dirs = [self.data_dirs]

    if not isinstance(self.crop_size, list):
        # make a list so its handled consistently if multiple crops are used
        if len(self.data_dirs) > 0:  # for test mode, data_dirs is []
            self.crop_size = [self.crop_size] * len(self.data_dirs)
        else:
            self.crop_size = [self.crop_size]

    if len(self.data_dirs) > 0 and len(self.crop_size) != len(self.data_dirs):
        raise ValueError(
            f"If a list of crop sizes or data directories are given,"
            f"they must have the same length but got {len(self.crop_size)} "
            f"and {len(self.data_dirs)}"
        )

    if (
        isinstance(self.anchors, list) and len(self.anchors) == 0
    ) or self.anchors == 0:
        raise ValueError(f"Must provide at least one anchor but got {self.anchors}")

    self.verbose = verbose

    # if self.seed is not None:
    #     np.random.seed(self.seed)

    # load_slp is a wrapper around sio.load_slp for frame gap checks
    self.labels = []
    self.annotated_segments = {}
    for slp_file in self.slp_files:
        labels, annotated_segments = data_utils.load_slp(slp_file)
        self.labels.append(labels)
        self.annotated_segments[slp_file] = annotated_segments

    self.videos = [imageio.get_reader(vid_file) for vid_file in self.vid_files]
    # do we need this? would need to update with sleap-io

    # Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
    # used in call to get_instances()
    self.create_chunks_slp()

get_indices(idx)

Retrieve label and frame indices given batch index.

Parameters:

Name Type Description Default
idx int

the index of the batch.

required
Source code in dreem/datasets/sleap_dataset.py
def get_indices(self, idx: int) -> tuple:
    """Retrieve label and frame indices given batch index.

    Args:
        idx: the index of the batch.
    """
    return self.label_idx[idx], self.chunked_frame_idx[idx]

get_instances(label_idx, frame_idx)

Get an element of the dataset.

Parameters:

Name Type Description Default
label_idx list[int]

index of the labels

required
frame_idx Tensor

indices of the frames to load in to the batch

required

Returns:

Type Description
list[Frame]

A list of dreem.io.Frame objects containing metadata and instance data for the batch/clip.

Source code in dreem/datasets/sleap_dataset.py
def get_instances(
    self, label_idx: list[int], frame_idx: torch.Tensor
) -> list[Frame]:
    """Get an element of the dataset.

    Args:
        label_idx: index of the labels
        frame_idx: indices of the frames to load in to the batch

    Returns:
        A list of `dreem.io.Frame` objects containing metadata and instance data for the batch/clip.

    """
    sleap_labels_obj = self.labels[label_idx]
    video_name = self.video_files[label_idx]

    # get the correct crop size based on the video
    video_par_path = Path(video_name).parent
    if len(self.data_dirs) > 0:
        crop_size = self.crop_size[0]
        for j, data_dir in enumerate(self.data_dirs):
            if Path(data_dir) == video_par_path:
                crop_size = self.crop_size[j]
                break
    else:
        crop_size = self.crop_size[0]

    vid_reader = self.videos[label_idx]

    skeleton = sleap_labels_obj.skeletons[-1]

    frames = []
    max_crop_h, max_crop_w = 0, 0
    for i, frame_ind in enumerate(frame_idx):
        (
            instances,
            gt_track_ids,
            poses,
            shown_poses,
            point_scores,
            instance_score,
        ) = ([], [], [], [], [], [])

        frame_ind = int(frame_ind)

        # sleap-io method for indexing a Labels() object based on the frame's index
        lf = sleap_labels_obj[(sleap_labels_obj.video, frame_ind)]
        if frame_ind != lf.frame_idx:
            logger.warning(f"Frame index mismatch: {frame_ind} != {lf.frame_idx}")

        try:
            img = vid_reader.get_data(int(frame_ind))
        except IndexError as e:
            logger.warning(
                f"Could not read frame {frame_ind} from {video_name} due to {e}"
            )
            continue

        if len(img.shape) == 2:
            img = img.expand_dims(-1)
        h, w, c = img.shape

        if c == 1:
            img = np.concatenate(
                [img, img, img], axis=-1
            )  # convert to grayscale to rgb

        if np.issubdtype(img.dtype, np.integer):  # convert int to float
            img = img.astype(np.float32)
            if self.normalize_image:
                img = img / 255

        n_instances_dropped = 0

        gt_instances = []
        # don't load instances that have been 'greyed out' i.e. all nans for keypoints
        for inst in lf.instances:
            pts = np.array([p for p in inst.numpy()])
            if np.isnan(pts).all():
                continue
            else:
                gt_instances.append(inst)

        dict_instances = {}
        no_track_instances = []
        for instance in gt_instances:
            if instance.track is not None:
                gt_track_id = sleap_labels_obj.tracks.index(instance.track)
                if gt_track_id not in dict_instances:
                    dict_instances[gt_track_id] = instance
                else:
                    existing_instance = dict_instances[gt_track_id]
                    # if existing is PredictedInstance and current is not, then current is a UserInstance and should be used
                    if isinstance(
                        existing_instance, sio.PredictedInstance
                    ) and not isinstance(instance, sio.PredictedInstance):
                        dict_instances[gt_track_id] = instance
            else:
                no_track_instances.append(instance)

        gt_instances = list(dict_instances.values()) + no_track_instances

        if self.mode == "train":
            np.random.shuffle(gt_instances)

        for instance in gt_instances:
            if (
                np.random.uniform() < self.instance_dropout["p"]
                and n_instances_dropped < self.instance_dropout["n"]
            ):
                n_instances_dropped += 1
                continue

            if instance.track is not None:
                gt_track_id = sleap_labels_obj.tracks.index(instance.track)
            else:
                gt_track_id = -1
            gt_track_ids.append(gt_track_id)

            poses.append(
                dict(
                    zip(
                        [n.name for n in instance.skeleton.nodes],
                        [p for p in instance.numpy()],
                    )
                )
            )

            shown_poses = [
                {
                    key: val
                    for key, val in instance.items()
                    if not np.isnan(val).any()
                }
                for instance in poses
            ]

            point_scores.append(
                np.array(
                    [
                        (
                            1.0  # point scores not reliably available in sleap io PredictedPointsArray
                            # point.score
                            # if isinstance(point, sio.PredictedPoint)
                            # else 1.0
                        )
                        for point in instance.numpy()
                    ]
                )
            )
            if isinstance(instance, sio.PredictedInstance):
                instance_score.append(instance.score)
            else:
                instance_score.append(1.0)
        # augmentations
        if self.augmentations is not None:
            for transform in self.augmentations:
                if isinstance(transform, A.CoarseDropout):
                    transform.fill_value = random.randint(0, 255)

            if shown_poses:
                keypoints = np.vstack([list(s.values()) for s in shown_poses])

            else:
                keypoints = []

            augmented = self.augmentations(image=img, keypoints=keypoints)

            img, aug_poses = augmented["image"], augmented["keypoints"]

            aug_poses = [
                arr
                for arr in np.split(
                    np.array(aug_poses),
                    np.array([len(s) for s in shown_poses]).cumsum(),
                )
                if arr.size != 0
            ]

            aug_poses = [
                dict(zip(list(pose_dict.keys()), aug_pose_arr.tolist()))
                for aug_pose_arr, pose_dict in zip(aug_poses, shown_poses)
            ]

            _ = [
                pose.update(aug_pose)
                for pose, aug_pose in zip(shown_poses, aug_poses)
            ]

        img = tvf.to_tensor(img)

        for j in range(len(gt_track_ids)):
            pose = shown_poses[j]

            """Check for anchor"""
            crops = []
            boxes = []
            centroids = {}

            if isinstance(self.anchors, int):
                anchors_to_choose = list(pose.keys()) + ["midpoint"]
                anchors = np.random.choice(anchors_to_choose, self.anchors)
            else:
                anchors = self.anchors

            dropped_anchors = self.node_dropout(anchors)

            for anchor in anchors:
                if anchor in dropped_anchors:
                    centroid = np.array([np.nan, np.nan])

                elif anchor == "midpoint" or anchor == "centroid":
                    centroid = np.nanmean(np.array(list(pose.values())), axis=0)

                elif anchor in pose:
                    centroid = np.array(pose[anchor])
                    if np.isnan(centroid).any():
                        centroid = np.array([np.nan, np.nan])

                elif (
                    anchor not in pose
                    and len(anchors) == 1
                    and self.handle_missing == "centroid"
                ):
                    anchor = "midpoint"
                    centroid = np.nanmean(np.array(list(pose.values())), axis=0)

                else:
                    centroid = np.array([np.nan, np.nan])

                if np.isnan(centroid).all():
                    bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan])

                else:
                    if self.use_tight_bbox and len(pose) > 1:
                        # tight bbox
                        # dont allow this for centroid-only poses!
                        arr_pose = np.array(list(pose.values()))
                        # note bbox will be a different size for each instance; padded at the end of the loop
                        bbox = data_utils.get_tight_bbox(arr_pose)

                    else:
                        bbox = data_utils.pad_bbox(
                            data_utils.get_bbox(centroid, crop_size),
                            padding=self.padding,
                        )

                if bbox.isnan().all():
                    crop = torch.zeros(
                        c,
                        crop_size + 2 * self.padding,
                        crop_size + 2 * self.padding,
                        dtype=img.dtype,
                    )
                else:
                    crop = data_utils.crop_bbox(img, bbox)

                crops.append(crop)
                # get max h,w for padding for tight bboxes
                c, h, w = crop.shape
                if h > max_crop_h:
                    max_crop_h = h
                if w > max_crop_w:
                    max_crop_w = w

                centroids[anchor] = centroid
                boxes.append(bbox)

            if len(crops) > 0:
                crops = torch.concat(crops, dim=0)

            if len(boxes) > 0:
                boxes = torch.stack(boxes, dim=0)

            if self.handle_missing == "drop" and boxes.isnan().any():
                continue

            instance = Instance(
                gt_track_id=gt_track_ids[j],
                pred_track_id=-1,
                crop=crops,
                centroid=centroids,
                bbox=boxes,
                skeleton=skeleton,
                pose=poses[j],
                point_scores=point_scores[j],
                instance_score=instance_score[j],
            )

            instances.append(instance)

        frame = Frame(
            video_id=label_idx,
            frame_id=frame_ind,
            vid_file=video_name,
            img_shape=img.shape,
            instances=instances,
        )
        frames.append(frame)

    # pad bbox to max size
    if self.use_tight_bbox:
        # bound the max crop size to the user defined crop size
        max_crop_h = crop_size if max_crop_h == 0 else min(max_crop_h, crop_size)
        max_crop_w = crop_size if max_crop_w == 0 else min(max_crop_w, crop_size)
        # gather all the crops
        for frame in frames:
            for instance in frame.instances:
                data_utils.pad_variable_size_crops(
                    instance, (max_crop_h, max_crop_w)
                )
    return frames