metrics package

Submodules

metrics.dice module

Dice score metrics.

class metrics.dice.GeneralizedDiceScoreVariant(num_classes: int, include_background: bool = True, per_class: bool = False, weight_type: Literal['square', 'simple', 'linear'] = 'square', weighted_average: bool = False, only_for_classes: list[bool] | list[int] | None = None, return_type: Literal['weighted_avg', 'macro_avg', 'per_class'] = 'weighted_avg', dist_sync_on_step: bool = False, zero_division: float = 1.0, metric_mode: MetricMode = MetricMode.INCLUDE_EMPTY_CLASS, **kwargs: Any)

Bases: GeneralizedDiceScore

Generalized Dice score metric with additional options.

class_occurrences: Tensor
score_running: Tensor
macro_avg_metric: Tensor
per_class_metric: Tensor
weighted_avg_metric: Tensor
count: Tensor
class_weights: Tensor
compute() Tensor

Compute the final generalized dice score.

update(preds: Tensor, target: Tensor) None

Update the state with new data.

metrics.hausdorff module

Hausdorff distance metric.

class metrics.hausdorff.HausdorffDistanceVariant(num_classes: int, include_classes: list[int], distance_metric: Literal['euclidean', 'chessboard', 'taxicab'] = 'euclidean', spacing: Tensor | list[float] | None = None, directed: bool = False, input_format: Literal['one-hot', 'index'] = 'one-hot', zero_division: float = 1.0, **kwargs: Mapping[Any, Any])

Bases: HausdorffDistance

Hausdorff Distance class which ignores +inf values.

is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
plot_lower_bound: float = 0.0
score: Tensor
total: Tensor
update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

compute() Tensor

Compute final Hausdorff distance over states.

metrics.hausdorff.hausdorff_distance_variant(preds: Tensor, target: Tensor, num_classes: int, include_classes: list[int], distance_metric: Literal['euclidean', 'chessboard', 'taxicab'] = 'euclidean', spacing: Tensor | list[float] | None = None, directed: bool = False, input_format: Literal['one-hot', 'index'] = 'one-hot') Tensor

Calculate Hausdorff Distance for semantic segmentation.

See hausdorff_distance()

Parameters:
  • preds – predicted binarized segmentation map

  • target – target binarized segmentation map

  • num_classes – Number of classes.

  • include_classes – Which classes to include in the computation.

  • distance_metric – Type of distance metric to use.

  • spacing – spacing between pixels along each spatial dimension. If not provided the spacing is assumed to be 1.

  • directed – Whether to calculate directed or undirected Hausdorff distance.

  • input_format – What kind of input the function receives. Choose between "one-hot" for one-hot encoded tensors or "index" for index tensors.

metrics.infarct module

Clinical metrics for myocardial infarction.

class metrics.infarct.InfarctResults(infarct_area: Tensor, ratio: Tensor, span: ndarray[tuple[int, ...], dtype[_ScalarType_co]], transmurality: ndarray[tuple[int, ...], dtype[_ScalarType_co]])

Bases: object

Computed infarct results.

Contains the infarct area in pixels, extent of myocardial infarction as a ratio to the LV myocardium area, span of the myocardial infarction in radians, and transmurality of the myocardial infarct region within its occupying span of the LV myocardium. These metrics may be batched.

infarct_area

Area of the infarct in pixels.

Type:

torch.Tensor

ratio

Extent of myocardial infarct as a ratio to the LV myocardium area.

Type:

torch.Tensor

span

Occupying span of the myocardial infarction in radians.

Type:

numpy.ndarray[tuple[int, …], numpy.dtype[numpy._typing._array_like._ScalarType_co]]

transmurality

Extent of myocardial infarct as a ratio to its occupying span of the LV myocardium.

Type:

numpy.ndarray[tuple[int, …], numpy.dtype[numpy._typing._array_like._ScalarType_co]]

infarct_area: Tensor
ratio: Tensor
span: ndarray[tuple[int, ...], dtype[_ScalarType_co]]
transmurality: ndarray[tuple[int, ...], dtype[_ScalarType_co]]
is_close(other: Self) bool

