models.attention package
Subpackages
- models.attention.urr package
- Submodules
- models.attention.urr.lightning_module module
URRResidualAttentionLightningModule
URRResidualAttentionLightningModule.batch_size
URRResidualAttentionLightningModule.in_channels
URRResidualAttentionLightningModule.classes
URRResidualAttentionLightningModule.num_frames
URRResidualAttentionLightningModule.dump_memory_snapshot
URRResidualAttentionLightningModule.dummy_predict
URRResidualAttentionLightningModule.residual_mode
URRResidualAttentionLightningModule.optimizer
URRResidualAttentionLightningModule.optimizer_kwargs
URRResidualAttentionLightningModule.scheduler
URRResidualAttentionLightningModule.scheduler_kwargs
URRResidualAttentionLightningModule.loading_mode
URRResidualAttentionLightningModule.multiplier
URRResidualAttentionLightningModule.total_epochs
URRResidualAttentionLightningModule.alpha
URRResidualAttentionLightningModule.beta
URRResidualAttentionLightningModule.learning_rate
URRResidualAttentionLightningModule.urr_source
URRResidualAttentionLightningModule.uncertainty_mode
URRResidualAttentionLightningModule.weights_from_ckpt_path
URRResidualAttentionLightningModule.forward()
URRResidualAttentionLightningModule.training_step()
URRResidualAttentionLightningModule.predict_step()
- models.attention.urr.model module
- models.attention.urr.segmentation_model module
- models.attention.urr.utils module
- Module contents
URRResidualAttentionUnet
URRResidualAttentionUnetPlusPlus
URRResidualAttentionLightningModule
URRResidualAttentionLightningModule.forward()
URRResidualAttentionLightningModule.predict_step()
URRResidualAttentionLightningModule.training_step()
URRResidualAttentionLightningModule.dl_classification_mode
URRResidualAttentionLightningModule.eval_classification_mode
URRResidualAttentionLightningModule.dice_metrics
URRResidualAttentionLightningModule.other_metrics
URRResidualAttentionLightningModule.hausdorff_metrics
URRResidualAttentionLightningModule.infarct_metrics
URRResidualAttentionLightningModule.model
URRResidualAttentionLightningModule.model_type
URRResidualAttentionLightningModule.de_transform
URRResidualAttentionLightningModule.classes
URRResidualAttentionLightningModule.optimizer
URRResidualAttentionLightningModule.optimizer_kwargs
URRResidualAttentionLightningModule.total_epochs
URRResidualAttentionLightningModule.scheduler
URRResidualAttentionLightningModule.scheduler_kwargs
URRResidualAttentionLightningModule.prepare_data_per_node
URRResidualAttentionLightningModule.allow_zero_length_dataloader_with_multiple_devices
URRResidualAttentionLightningModule.training
URRResidualAttentionLightningModule.batch_size
URRResidualAttentionLightningModule.in_channels
URRResidualAttentionLightningModule.num_frames
URRResidualAttentionLightningModule.dump_memory_snapshot
URRResidualAttentionLightningModule.dummy_predict
URRResidualAttentionLightningModule.residual_mode
URRResidualAttentionLightningModule.loading_mode
URRResidualAttentionLightningModule.multiplier
URRResidualAttentionLightningModule.alpha
URRResidualAttentionLightningModule.beta
URRResidualAttentionLightningModule.learning_rate
URRResidualAttentionLightningModule.urr_source
URRResidualAttentionLightningModule.uncertainty_mode
URRResidualAttentionLightningModule.weights_from_ckpt_path
Submodules
models.attention.lightning_module module
LightningModule wrappers for U-Net with Attention mechanism on residual frames.
- class models.attention.lightning_module.ResidualAttentionLightningModule(batch_size: int, metric: Metric | None = None, loss: Module | str | None = None, model_type: ModelType = ModelType.UNET, encoder_name: str = 'resnet34', encoder_depth: int = 5, encoder_weights: str | None = 'imagenet', in_channels: int = 3, classes: int = 1, num_frames: Literal[5, 10, 15, 20, 30] = 5, weights_from_ckpt_path: str | None = None, optimizer: Optimizer | str = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: LRScheduler | str = 'gradual_warmup_scheduler', scheduler_kwargs: dict[str, Any] | None = None, multiplier: int = 2, total_epochs: int = 50, alpha: float = 1.0, _beta: float = 0.0, learning_rate: float = 0.0001, dl_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, eval_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, residual_mode: ResidualMode = ResidualMode.SUBTRACT_NEXT_FRAME, loading_mode: LoadingMode = LoadingMode.RGB, dump_memory_snapshot: bool = False, flat_conv: bool = False, unet_activation: str | None = None, attention_reduction: Literal['sum', 'prod', 'cat', 'weighted', 'weighted_learnable'] = 'sum', attention_only: bool = False, dummy_predict: DummyPredictMode = DummyPredictMode.NONE, temporal_conv_type: TemporalConvolutionalType = TemporalConvolutionalType.ORIGINAL, metric_mode: MetricMode = MetricMode.INCLUDE_EMPTY_CLASS, metric_div_zero: float = 1.0, single_attention_instance: bool = False)
Bases:
CommonModelMixin
Attention mechanism-based U-Net.
- batch_size
Batch size of dataloader.
- in_channels
Number of image channels.
- num_frames
Number of frames used.
- dump_memory_snapshot
Whether to dump a memory snapshot.
- dummy_predict
Whether to simply return the ground truth for visualisation.
- residual_mode
Residual frames generation mode.
- loading_mode
Image loading mode.
- multiplier
Learning rate multiplier.
- alpha
Loss scaling factor.
- learning_rate
Learning rate for training.
- weights_from_ckpt_path
Model checkpoint path to load weights from.
- log_metrics(prefix) None
Implement shared metric logging epoch end here.
Note: This is to prevent circular imports with the logging module.
- training_step(batch: tuple[Tensor, Tensor, Tensor, str], batch_idx: int) Tensor
Forward pass for the model with dataloader batches.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
- Returns:
Training loss.
- Raises:
AssertionError – Prediction shape and ground truth mask shapes are different.
- validation_step(batch: tuple[Tensor, Tensor, Tensor, str], batch_idx: int)
Forward pass for the model for one minibatch of a validation epoch.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
- test_step(batch: tuple[Tensor, Tensor, Tensor, str], batch_idx: int)
Forward pass for the model for one minibatch of a test epoch.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
- predict_step(batch: tuple[Tensor, Tensor, Tensor, str | list[str]], batch_idx: int, dataloader_idx: int = 0) tuple[Tensor, Tensor, str | list[str]]
Forward pass for the model for one minibatch of a test epoch.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
dataloader_idx – Index of the dataloader.
- Returns:
Mask predictions, original images, and filename.
- configure_optimizers()
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated", # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
Some things to know:
Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()
method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
models.attention.model module
Implementation of residual frames-based attention layers.
- class models.attention.model.AttentionLayer(embed_dim: int, num_frames: int, num_heads: int = 1, key_embed_dim: int | None = None, value_embed_dim: int | None = None, need_weights: bool = False, reduce: Literal['sum', 'prod', 'cat', 'weighted', 'weighted_learnable'] = 'sum', one_instance: bool = False)
Bases:
Module
Attention mechanism between spatio-temporal and spatial embeddings.
As the spatial dimensions of the image can be considered the sequence to be processed, the channel dimension must be the embedding dimension for each part of Q, K, and V tensors.
- forward(q: Tensor, ks: Tensor, vs: Tensor) Tensor
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class models.attention.model.SpatialAttentionBlock(temporal_conv: OneD | DilatedOneD | Temporal3DConv | None, attention: AttentionLayer, num_frames: int, reduce: Literal['sum', 'prod', 'cat', 'weighted', 'weighted_learnable'], reduce_dim: int = 0, one_instance: bool = False, _attention_only: bool = False)
Bases:
Module
Residual block with attention mechanism between spatio-temporal embeddings.
- forward(st_embeddings: Tensor, res_embeddings: Tensor, return_o1: bool = False) Tensor | tuple[Tensor, Tensor]
Forward pass of the residual block.
- Parameters:
st_embeddings – Spatio-temporal embeddings from raw frames.
res_embeddings – Spatial embeddings from residual frames.
return_o1 – Whether to return o1.
models.attention.segmentation_model module
Implementation of residual frames-based U-Net and U-Net++ architectures.
- class models.attention.segmentation_model.ResidualAttentionUnet(*args, **kwargs)
Bases:
SegmentationModel
U-Net with Attention mechanism on residual frames.
- check_input_shape(x)
Check if the input shape is divisible by the output stride. If not, raise a RuntimeError.
- property encoder
Get the encoder of the model.
- forward(regular_frames: Tensor, residual_frames: Tensor) Tensor
Forward pass of the model.
- Parameters:
regular_frames – Regular frames from the sequence.
residual_frames – Residual frames from the sequence.
- Returns:
Predicted mask logits.
- predict(regular_frames: Tensor, residual_frames: Tensor) Tensor
Inference method. Switch model to eval mode, call .forward(x) with torch.no_grad()
- Parameters:
x – 4D torch tensor with shape (batch_size, channels, height, width)
- Returns:
4D torch tensor with shape (batch_size, classes, height, width)
- Return type:
prediction
- class models.attention.segmentation_model.ResidualAttentionUnetPlusPlus(*args, **kwargs)
Bases:
ResidualAttentionUnet
U-Net++ with Attention mechanism on residual frames.
models.attention.utils module
Helper classes and typedefs for the residual frames-based attention models.
Module contents
Residual frames-based attention U-Net and U-Net++ implementation.
- class models.attention.ResidualAttentionLightningModule(batch_size: int, metric: Metric | None = None, loss: Module | str | None = None, model_type: ModelType = ModelType.UNET, encoder_name: str = 'resnet34', encoder_depth: int = 5, encoder_weights: str | None = 'imagenet', in_channels: int = 3, classes: int = 1, num_frames: Literal[5, 10, 15, 20, 30] = 5, weights_from_ckpt_path: str | None = None, optimizer: Optimizer | str = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: LRScheduler | str = 'gradual_warmup_scheduler', scheduler_kwargs: dict[str, Any] | None = None, multiplier: int = 2, total_epochs: int = 50, alpha: float = 1.0, _beta: float = 0.0, learning_rate: float = 0.0001, dl_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, eval_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, residual_mode: ResidualMode = ResidualMode.SUBTRACT_NEXT_FRAME, loading_mode: LoadingMode = LoadingMode.RGB, dump_memory_snapshot: bool = False, flat_conv: bool = False, unet_activation: str | None = None, attention_reduction: Literal['sum', 'prod', 'cat', 'weighted', 'weighted_learnable'] = 'sum', attention_only: bool = False, dummy_predict: DummyPredictMode = DummyPredictMode.NONE, temporal_conv_type: TemporalConvolutionalType = TemporalConvolutionalType.ORIGINAL, metric_mode: MetricMode = MetricMode.INCLUDE_EMPTY_CLASS, metric_div_zero: float = 1.0, single_attention_instance: bool = False)
Bases:
CommonModelMixin
Attention mechanism-based U-Net.
- configure_optimizers()
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated", # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
Some things to know:
Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()
method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
- log_metrics(prefix) None
Implement shared metric logging epoch end here.
Note: This is to prevent circular imports with the logging module.
- predict_step(batch: tuple[Tensor, Tensor, Tensor, str | list[str]], batch_idx: int, dataloader_idx: int = 0) tuple[Tensor, Tensor, str | list[str]]
Forward pass for the model for one minibatch of a test epoch.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
dataloader_idx – Index of the dataloader.
- Returns:
Mask predictions, original images, and filename.
- test_step(batch: tuple[Tensor, Tensor, Tensor, str], batch_idx: int)
Forward pass for the model for one minibatch of a test epoch.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
- training_step(batch: tuple[Tensor, Tensor, Tensor, str], batch_idx: int) Tensor
Forward pass for the model with dataloader batches.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
- Returns:
Training loss.
- Raises:
AssertionError – Prediction shape and ground truth mask shapes are different.
- validation_step(batch: tuple[Tensor, Tensor, Tensor, str], batch_idx: int)
Forward pass for the model for one minibatch of a validation epoch.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
- dl_classification_mode: ClassificationMode
Classification mode for the dataloader instances.
- eval_classification_mode: ClassificationMode
Classification mode for the evaluation process.
- other_metrics: dict[str, MetricCollection]
A collection of other metrics (recall, precision, jaccard).
- model: nn.Module
The internal model used.
- de_transform: Compose | InverseNormalize
The inverse transformation from augmentation of the samples by the dataloaders.
- batch_size
Batch size of dataloader.
- in_channels
Number of image channels.
- num_frames
Number of frames used.
- dump_memory_snapshot
Whether to dump a memory snapshot.
- dummy_predict
Whether to simply return the ground truth for visualisation.
- residual_mode
Residual frames generation mode.
- loading_mode
Image loading mode.
- multiplier
Learning rate multiplier.
- alpha
Loss scaling factor.
- learning_rate
Learning rate for training.
- weights_from_ckpt_path
Model checkpoint path to load weights from.