models.attention.urr package
Submodules
models.attention.urr.lightning_module module
LightningModule wrappers for U-Net with attention mechanism and URR.
- class models.attention.urr.lightning_module.URRResidualAttentionLightningModule(batch_size: int, metric: Metric | None = None, loss: Module | str | None = None, model_type: ModelType = ModelType.UNET, encoder_name: str = 'resnet34', encoder_depth: int = 5, encoder_weights: str | None = 'imagenet', in_channels: int = 3, classes: int = 1, num_frames: Literal[5, 10, 15, 20, 30] = 5, weights_from_ckpt_path: str | None = None, optimizer: Optimizer | str = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: LRScheduler | str = 'gradual_warmup_scheduler', scheduler_kwargs: dict[str, Any] | None = None, multiplier: int = 2, total_epochs: int = 50, alpha: float = 0.95, beta: float = 0.05, learning_rate: float = 0.0001, dl_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, eval_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, residual_mode: ResidualMode = ResidualMode.SUBTRACT_NEXT_FRAME, loading_mode: LoadingMode = LoadingMode.RGB, dump_memory_snapshot: bool = False, flat_conv: bool = False, unet_activation: str | None = None, attention_reduction: Literal['sum', 'prod', 'cat', 'weighted', 'weighted_learnable'] = 'sum', attention_only: bool = False, dummy_predict: DummyPredictMode = DummyPredictMode.NONE, temporal_conv_type: TemporalConvolutionalType = TemporalConvolutionalType.ORIGINAL, urr_source: URRSource = URRSource.O3, uncertainty_mode: UncertaintyMode = UncertaintyMode.URR, metric_mode: MetricMode = MetricMode.INCLUDE_EMPTY_CLASS, metric_div_zero: float = 1.0, single_attention_instance: bool = False, show_r2_plots: bool = False)
Bases:
ResidualAttentionLightningModule
Lightning Module wrapper for attention U-Nets with URR.
- batch_size
Batch size of dataloader.
- in_channels
Number of image channels.
- num_frames
Number of frames used.
- dump_memory_snapshot
Whether to dump a memory snapshot.
- dummy_predict
Whether to simply return the ground truth for visualisation.
- residual_mode
Residual frames generation mode.
- loading_mode
Image loading mode.
- multiplier
Learning rate multiplier.
- alpha
Loss scaling factor for segmentation loss.
- beta
Loss scaling factor for confidence loss.
- learning_rate
Learning rate for training.
- urr_source
URR low level features source.
- uncertainty_mode
Whether to include uncertain-regions refinement or to just use confidence loss.
- weights_from_ckpt_path
Model checkpoint path to load weights from.
- training_step(batch: tuple[Tensor, Tensor, Tensor, str], batch_idx: int)
Forward pass for the model with dataloader batches.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
- Returns:
Training loss.
- Raises:
AssertionError – Prediction shape and ground truth mask shapes are different.
- predict_step(batch: tuple[Tensor, Tensor, Tensor, str | list[str]], batch_idx: int, dataloader_idx: int = 0)
Forward pass for the model for one minibatch of a test epoch.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
dataloader_idx – Index of the dataloader.
- Returns:
Mask predictions, original images, and filename.
models.attention.urr.model module
Implementation of uncertain region refinement for residual frames-based U-Net and U-Net++ architectures.
- class models.attention.urr.model.RegionRefiner(local_size: int, mdim_global: int, mdim_local: int, classes: int, attn_reduce: Literal['sum', 'prod', 'cat', 'weighted', 'weighted_learnable'] = 'sum')
Bases:
Module
Region Refinement adapted from AFB-URR.
- forward(unet_decoder_output: Tensor, spatial_encoder_r1: Tensor) tuple[Tensor, Tensor]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
models.attention.urr.segmentation_model module
Implementation of URR-based attention module compatible with Segmentation Models PyTorch.
- class models.attention.urr.segmentation_model.UnetDecoderURR(encoder_channels, decoder_channels, n_blocks=5, use_batchnorm=True, attention_type=None, center=False)
Bases:
UnetDecoder
U-Net decoder adapted to return upconv decoder layer outputs.
- forward(*features) tuple[Tensor, list[Tensor]]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class models.attention.urr.segmentation_model.UnetPlusPlusDecoderURR(encoder_channels, decoder_channels, n_blocks=5, use_batchnorm=True, attention_type=None, center=False)
Bases:
UnetPlusPlusDecoder
U-Net++ decoder adapted to return upconv decoder layer outputs.
- forward(*features) tuple[Tensor, list[Tensor]]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class models.attention.urr.segmentation_model.URRDecoder(decoder: UnetDecoderURR | UnetPlusPlusDecoderURR, segmentation_head: SegmentationHead, refiner: RegionRefiner, num_classes: int, uncertainty_mode: UncertaintyMode)
Bases:
Module
Wrapper for the decoder, segmentation head, and region refiner.
- forward(features: Sequence[Tensor], low_level_feature: Tensor | None = None) tuple[Tensor, Tensor | None, Tensor, Tensor]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class models.attention.urr.segmentation_model.URRResidualAttentionUnet(*args, **kwargs)
Bases:
ResidualAttentionUnet
Uncertain region refinement for U-Net with attention mechanism.
- forward(regular_frames: Tensor, residual_frames: Tensor) tuple[Tensor, Tensor | None, Tensor, Tensor]
Forward pass of the U-Net model.
- Parameters:
regular_frames – Normal cine CMR sequence.
residual_frames – Residual cine CMR sequence.
- Returns:
Predicted mask probabilities, initial uncertainty, final uncertainty, confidence loss.
- class models.attention.urr.segmentation_model.URRResidualAttentionUnetPlusPlus(*args, **kwargs)
Bases:
URRResidualAttentionUnet
U-Net++ with attention mechanism and uncertain region refinement.
models.attention.urr.utils module
Utility functions for URR-augmented attention model.
- models.attention.urr.utils.calc_uncertainty(score: Tensor) Tensor
Calculate uncertainty.
- Parameters:
score – Rough segmentation mask.
- Returns:
Uncertainty map.
- Raises:
RuntimeError – If dimensions of score are wrong or some other runtime error occurs.
Module contents
Model architecture for Attention with URR mechanism.
- class models.attention.urr.URRResidualAttentionUnet(*args, **kwargs)
Bases:
ResidualAttentionUnet
Uncertain region refinement for U-Net with attention mechanism.
- forward(regular_frames: Tensor, residual_frames: Tensor) tuple[Tensor, Tensor | None, Tensor, Tensor]
Forward pass of the U-Net model.
- Parameters:
regular_frames – Normal cine CMR sequence.
residual_frames – Residual cine CMR sequence.
- Returns:
Predicted mask probabilities, initial uncertainty, final uncertainty, confidence loss.
- class models.attention.urr.URRResidualAttentionUnetPlusPlus(*args, **kwargs)
Bases:
URRResidualAttentionUnet
U-Net++ with attention mechanism and uncertain region refinement.
- class models.attention.urr.URRResidualAttentionLightningModule(batch_size: int, metric: Metric | None = None, loss: Module | str | None = None, model_type: ModelType = ModelType.UNET, encoder_name: str = 'resnet34', encoder_depth: int = 5, encoder_weights: str | None = 'imagenet', in_channels: int = 3, classes: int = 1, num_frames: Literal[5, 10, 15, 20, 30] = 5, weights_from_ckpt_path: str | None = None, optimizer: Optimizer | str = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: LRScheduler | str = 'gradual_warmup_scheduler', scheduler_kwargs: dict[str, Any] | None = None, multiplier: int = 2, total_epochs: int = 50, alpha: float = 0.95, beta: float = 0.05, learning_rate: float = 0.0001, dl_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, eval_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, residual_mode: ResidualMode = ResidualMode.SUBTRACT_NEXT_FRAME, loading_mode: LoadingMode = LoadingMode.RGB, dump_memory_snapshot: bool = False, flat_conv: bool = False, unet_activation: str | None = None, attention_reduction: Literal['sum', 'prod', 'cat', 'weighted', 'weighted_learnable'] = 'sum', attention_only: bool = False, dummy_predict: DummyPredictMode = DummyPredictMode.NONE, temporal_conv_type: TemporalConvolutionalType = TemporalConvolutionalType.ORIGINAL, urr_source: URRSource = URRSource.O3, uncertainty_mode: UncertaintyMode = UncertaintyMode.URR, metric_mode: MetricMode = MetricMode.INCLUDE_EMPTY_CLASS, metric_div_zero: float = 1.0, single_attention_instance: bool = False, show_r2_plots: bool = False)
Bases:
ResidualAttentionLightningModule
Lightning Module wrapper for attention U-Nets with URR.
- predict_step(batch: tuple[Tensor, Tensor, Tensor, str | list[str]], batch_idx: int, dataloader_idx: int = 0)
Forward pass for the model for one minibatch of a test epoch.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
dataloader_idx – Index of the dataloader.
- Returns:
Mask predictions, original images, and filename.
- training_step(batch: tuple[Tensor, Tensor, Tensor, str], batch_idx: int)
Forward pass for the model with dataloader batches.
- Parameters:
batch – Batch of frames, residual frames, masks, and filenames.
batch_idx – Index of the batch in the epoch.
- Returns:
Training loss.
- Raises:
AssertionError – Prediction shape and ground truth mask shapes are different.
- dl_classification_mode: ClassificationMode
Classification mode for the dataloader instances.
- eval_classification_mode: ClassificationMode
Classification mode for the evaluation process.
- other_metrics: dict[str, MetricCollection]
A collection of other metrics (recall, precision, jaccard).
- model: nn.Module
The internal model used.
- de_transform: Compose | InverseNormalize
The inverse transformation from augmentation of the samples by the dataloaders.
- batch_size
Batch size of dataloader.
- in_channels
Number of image channels.
- num_frames
Number of frames used.
- dump_memory_snapshot
Whether to dump a memory snapshot.
- dummy_predict
Whether to simply return the ground truth for visualisation.
- residual_mode
Residual frames generation mode.
- loading_mode
Image loading mode.
- multiplier
Learning rate multiplier.
- alpha
Loss scaling factor for segmentation loss.
- beta
Loss scaling factor for confidence loss.
- learning_rate
Learning rate for training.
- urr_source
URR low level features source.
- uncertainty_mode
Whether to include uncertain-regions refinement or to just use confidence loss.
- weights_from_ckpt_path
Model checkpoint path to load weights from.