Check if this result and another result are close.

to_tensor() Tensor

Cast to tensor of size (bs, 4).

Returns:

Tensor of shape (bs, 4). Order of elements is infarct area, ratio, span, and transmurality.

Return type:

Tensor

class metrics.infarct.InfarctHeuristics(lv_index: int = 1, infarct_index: int = 2)

Bases: Module

Infarct heuristic calculation torch module.

forward(segmentation_mask: Tensor, lv_index: int | None = None, infarct_index: int | None = None) InfarctResults

Compute infarct metrics for scar tissue and LV myocardium.

Parameters:
  • segmentation_mask – Ground truth or predicted mask in one-hot format.

  • lv_index – Mask index of the LV myocardium.

  • infarct_index – Mask index of the myocardial infarction.

Returns:

Computed metrics.

Return type:

InfarctMetrics

class metrics.infarct.InfarctVisualisation(classification_mode: ClassificationMode, lv_index: int = 1, infarct_index: int = 2, **kwargs)

Bases: object

Visualise the infarct clinical metrics.

viz(cine_image: Tensor, segmentation_mask: Tensor, output_raw_annotation: bool = False, _debug: bool = False) Image

Visualise the input with masks and annotations.

Parameters:
  • cine_image – Tensor of cine image sequence.

  • segmentation_mask – Segmentation mask of cine input.

  • output_raw_annotation – Whether to output just the annotation.

Returns:

Annotated image or raw annotation.

Return type:

Image.Image

metrics.jaccard module

Compute mJaccard-Index.

class metrics.jaccard.MulticlassMJaccardIndex(num_classes: int, average: Literal['macro', 'none'] | None, ignore_index: int | None = None, validate_args: bool = True, zero_division: float = 0, metric_mode: MetricMode = MetricMode.INCLUDE_EMPTY_CLASS, **kwargs: Any)

Bases: MulticlassJaccardIndex

Calculate the mJaccard Index for multiclass tasks.

mJaccard_running: Tensor
samples: Tensor
compute()

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

class metrics.jaccard.MultilabelMJaccardIndex(num_labels: int, threshold: float = 0.5, average: Literal['macro', 'none'] | None = 'macro', ignore_index: int | None = None, validate_args: bool = True, **kwargs: Any)

Bases: MultilabelJaccardIndex

Calculate the mJaccard Index for multilabel tasks.

mJaccard_running: Tensor
samples: Tensor
compute()

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

class metrics.jaccard.BinaryMJaccardIndex(threshold: float = 0.5, ignore_index: int | None = None, validate_args: bool = True, zero_division: float = 1.0, **kwargs: Any)

Bases: BinaryJaccardIndex

Calculate the mJaccard Index for binary tasks.

mJaccard_running: Tensor
samples: Tensor
compute()

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

metrics.logging module

Logging utilities for metrics.

metrics.logging.shared_metric_calculation(module: CommonModelMixin, masks: Tensor, masks_proba: Tensor, prefix: Literal['train', 'val', 'test'] = 'train')

Calculate the metrics for the model.

Parameters:
  • module – The LightningModule instance.

  • images – The input images.

  • masks – The ground truth masks.

  • masks_proba – The predicted masks.

  • prefix – The runtime mode (train, val, test).

metrics.logging.setup_metrics(module: CommonModelMixin, metric: Metric | None, classes: int, metric_mode: MetricMode, division_by_zero: float)

Set up the metrics (dice, jaccard, precision, recall) for the model.

Parameters:
  • module – The LightningModule instance.

  • metric – The metric to use. If None, the default is the GeneralizedDiceScoreVariant.

  • classes – The number of classes in the dataset.

  • metric_mode – Metric calculation mode.

  • division_by_zero – How to handle division by zero operations.

metrics.logging.shared_metric_logging_epoch_end(module: CommonModelMixin, prefix: str)

Log the metrics for the model. This is called at the end of the epoch.

