models.attention.urr package

Submodules

models.attention.urr.lightning_module module

LightningModule wrappers for U-Net with attention mechanism and URR.

class models.attention.urr.lightning_module.URRResidualAttentionLightningModule(batch_size: int, metric: Metric | None = None, loss: Module | str | None = None, model_type: ModelType = ModelType.UNET, encoder_name: str = 'resnet34', encoder_depth: int = 5, encoder_weights: str | None = 'imagenet', in_channels: int = 3, classes: int = 1, num_frames: Literal[5, 10, 15, 20, 30] = 5, weights_from_ckpt_path: str | None = None, optimizer: Optimizer | str = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: LRScheduler | str = 'gradual_warmup_scheduler', scheduler_kwargs: dict[str, Any] | None = None, multiplier: int = 2, total_epochs: int = 50, alpha: float = 0.95, beta: float = 0.05, learning_rate: float = 0.0001, dl_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, eval_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, residual_mode: ResidualMode = ResidualMode.SUBTRACT_NEXT_FRAME, loading_mode: LoadingMode = LoadingMode.RGB, dump_memory_snapshot: bool = False, flat_conv: bool = False, unet_activation: str | None = None, attention_reduction: Literal['sum', 'prod', 'cat', 'weighted', 'weighted_learnable'] = 'sum', attention_only: bool = False, dummy_predict: DummyPredictMode = DummyPredictMode.NONE, temporal_conv_type: TemporalConvolutionalType = TemporalConvolutionalType.ORIGINAL, urr_source: URRSource = URRSource.O3, uncertainty_mode: UncertaintyMode = UncertaintyMode.URR, metric_mode: MetricMode = MetricMode.INCLUDE_EMPTY_CLASS, metric_div_zero: float = 1.0, single_attention_instance: bool = False, show_r2_plots: bool = False)

Bases: ResidualAttentionLightningModule

Lightning Module wrapper for attention U-Nets with URR.

batch_size

Batch size of dataloader.

in_channels

Number of image channels.

classes: int

Number of segmentation classes.

num_frames

Number of frames used.

dump_memory_snapshot

Whether to dump a memory snapshot.

dummy_predict

Whether to simply return the ground truth for visualisation.

residual_mode

Residual frames generation mode.

optimizer: type[Optimizer] | Optimizer | str

Optimizer for training.

optimizer_kwargs: dict[str, Any]

Optimizer kwargs.

scheduler: type[LRScheduler] | LRScheduler | str

Scheduler for training.

scheduler_kwargs: dict[str, Any]

Scheduler kwargs.

loading_mode

Image loading mode.

multiplier

Learning rate multiplier.

total_epochs: int

Number of total epochs for training.

alpha

Loss scaling factor for segmentation loss.

beta

Loss scaling factor for confidence loss.

learning_rate

Learning rate for training.

urr_source

URR low level features source.

uncertainty_mode

Whether to include uncertain-regions refinement or to just use confidence loss.

weights_from_ckpt_path

Model checkpoint path to load weights from.

forward(x_img: Tensor, x_res: Tensor) tuple[Tensor, Tensor, Tensor]

Forward pass of the model.

training_step(batch: tuple[Tensor, Tensor, Tensor, str], batch_idx: int)

Forward pass for the model with dataloader batches.

Parameters:
  • batch – Batch of frames, residual frames, masks, and filenames.

  • batch_idx – Index of the batch in the epoch.

Returns:

Training loss.

Raises:

AssertionError – Prediction shape and ground truth mask shapes are different.

predict_step(batch: tuple[Tensor, Tensor, Tensor, str | list[str]], batch_idx: int, dataloader_idx: int = 0)

Forward pass for the model for one minibatch of a test epoch.

Parameters:
  • batch – Batch of frames, residual frames, masks, and filenames.

  • batch_idx – Index of the batch in the epoch.

  • dataloader_idx – Index of the dataloader.

Returns:

Mask predictions, original images, and filename.

models.attention.urr.model module

Implementation of uncertain region refinement for residual frames-based U-Net and U-Net++ architectures.

class models.attention.urr.model.RegionRefiner(local_size: int, mdim_global: int, mdim_local: int, classes: int, attn_reduce: Literal['sum', 'prod', 'cat', 'weighted', 'weighted_learnable'] = 'sum')

