utils package
Submodules
utils.global_progress_bar module
A rich progress bar that estimates the total time remaining.
Taken from https://github.com/Lightning-AI/pytorch-lightning/issues/3009
- class utils.global_progress_bar.RemainingTimeColumn(style: str | Style)
Bases:
ProgressColumn
Show total remaining time in training.
- render(task) Text
Should return a renderable object.
- class utils.global_progress_bar.BetterProgressBar(refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(description='', progress_bar='#6206E0', progress_bar_finished='#6206E0', progress_bar_pulse='#6206E0', batch_progress='', time='dim', processing_speed='dim underline', metrics='italic', metrics_text_delimiter=' ', metrics_format='.3f'), console_kwargs: dict[str, Any] | None = None)
Bases:
RichProgressBar
A progress bar that estimates the total time remaining.
utils.logging module
CLI logging defaults.
utils.prediction_writer module
Module for writing the predicted images with masks to the output directory.
- class utils.prediction_writer.MaskImageWriter(loading_mode: LoadingMode, output_dir: str | None = None, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'epoch', inv_transform: InverseNormalize = InverseNormalize(mean=[-2.117902994155884, -2.0357131958007812, -1.8044434785842896], std=[4.366810321807861, 4.464283466339111, 4.444442272186279], inplace=False), format: Literal['apng', 'tiff', 'gif', 'webp', 'png'] = 'gif', raw_masks: bool = False, uncertainty: bool = False, drawn_classes: tuple[int, ...] = (0, 1, 2, 3), output: bool = True)
Bases:
BasePredictionWriter
Writes the predicted images with masks to the output directory.
- write_on_epoch_end(trainer: Trainer, pl_module: LightningModule, predictions: Sequence[tuple[Tensor, Tensor, list[str]]] | Sequence[Sequence[tuple[Tensor, Tensor, list[str]]]] | Sequence[tuple[Tensor, Tensor, Tensor, list[str]]] | Sequence[Sequence[tuple[Tensor, Tensor, Tensor, list[str]]]], batch_indices: Sequence[Any]) None
Save the predicted images with masks to the output directory.
- Parameters:
trainer – The trainer object.
pl_module – The lightning module.
predictions – The predictions from the model.
batch_indices – The indices of the batch.
utils.types module
Helper module containing type and class definitions.
- class utils.types.InverseNormalize(mean: Sequence[float | int], std: Sequence[float | int])
Bases:
Normalize
Inverses the normalization and returns the reconstructed images in the input.
- class utils.types.ClassificationMode(*values)
Bases:
Enum
The classification mode for the model.
- MULTICLASS_MODE = 1
The model is trained to predict a single class for each pixel.
- MULTILABEL_MODE = 2
The model is trained to predict multiple classes for each pixel.
- BINARY_CLASS_3_MODE = 3
The model is trained to predict a single class for a binary classification task for each pixel.
- MULTICLASS_1_2_MODE = 4
The model is trained to predict the LV myocardium and scar tissue (MI) regions.
- class utils.types.ResidualMode(*values)
Bases:
Enum
The residual frame calculation mode for the model.
- SUBTRACT_NEXT_FRAME = 1
Subtracts the next frame from the current frame.
- OPTICAL_FLOW_CPU = 2
Calculates the optical flow using the CPU.
- OPTICAL_FLOW_GPU = 3
Calculates the optical flow using the GPU.
- class utils.types.LoadingMode(*values)
Bases:
Enum
Determines the image loading mode for the dataset.
- RGB = 1
The images are loaded in RGB mode.
- GREYSCALE = 2
The images are loaded in greyscale mode.
- class utils.types.ModelType(*values)
Bases:
Enum
Model architecture types.
- UNET = 1
U-Net architecture.
- UNET_PLUS_PLUS = 2
UNet++ architecture.
- TRANS_UNET = 3
TransUNet architecture.
utils.utils module
Utility functions for the project.
- class utils.utils.LightningGradualWarmupScheduler(optimizer: Optimizer, multiplier: int = 2, total_epoch: int = 5, T_max=50, after_scheduler=None)
Bases:
LRScheduler
Gradually warm-up(increasing) learning rate in optimizer.
- step(epoch=None, metrics=None)
Perform a step.
- utils.utils.get_classification_mode(mode: str) ClassificationMode
Get the classification mode from a string input.
- Parameters:
mode – The classification mode
- Raises:
KeyError – If the mode is not an implemented mode.
- utils.utils.get_residual_mode(mode: str) ResidualMode
Get the residual calculation mode from a string input.
- Parameters:
mode – The residual calculation mode.
- Raises:
KeyError – If the mode is not an implemented mode.
- utils.utils.get_loading_mode(mode: str) LoadingMode
Get the classification mode from a string input.
- Parameters:
mode – The classification mode
- Raises:
KeyError – If the mode is not an implemented mode.
- utils.utils.get_checkpoint_filename(version: str | None) str | None
Get the checkpoint filename with the version name if it exists.
- Parameters:
version – The version name of the model.
- utils.utils.get_best_weighted_avg_dice_filename(version: str | None) str
Get the filename for the best weighted average dice score.
- Parameters:
version – The version name of the model.
- utils.utils.get_best_macro_avg_dice_class_2_3_filename(version: str | None) str
Get the filename for the best macro average dice score for class 2 & 3.
- Parameters:
version – The version name of the model.
- utils.utils.get_last_checkpoint_filename(version: str | None) str | None
Get the filename for the last checkpoint.
- Parameters:
version – The version name of the model.
- utils.utils.get_version_name(ckpt_path: str | None) str | None
Get the version name from the checkpoint path.
- Parameters:
ckpt_path – The path to the checkpoint.
- utils.utils.configure_optimizers(module: CommonModelMixin) dict[str, Optimizer | dict[str, Any]]
Configure the optimizer and learning rate scheduler for the model.
- Parameters:
module – The LightningModule instance.
- utils.utils.get_transforms(loading_mode: LoadingMode, augment: bool = False) tuple[Compose, Compose, Compose]
Get default transformations for all datasets.
The default implementation resizes the images to (224, 224), casts them to float32, normalises them, and sets them to greyscale if the loading mode is not RGB.
- Parameters:
loading_mode – The loading mode for the images.
augment – Whether to augment the images and masks together.
- Returns:
The image, mask, combined, and final resize transformations
- utils.utils.get_accumulate_grad_batches(devices: int, batch_size: int) int
Get the number of batches to accumulate the gradients.
- Parameters:
devices – Number of devices for training.
batch_size – The batch size for training.
- Returns:
The number of batches to accumulate the gradients
- Raises:
AssertionError – If effective batch size of 8 not divisible by batch size * devices
Module contents
Contains the utility functions and classes for the project.