This method only handles the logging of Dice scores.

Parameters:
  • module – The LightningModule instance.

  • prefix – The runtime mode (train, val, test).

metrics.loss module

Implementation of loss functions.

class metrics.loss.StructureLoss(size_average=None, reduce=None, reduction: str = 'mean')

Bases: _Loss

Structure loss using Binary Cross-Entropy.

forward(input: Tensor, target: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class metrics.loss.JointEdgeSegLoss(num_classes: int, edge_weight: float = 0.3, seg_weight: float = 1.0, inv_weight: float = 0.3, att_weight: float = 0.1, size_average=None, reduce=None, reduction: str = 'mean')

Bases: _Loss

Joint Edge + Segmentation Structure Loss for Vivim.

bce2d(input: Tensor, target: Tensor) Tensor

Compute Binary CrossEntropy loss.

edge_attention(input: Tensor, target: Tensor, edge: Tensor) Tensor

Compute edge attention loss.

forward(inputs: tuple[Tensor, Tensor], targets: tuple[Tensor, Tensor]) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class metrics.loss.WeightedDiceLoss(num_classes: int, mode: Literal['binary', 'multiclass', 'multilabel'], weight: Tensor | None = None, classes: List[int] | None = None, log_loss: bool = False, from_logits: bool = True, smooth: float = 0, ignore_index: int | None = None, eps: float = 1e-07, reduction: Literal['mean', 'sum', 'none'] = 'mean')

Bases: DiceLoss, _WeightedLoss

Dice loss with class weights.

forward(y_pred: Tensor, y_true: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

metrics.precision_recall module

Compute mPrecision and mRecall.

class metrics.precision_recall.MulticlassMPrecision(num_classes: int, top_k: int = 1, average: Literal['macro', 'none'] | None = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, metric_mode: MetricMode = MetricMode.IGNORE_EMPTY_CLASS, **kwargs: Any)

Bases: MulticlassPrecision

Calculates mPrecision for multiclass tasks.

mPrecision_running: Tensor
samples: Tensor
compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

class metrics.precision_recall.MultilabelMPrecision(num_labels: int, threshold: float = 0.5, average: Literal['macro', 'none'] | None = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, metric_mode: MetricMode = MetricMode.IGNORE_EMPTY_CLASS, **kwargs: Any)

Bases: MultilabelPrecision

Calculates mPrecision for multilabel tasks.

mPrecision_running: Tensor
samples: Tensor
compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

class metrics.precision_recall.MulticlassMRecall(num_classes: int, top_k: int = 1, average: Literal['macro', 'none'] | None = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, metric_mode: MetricMode = MetricMode.IGNORE_EMPTY_CLASS, **kwargs: Any)

Bases: MulticlassRecall

Compute mRecall for multiclass tasks.

mRecall_running: Tensor
samples: Tensor
compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

class metrics.precision_recall.MultilabelMRecall(num_labels: int, threshold: float = 0.5, average: Literal['macro', 'none'] | None = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, metric_mode: MetricMode = MetricMode.IGNORE_EMPTY_CLASS, **kwargs: Any)

Bases: MultilabelRecall

Calculates mRecall for multilabel tasks.

mRecall_running: Tensor
samples: Tensor
compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

class metrics.precision_recall.MulticlassMF1Score(num_classes: int, top_k: int = 1, average: Literal['macro', 'none'] = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, zero_division: float = 0.0, **kwargs: Any)

Bases: MulticlassF1Score

Compute mF1 Score for multiclass tasks.

mF1_running: Tensor
samples: Tensor
compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

class metrics.precision_recall.MultilabelMF1Score(num_labels: int, threshold: float = 0.5, average: Literal['macro', 'none'] | None = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, zero_division: float = 0.0, **kwargs: Any)

Bases: MultilabelF1Score

Compute mF1 Score for multilabel tasks.

mF1_running: Tensor
samples: Tensor
compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

metrics.utils module

Utility functions for metrics.

Module contents

Metrics implementation for the project.

class metrics.HausdorffDistanceVariant(num_classes: int, include_classes: list[int], distance_metric: Literal['euclidean', 'chessboard', 'taxicab'] = 'euclidean', spacing: Tensor | list[float] | None = None, directed: bool = False, input_format: Literal['one-hot', 'index'] = 'one-hot', zero_division: float = 1.0, **kwargs: Mapping[Any, Any])

Bases: HausdorffDistance

Hausdorff Distance class which ignores +inf values.

compute() Tensor

Compute final Hausdorff distance over states.

full_state_update: bool = False
higher_is_better: bool = False
is_differentiable: bool = True
plot_lower_bound: float = 0.0
update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

score: Tensor
total: Tensor
distance_metric: Literal['euclidean', 'chessboard', 'taxicab']
input_format: Literal['one-hot', 'index']
training: bool
metrics.hausdorff_distance_variant(preds: Tensor, target: Tensor, num_classes: int, include_classes: list[int], distance_metric: Literal['euclidean', 'chessboard', 'taxicab'] = 'euclidean', spacing: Tensor | list[float] | None = None, directed: bool = False, input_format: Literal['one-hot', 'index'] = 'one-hot') Tensor

Calculate Hausdorff Distance for semantic segmentation.

See hausdorff_distance()

Parameters:
  • preds – predicted binarized segmentation map

  • target – target binarized segmentation map

  • num_classes – Number of classes.

  • include_classes – Which classes to include in the computation.

  • distance_metric – Type of distance metric to use.

  • spacing – spacing between pixels along each spatial dimension. If not provided the spacing is assumed to be 1.

  • directed – Whether to calculate directed or undirected Hausdorff distance.

  • input_format – What kind of input the function receives. Choose between "one-hot" for one-hot encoded tensors or "index" for index tensors.

class metrics.InfarctResults(infarct_area: Tensor, ratio: Tensor, span: ndarray[tuple[int, ...], dtype[_ScalarType_co]], transmurality: ndarray[tuple[int, ...], dtype[_ScalarType_co]])

Bases: object

Computed infarct results.

Contains the infarct area in pixels, extent of myocardial infarction as a ratio to the LV myocardium area, span of the myocardial infarction in radians, and transmurality of the myocardial infarct region within its occupying span of the LV myocardium. These metrics may be batched.

infarct_area

Area of the infarct in pixels.

Type:

torch.Tensor

ratio

Extent of myocardial infarct as a ratio to the LV myocardium area.

Type:

torch.Tensor

span

Occupying span of the myocardial infarction in radians.

Type:

numpy.ndarray[tuple[int, …], numpy.dtype[numpy._typing._array_like._ScalarType_co]]

transmurality

Extent of myocardial infarct as a ratio to its occupying span of the LV myocardium.

Type:

numpy.ndarray[tuple[int, …], numpy.dtype[numpy._typing._array_like._ScalarType_co]]

is_close(other: Self) bool

Check if this result and another result are close.

to_tensor() Tensor

Cast to tensor of size (bs, 4).

Returns:

Tensor of shape (bs, 4). Order of elements is infarct area, ratio, span, and transmurality.

Return type:

Tensor

infarct_area: Tensor
ratio: Tensor
span: ndarray[tuple[int, ...], dtype[_ScalarType_co]]
transmurality: ndarray[tuple[int, ...], dtype[_ScalarType_co]]
class metrics.InfarctHeuristics(lv_index: int = 1, infarct_index: int = 2)

Bases: Module

Infarct heuristic calculation torch module.

forward(segmentation_mask: Tensor, lv_index: int | None = None, infarct_index: int | None = None) InfarctResults

Compute infarct metrics for scar tissue and LV myocardium.

Parameters:
  • segmentation_mask – Ground truth or predicted mask in one-hot format.

  • lv_index – Mask index of the LV myocardium.

  • infarct_index – Mask index of the myocardial infarction.

Returns:

Computed metrics.

Return type:

InfarctMetrics

class metrics.InfarctArea(classification_mode: ClassificationMode, lv_index: int = 1, infarct_index: int = 2, plot_type: Literal['default', 'advanced'] = 'default', **kwargs)

Bases: InfarctMetricBase

Computes the R² value of scar tissue as a % of total image size.

update(preds: Tensor, target: Tensor) None

Update state with predictions and target in one-hot encoded form.

Parameters:
  • preds – Prediction tensor.

  • target – Target tensor.

Warning

This will fail in subclasses which call self.preprocessing if the tensors are not one-hot encoded.

class metrics.InfarctAreaRatio(classification_mode: ClassificationMode, lv_index: int = 1, infarct_index: int = 2, plot_type: Literal['default', 'advanced'] = 'default', **kwargs)

Bases: InfarctMetricBase

Computes the R² value of scar tissue area as a % of LV myocardium area.

update(preds: Tensor, target: Tensor) None

Update state with predictions and target in one-hot encoded form.

Parameters:
  • preds – Prediction tensor.

  • target – Target tensor.

Warning

This will fail in subclasses which call self.preprocessing if the tensors are not one-hot encoded.

class metrics.InfarctSpan(classification_mode: ClassificationMode, lv_index: int = 1, infarct_index: int = 2, plot_type: Literal['default', 'advanced'] = 'default', **kwargs)

Bases: InfarctMetricBase

Computes the R² value of the infarct region as a span of the LV myocardium in radians.

update(preds: Tensor, target: Tensor) None

Update state with predictions and target in one-hot encoded form.

Parameters:
  • preds – Prediction tensor.

  • target – Target tensor.

Warning

This will fail in subclasses which call self.preprocessing if the tensors are not one-hot encoded.

class metrics.InfarctTransmuralities(classification_mode: ClassificationMode, lv_index: int = 1, infarct_index: int = 2, plot_type: Literal['default', 'advanced'] = 'default', **kwargs)

Bases: InfarctMetricBase

Computes the R² value of the infarct region as a % of the span it occupies within the LV myocardium.

update(preds: Tensor, target: Tensor) None

Update state with predictions and target in one-hot encoded form.

Parameters:
  • preds – Prediction tensor.

  • target – Target tensor.

Warning

This will fail in subclasses which call self.preprocessing if the tensors are not one-hot encoded.

class metrics.InfarctPredictionWriter(lv_myo_index: int = 1, infarct_index: int = 2, loading_mode: LoadingMode = LoadingMode.GREYSCALE, output_dir: str | None = None, write_interval: Literal['batch', 'epoch', 'batch_and_epoch'] = 'epoch', inv_transform: InverseNormalize = InverseNormalize(mean=[-1.986724853515625], std=[4.424777030944824], inplace=False), format: Literal['apng', 'tiff', 'gif', 'webp', 'png'] = 'gif', output_samples_to_dirs: bool = False, output: bool = True)

Bases: BasePredictionWriter

Prediction writer for infarct visualisation.

write_on_epoch_end(trainer: Trainer, pl_module: LightningModule, predictions: Sequence[tuple[Tensor, Tensor, list[str]]] | Sequence[Sequence[tuple[Tensor, Tensor, list[str]]]] | Sequence[tuple[Tensor, Tensor, Tensor, list[str]]] | Sequence[Sequence[tuple[Tensor, Tensor, Tensor, list[str]]]], batch_indices: Sequence[Any]) None

Override with the logic to write all batches.

metrics.shared_metric_calculation(module: CommonModelMixin, masks: Tensor, masks_proba: Tensor, prefix: Literal['train', 'val', 'test'] = 'train')

Calculate the metrics for the model.

Parameters:
  • module – The LightningModule instance.

  • images – The input images.

  • masks – The ground truth masks.

  • masks_proba – The predicted masks.

  • prefix – The runtime mode (train, val, test).

metrics.setup_metrics(module: CommonModelMixin, metric: Metric | None, classes: int, metric_mode: MetricMode, division_by_zero: float)

Set up the metrics (dice, jaccard, precision, recall) for the model.

Parameters:
  • module – The LightningModule instance.

  • metric – The metric to use. If None, the default is the GeneralizedDiceScoreVariant.

  • classes – The number of classes in the dataset.

  • metric_mode – Metric calculation mode.

  • division_by_zero – How to handle division by zero operations.

metrics.shared_metric_logging_epoch_end(module: CommonModelMixin, prefix: str)

Log the metrics for the model. This is called at the end of the epoch.

This method only handles the logging of Dice scores.

Parameters:
  • module – The LightningModule instance.

  • prefix – The runtime mode (train, val, test).

class metrics.MulticlassMJaccardIndex(num_classes: int, average: Literal['macro', 'none'] | None, ignore_index: int | None = None, validate_args: bool = True, zero_division: float = 0, metric_mode: MetricMode = MetricMode.INCLUDE_EMPTY_CLASS, **kwargs: Any)

Bases: MulticlassJaccardIndex

Calculate the mJaccard Index for multiclass tasks.

compute()

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

mJaccard_running: Tensor
samples: Tensor
average: Literal['macro', 'none'] | None
confmat: Tensor
training: bool
class metrics.MultilabelMJaccardIndex(num_labels: int, threshold: float = 0.5, average: Literal['macro', 'none'] | None = 'macro', ignore_index: int | None = None, validate_args: bool = True, **kwargs: Any)

Bases: MultilabelJaccardIndex

Calculate the mJaccard Index for multilabel tasks.

compute()

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

mJaccard_running: Tensor
samples: Tensor
average: Literal['macro', 'none'] | None
confmat: Tensor
training: bool
class metrics.StructureLoss(size_average=None, reduce=None, reduction: str = 'mean')

Bases: _Loss

Structure loss using Binary Cross-Entropy.

forward(input: Tensor, target: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class metrics.JointEdgeSegLoss(num_classes: int, edge_weight: float = 0.3, seg_weight: float = 1.0, inv_weight: float = 0.3, att_weight: float = 0.1, size_average=None, reduce=None, reduction: str = 'mean')

Bases: _Loss

Joint Edge + Segmentation Structure Loss for Vivim.

bce2d(input: Tensor, target: Tensor) Tensor

Compute Binary CrossEntropy loss.

edge_attention(input: Tensor, target: Tensor, edge: Tensor) Tensor

Compute edge attention loss.

forward(inputs: tuple[Tensor, Tensor], targets: tuple[Tensor, Tensor]) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class metrics.WeightedDiceLoss(num_classes: int, mode: Literal['binary', 'multiclass', 'multilabel'], weight: Tensor | None = None, classes: List[int] | None = None, log_loss: bool = False, from_logits: bool = True, smooth: float = 0, ignore_index: int | None = None, eps: float = 1e-07, reduction: Literal['mean', 'sum', 'none'] = 'mean')

Bases: DiceLoss, _WeightedLoss

Dice loss with class weights.

forward(y_pred: Tensor, y_true: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class metrics.GeneralizedDiceScoreVariant(num_classes: int, include_background: bool = True, per_class: bool = False, weight_type: Literal['square', 'simple', 'linear'] = 'square', weighted_average: bool = False, only_for_classes: list[bool] | list[int] | None = None, return_type: Literal['weighted_avg', 'macro_avg', 'per_class'] = 'weighted_avg', dist_sync_on_step: bool = False, zero_division: float = 1.0, metric_mode: MetricMode = MetricMode.INCLUDE_EMPTY_CLASS, **kwargs: Any)

Bases: GeneralizedDiceScore

Generalized Dice score metric with additional options.

compute() Tensor

Compute the final generalized dice score.

update(preds: Tensor, target: Tensor) None

Update the state with new data.

class_occurrences: Tensor
score_running: Tensor
macro_avg_metric: Tensor
per_class_metric: Tensor
weighted_avg_metric: Tensor
class_weights: Tensor
count: Tensor
return_type: Literal['weighted_avg', 'macro_avg', 'per_class']
score: Tensor
samples: Tensor
training: bool
class metrics.MulticlassMPrecision(num_classes: int, top_k: int = 1, average: Literal['macro', 'none'] | None = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, metric_mode: MetricMode = MetricMode.IGNORE_EMPTY_CLASS, **kwargs: Any)

Bases: MulticlassPrecision

Calculates mPrecision for multiclass tasks.

compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

mPrecision_running: Tensor
samples: Tensor
average: Literal['macro', 'none'] | None
multidim_average: Literal['global', 'samplewise']
tp: List[Tensor] | Tensor
fp: List[Tensor] | Tensor
tn: List[Tensor] | Tensor
fn: List[Tensor] | Tensor
training: bool
class metrics.MulticlassMRecall(num_classes: int, top_k: int = 1, average: Literal['macro', 'none'] | None = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, metric_mode: MetricMode = MetricMode.IGNORE_EMPTY_CLASS, **kwargs: Any)

Bases: MulticlassRecall

Compute mRecall for multiclass tasks.

compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

mRecall_running: Tensor
samples: Tensor
average: Literal['macro', 'none'] | None
multidim_average: Literal['global', 'samplewise']
tp: List[Tensor] | Tensor
fp: List[Tensor] | Tensor
tn: List[Tensor] | Tensor
fn: List[Tensor] | Tensor
training: bool
class metrics.MulticlassMF1Score(num_classes: int, top_k: int = 1, average: Literal['macro', 'none'] = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, zero_division: float = 0.0, **kwargs: Any)

Bases: MulticlassF1Score

Compute mF1 Score for multiclass tasks.

compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

mF1_running: Tensor
samples: Tensor
average: Literal['macro', 'none']
multidim_average: Literal['global', 'samplewise']
tp: List[Tensor] | Tensor
fp: List[Tensor] | Tensor
tn: List[Tensor] | Tensor
fn: List[Tensor] | Tensor
training: bool
class metrics.MultilabelMPrecision(num_labels: int, threshold: float = 0.5, average: Literal['macro', 'none'] | None = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, metric_mode: MetricMode = MetricMode.IGNORE_EMPTY_CLASS, **kwargs: Any)

Bases: MultilabelPrecision

Calculates mPrecision for multilabel tasks.

compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

mPrecision_running: Tensor
samples: Tensor
average: Literal['macro', 'none'] | None
multidim_average: Literal['global', 'samplewise']
tp: List[Tensor] | Tensor
fp: List[Tensor] | Tensor
tn: List[Tensor] | Tensor
fn: List[Tensor] | Tensor
training: bool
class metrics.MultilabelMRecall(num_labels: int, threshold: float = 0.5, average: Literal['macro', 'none'] | None = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, metric_mode: MetricMode = MetricMode.IGNORE_EMPTY_CLASS, **kwargs: Any)

Bases: MultilabelRecall

Calculates mRecall for multilabel tasks.

compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

mRecall_running: Tensor
samples: Tensor
average: Literal['macro', 'none'] | None
multidim_average: Literal['global', 'samplewise']
tp: List[Tensor] | Tensor
fp: List[Tensor] | Tensor
tn: List[Tensor] | Tensor
fn: List[Tensor] | Tensor
training: bool
class metrics.MultilabelMF1Score(num_labels: int, threshold: float = 0.5, average: Literal['macro', 'none'] | None = 'macro', multidim_average: Literal['global', 'samplewise'] = 'global', ignore_index: int | None = None, validate_args: bool = True, zero_division: float = 0.0, **kwargs: Any)

Bases: MultilabelF1Score

Compute mF1 Score for multilabel tasks.

compute() Tensor

Compute metric.

update(preds: Tensor, target: Tensor) None

Update state with predictions and targets.

mF1_running: Tensor
samples: Tensor
average: Literal['macro', 'none'] | None
multidim_average: Literal['global', 'samplewise']
tp: List[Tensor] | Tensor
fp: List[Tensor] | Tensor
tn: List[Tensor] | Tensor
fn: List[Tensor] | Tensor
training: bool