Bases: Module

Region Refinement adapted from AFB-URR.

forward(unet_decoder_output: Tensor, spatial_encoder_r1: Tensor) tuple[Tensor, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

models.attention.urr.segmentation_model module

Implementation of URR-based attention module compatible with Segmentation Models PyTorch.

class models.attention.urr.segmentation_model.UnetDecoderURR(encoder_channels, decoder_channels, n_blocks=5, use_batchnorm=True, attention_type=None, center=False)

Bases: UnetDecoder

U-Net decoder adapted to return upconv decoder layer outputs.

forward(*features) tuple[Tensor, list[Tensor]]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class models.attention.urr.segmentation_model.UnetPlusPlusDecoderURR(encoder_channels, decoder_channels, n_blocks=5, use_batchnorm=True, attention_type=None, center=False)

Bases: UnetPlusPlusDecoder

U-Net++ decoder adapted to return upconv decoder layer outputs.

forward(*features) tuple[Tensor, list[Tensor]]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class models.attention.urr.segmentation_model.URRDecoder(decoder: UnetDecoderURR | UnetPlusPlusDecoderURR, segmentation_head: SegmentationHead, refiner: RegionRefiner, num_classes: int, uncertainty_mode: UncertaintyMode)

Bases: Module

Wrapper for the decoder, segmentation head, and region refiner.

forward(features: Sequence[Tensor], low_level_feature: Tensor | None = None) tuple[Tensor, Tensor | None, Tensor, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class models.attention.urr.segmentation_model.URRResidualAttentionUnet(*args, **kwargs)

Bases: ResidualAttentionUnet

Uncertain region refinement for U-Net with attention mechanism.

forward(regular_frames: Tensor, residual_frames: Tensor) tuple[Tensor, Tensor | None, Tensor, Tensor]

Forward pass of the U-Net model.

Parameters:
  • regular_frames – Normal cine CMR sequence.

  • residual_frames – Residual cine CMR sequence.

Returns:

Predicted mask probabilities, initial uncertainty, final uncertainty, confidence loss.

class models.attention.urr.segmentation_model.URRResidualAttentionUnetPlusPlus(*args, **kwargs)

Bases: URRResidualAttentionUnet

U-Net++ with attention mechanism and uncertain region refinement.

models.attention.urr.utils module

Utility functions for URR-augmented attention model.

models.attention.urr.utils.calc_uncertainty(score: Tensor) Tensor

Calculate uncertainty.

Parameters:

score – Rough segmentation mask.

Returns:

Uncertainty map.

Raises:

RuntimeError – If dimensions of score are wrong or some other runtime error occurs.

class models.attention.urr.utils.URRSource(*values)

Bases: Enum

Source for generating the low level feature maps for URR.

O1 = 1

Output of temporal convolution.

O3 = 2

Aggregated output of temporal convolution and attention mechanism

class models.attention.urr.utils.UncertaintyMode(*values)

Bases: Enum

What form of UR/URR to use.

UR = 1

Uncertain regions only.

URR = 2

Uncertain regions and refinement.

Module contents

Model architecture for Attention with URR mechanism.

class models.attention.urr.URRResidualAttentionUnet(*args, **kwargs)

Bases: ResidualAttentionUnet

Uncertain region refinement for U-Net with attention mechanism.

forward(regular_frames: Tensor, residual_frames: Tensor) tuple[Tensor, Tensor | None, Tensor, Tensor]

Forward pass of the U-Net model.

Parameters:
  • regular_frames – Normal cine CMR sequence.

  • residual_frames – Residual cine CMR sequence.

Returns:

Predicted mask probabilities, initial uncertainty, final uncertainty, confidence loss.

class models.attention.urr.URRResidualAttentionUnetPlusPlus(*args, **kwargs)

Bases: URRResidualAttentionUnet

U-Net++ with attention mechanism and uncertain region refinement.

class models.attention.urr.URRResidualAttentionLightningModule(batch_size: int, metric: Metric | None = None, loss: Module | str | None = None, model_type: ModelType = ModelType.UNET, encoder_name: str = 'resnet34', encoder_depth: int = 5, encoder_weights: str | None = 'imagenet', in_channels: int = 3, classes: int = 1, num_frames: Literal[5, 10, 15, 20, 30] = 5, weights_from_ckpt_path: str | None = None, optimizer: Optimizer | str = 'adamw', optimizer_kwargs: dict[str, Any] | None = None, scheduler: LRScheduler | str = 'gradual_warmup_scheduler', scheduler_kwargs: dict[str, Any] | None = None, multiplier: int = 2, total_epochs: int = 50, alpha: float = 0.95, beta: float = 0.05, learning_rate: float = 0.0001, dl_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, eval_classification_mode: ClassificationMode = ClassificationMode.MULTICLASS_MODE, residual_mode: ResidualMode = ResidualMode.SUBTRACT_NEXT_FRAME, loading_mode: LoadingMode = LoadingMode.RGB, dump_memory_snapshot: bool = False, flat_conv: bool = False, unet_activation: str | None = None, attention_reduction: Literal['sum', 'prod', 'cat', 'weighted', 'weighted_learnable'] = 'sum', attention_only: bool = False, dummy_predict: DummyPredictMode = DummyPredictMode.NONE, temporal_conv_type: TemporalConvolutionalType = TemporalConvolutionalType.ORIGINAL, urr_source: URRSource = URRSource.O3, uncertainty_mode: UncertaintyMode = UncertaintyMode.URR, metric_mode: MetricMode = MetricMode.INCLUDE_EMPTY_CLASS, metric_div_zero: float = 1.0, single_attention_instance: bool = False, show_r2_plots: bool = False)

Bases: ResidualAttentionLightningModule

Lightning Module wrapper for attention U-Nets with URR.

forward(x_img: Tensor, x_res: Tensor) tuple[Tensor, Tensor, Tensor]

Forward pass of the model.

predict_step(batch: tuple[Tensor, Tensor, Tensor, str | list[str]], batch_idx: int, dataloader_idx: int = 0)

Forward pass for the model for one minibatch of a test epoch.

Parameters:
  • batch – Batch of frames, residual frames, masks, and filenames.

  • batch_idx – Index of the batch in the epoch.

  • dataloader_idx – Index of the dataloader.

Returns:

Mask predictions, original images, and filename.

training_step(batch: tuple[Tensor, Tensor, Tensor, str], batch_idx: int)

Forward pass for the model with dataloader batches.

Parameters:
  • batch – Batch of frames, residual frames, masks, and filenames.

  • batch_idx – Index of the batch in the epoch.

Returns:

Training loss.

Raises:

AssertionError – Prediction shape and ground truth mask shapes are different.

dl_classification_mode: ClassificationMode

Classification mode for the dataloader instances.

eval_classification_mode: ClassificationMode

Classification mode for the evaluation process.

dice_metrics: dict[str, MetricCollection | Metric]

A collection of dice score variants.

other_metrics: dict[str, MetricCollection]

A collection of other metrics (recall, precision, jaccard).

hausdorff_metrics: dict[str, MetricCollection]

Just hausdorff distance metrics.

infarct_metrics: dict[str, MetricCollection]

A collection of infarct-related clinical heuristics.

model: nn.Module

The internal model used.

model_type: ModelType

The architecture of the model, if appropriate.

de_transform: Compose | InverseNormalize

The inverse transformation from augmentation of the samples by the dataloaders.

classes: int

Number of segmentation classes.

optimizer: type[Optimizer] | Optimizer | str

Optimizer for training.

optimizer_kwargs: dict[str, Any]

Optimizer kwargs.

total_epochs: int

Number of total epochs for training.

scheduler: type[LRScheduler] | LRScheduler | str

Scheduler for training.

scheduler_kwargs: dict[str, Any]

Scheduler kwargs.

prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
training: bool
batch_size

Batch size of dataloader.

in_channels

Number of image channels.

num_frames

Number of frames used.

dump_memory_snapshot

Whether to dump a memory snapshot.

dummy_predict

Whether to simply return the ground truth for visualisation.

residual_mode

Residual frames generation mode.

loading_mode

Image loading mode.

multiplier

Learning rate multiplier.

alpha

Loss scaling factor for segmentation loss.

beta

Loss scaling factor for confidence loss.

learning_rate

Learning rate for training.

urr_source

URR low level features source.

uncertainty_mode

Whether to include uncertain-regions refinement or to just use confidence loss.

weights_from_ckpt_path

Model checkpoint path to load weights from.