utils package

Submodules

utils.global_progress_bar module

A rich progress bar that estimates the total time remaining.

Taken from https://github.com/Lightning-AI/pytorch-lightning/issues/3009

class utils.global_progress_bar.RemainingTimeColumn(style: str | Style)

Bases: ProgressColumn

Show total remaining time in training.

max_refresh: float | None = 1.0
render(task) Text

Should return a renderable object.

class utils.global_progress_bar.BetterProgressBar(refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(description='', progress_bar='#6206E0', progress_bar_finished='#6206E0', progress_bar_pulse='#6206E0', batch_progress='', time='dim', processing_speed='dim underline', metrics='italic', metrics_text_delimiter=' ', metrics_format='.3f'), console_kwargs: dict[str, Any] | None = None)

Bases: RichProgressBar

A progress bar that estimates the total time remaining.

configure_columns(trainer) list

utils.logging module

CLI logging defaults.

utils.prediction_writer module

Module for writing the predicted images with masks to the output directory.

class utils.prediction_writer.MaskImageWriter(loading_mode: LoadingMode, output_dir: str | None = None, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'epoch', inv_transform: InverseNormalize = InverseNormalize(mean=[-2.117902994155884, -2.0357131958007812, -1.8044434785842896], std=[4.366810321807861, 4.464283466339111, 4.444442272186279], inplace=False), format: Literal['apng', 'tiff', 'gif', 'webp', 'png'] = 'gif', raw_masks: bool = False, uncertainty: bool = False, drawn_classes: tuple[int, ...] = (0, 1, 2, 3), output: bool = True)

Bases: BasePredictionWriter

Writes the predicted images with masks to the output directory.

write_on_epoch_end(trainer: Trainer, pl_module: LightningModule, predictions: Sequence[tuple[Tensor, Tensor, list[str]]] | Sequence[Sequence[tuple[Tensor, Tensor, list[str]]]] | Sequence[tuple[Tensor, Tensor, Tensor, list[str]]] | Sequence[Sequence[tuple[Tensor, Tensor, Tensor, list[str]]]], batch_indices: Sequence[Any]) None

Save the predicted images with masks to the output directory.

Parameters:
  • trainer – The trainer object.

  • pl_module – The lightning module.

  • predictions – The predictions from the model.

  • batch_indices – The indices of the batch.

utils.prediction_writer.get_output_dir_from_ckpt_path(ckpt_path: str | None)

Get the output directory from the checkpoint path.

utils.types module

Helper module containing type and class definitions.

class utils.types.InverseNormalize(mean: Sequence[float | int], std: Sequence[float | int])

Bases: Normalize

Inverses the normalization and returns the reconstructed images in the input.

class utils.types.ClassificationMode(*values)

Bases: Enum

The classification mode for the model.

MULTICLASS_MODE = 1

The model is trained to predict a single class for each pixel.

MULTILABEL_MODE = 2

The model is trained to predict multiple classes for each pixel.

BINARY_CLASS_3_MODE = 3

The model is trained to predict a single class for a binary classification task for each pixel.

MULTICLASS_1_2_MODE = 4

The model is trained to predict the LV myocardium and scar tissue (MI) regions.

num_classes() int

Get the number of classes assosciated with the classification mode.

Returns:

Number of classes.

Return type:

int

class utils.types.ResidualMode(*values)

Bases: Enum

The residual frame calculation mode for the model.

SUBTRACT_NEXT_FRAME = 1

Subtracts the next frame from the current frame.

OPTICAL_FLOW_CPU = 2

Calculates the optical flow using the CPU.

OPTICAL_FLOW_GPU = 3

Calculates the optical flow using the GPU.

class utils.types.LoadingMode(*values)

Bases: Enum

Determines the image loading mode for the dataset.

RGB = 1

The images are loaded in RGB mode.

GREYSCALE = 2

The images are loaded in greyscale mode.

class utils.types.ModelType(*values)

Bases: Enum

Model architecture types.

UNET = 1

U-Net architecture.

UNET_PLUS_PLUS = 2

UNet++ architecture.

TRANS_UNET = 3

TransUNet architecture.

class utils.types.MetricMode(*values)

Bases: Enum

Metric calculation mode.

INCLUDE_EMPTY_CLASS = 1

Includes samples with no instances of class.

IGNORE_EMPTY_CLASS = 2

Ignores samples with no instances of class for metrics for that class.

class utils.types.DummyPredictMode(*values)

Bases: Enum

Dummy prediction mode.

NONE = 1

No-op.

GROUND_TRUTH = 2

Outputs the ground truth masks.

BLANK = 3

Outputs only the images.

utils.utils module

Utility functions for the project.

class utils.utils.LightningGradualWarmupScheduler(optimizer: Optimizer, multiplier: int = 2, total_epoch: int = 5, T_max=50, after_scheduler=None)

Bases: LRScheduler

Gradually warm-up(increasing) learning rate in optimizer.

step(epoch=None, metrics=None)

Perform a step.

utils.utils.get_classification_mode(mode: str) ClassificationMode

Get the classification mode from a string input.

Parameters:

mode – The classification mode

Raises:

KeyError – If the mode is not an implemented mode.

utils.utils.get_residual_mode(mode: str) ResidualMode

Get the residual calculation mode from a string input.

Parameters:

mode – The residual calculation mode.

Raises:

KeyError – If the mode is not an implemented mode.

utils.utils.get_loading_mode(mode: str) LoadingMode

Get the classification mode from a string input.

Parameters:

mode – The classification mode

Raises:

KeyError – If the mode is not an implemented mode.

utils.utils.get_checkpoint_filename(version: str | None) str | None

Get the checkpoint filename with the version name if it exists.

Parameters:

version – The version name of the model.

utils.utils.get_best_weighted_avg_dice_filename(version: str | None) str

Get the filename for the best weighted average dice score.

Parameters:

version – The version name of the model.

utils.utils.get_best_macro_avg_dice_class_2_3_filename(version: str | None) str

Get the filename for the best macro average dice score for class 2 & 3.

Parameters:

version – The version name of the model.

utils.utils.get_last_checkpoint_filename(version: str | None) str | None

Get the filename for the last checkpoint.

Parameters:

version – The version name of the model.

utils.utils.get_version_name(ckpt_path: str | None) str | None

Get the version name from the checkpoint path.

Parameters:

ckpt_path – The path to the checkpoint.

utils.utils.configure_optimizers(module: CommonModelMixin) dict[str, Optimizer | dict[str, Any]]

Configure the optimizer and learning rate scheduler for the model.

Parameters:

module – The LightningModule instance.

utils.utils.get_transforms(loading_mode: LoadingMode, augment: bool = False) tuple[Compose, Compose, Compose]

Get default transformations for all datasets.

The default implementation resizes the images to (224, 224), casts them to float32, normalises them, and sets them to greyscale if the loading mode is not RGB.

Parameters:
  • loading_mode – The loading mode for the images.

  • augment – Whether to augment the images and masks together.

Returns:

The image, mask, combined, and final resize transformations

utils.utils.get_accumulate_grad_batches(devices: int, batch_size: int) int

Get the number of batches to accumulate the gradients.

Parameters:
  • devices – Number of devices for training.

  • batch_size – The batch size for training.

Returns:

The number of batches to accumulate the gradients

Raises:

AssertionError – If effective batch size of 8 not divisible by batch size * devices

utils.utils.get_model_type(enum_str: str) ModelType

Get enum from input string.

Parameters:

enum_str – String to match with enum.

Returns:

Resultant enum variant.

Raises:

KeyError – If the variant requested is not found.

Module contents

Contains the utility functions and classes for the project.