DREEM Models¶
User-facing models¶
There are two main model APIs users should interact with.
GlobalTrackingTransformer
is the underlying model architecture we use for tracking. It is made up of aVisualEncoder
and aTransformer
Encoder-Decoder
. Only more advanced users who have familiarity with python and pytorch should interact with this model. For others see belowGTRRunner
is apytorch_lightning
around theGlobalTrackingTransformer
. It implements the basic routines you need for training, validation and testing. Most users will interact with this model.
Model Parts¶
For advanced users who are interested in extending our model, we have modularized each component so that its easy to compose into your own custom model. The model parts are
VisualEncoder
: A CNN backbone used for feature extraction.Transformer
which is composed of a:- SpatioTemporal
Embedding
which computes the spatial and temporal embedding of each detection. TransformerEncoder
: A stack ofTransformerEncoderLayer
sTransformerDecoder
: A stack ofTransformerDecoderLayer
s
- SpatioTemporal
- An
AttentionHead
which computes the association matrix from the transformer output.