models.fusion.stn package
Submodules
models.fusion.stn.model module
Spatial transformer pytorch modules.
- class models.fusion.stn.model.SpatialTransformer(in_channels: int, input_is_3d: bool = False)
Bases:
Module
Spatial transformer module.
- class models.fusion.stn.model.STN(transformer_type: SpatialTransformerType, num_param: int, N: int, xdim: int | None = None, **kwargs)
Bases:
Module
Base class for a Spatial Transformer Network (STN).
Adapted from: https://github.com/FrederikWarburg/pSTN-baselines. Implement this for specific tasks, which requires the method init_localiser to be implemented.
- init_localiser(**kwargs)
Initialise task-specific localiser.
- init_model_weights(**kwargs)
Initialise model weights.
- forward(x: Tensor, x_high_res: Tensor | None = None)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
models.fusion.stn.utils module
Utilities for the feature fusion modules.
- models.fusion.stn.utils.init_transformer(transformer_type: SpatialTransformerType, N: int, num_param: int, xdim: int | None = None)
Initialise the spatial transformer.
- Parameters:
transformer_type – Type of spatial transformer.
N – Number of parallel tracks.
num_param – If we use an affine (s, r, tx, ty) or crop (0.5, 1, tx, ty) transformation.
xdim – Indicator of time seeries datasets. 1 if timeseries, otherwise 2.
- class models.fusion.stn.utils.SpatialTransformerType(*values)
Bases:
Enum
Enum for spatial transformer type.
- AFFINE = 1
Affine transformer.
- DIFFEOMORPHIC = 2
Diffeomorphic transformer.
- class models.fusion.stn.utils.AffineTransformer(*args, **kwargs)
Bases:
Module
Affine spatial transformer.
- forward(x: Tensor, params: Tensor, small_image_shape: int | Tuple[int, int])
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class models.fusion.stn.utils.DiffeomorphicTransformer(N: int, num_param: int, xdim: int)
Bases:
Module
Diffeomorphic spatial transformer.
- forward(x: Tensor, params: Tensor)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Module contents
Spatial Transformer Network implementation for fusion models.