Skip to content

calibrax.metrics.functional.segmentation¤

Segmentation metrics for evaluating pixel-level predictions. Provides Intersection over Union (Jaccard index), Dice coefficient, and pixel accuracy for binary and multi-class segmentation tasks.

Segmentation metrics for pixel/voxel-level evaluation.

Pure functions for evaluating segmentation quality by comparing predicted masks against ground truth masks. Supports binary and multiclass segmentation with multiple averaging modes.

Includes 3 functions: iou, dice_coefficient, pixel_accuracy.

iou(predictions: Any, targets: Any, *, num_classes: int | None = None, average: str = 'binary') -> Any ¤

Intersection over Union (Jaccard index) for segmentation.

Measures overlap between predicted and ground truth masks. For binary: |P ∩ T| / |P ∪ T|. For multiclass: per-class IoU, then averaged.

Note

Direction: HIGHER (1.0 = perfect overlap). Range: [0, 1]. Equivalent to 1 - Jaccard distance. Related to Dice via dice = 2 * iou / (1 + iou).

Parameters:

Name Type Description Default
predictions Any

Predicted integer mask (0/1 for binary, 0..num_classes-1 for multiclass).

required
targets Any

Ground truth integer mask.

required
num_classes int | None

Number of classes. Required for macro/weighted. Inferred from data for binary.

None
average str

Averaging mode. "binary" for single-class, "macro" for unweighted mean across classes, "weighted" for frequency-weighted mean.

'binary'

Returns:

Type Description
Any

IoU score as a scalar value.

Examples:

>>> import jax.numpy as jnp
>>> preds = jnp.array([1, 1, 0, 0])
>>> truth = jnp.array([1, 0, 0, 0])
>>> iou(preds, truth)  # binary IoU
0.5

dice_coefficient(predictions: Any, targets: Any, *, num_classes: int | None = None, average: str = 'binary') -> Any ¤

Dice coefficient (F1 for segmentation).

Measures overlap: 2|P ∩ T| / (|P| + |T|). Equivalent to F1 score applied to pixel-level classification.

Note

Direction: HIGHER (1.0 = perfect overlap). Range: [0, 1]. Related to IoU via dice = 2 * iou / (1 + iou) and iou = dice / (2 - dice).

Parameters:

Name Type Description Default
predictions Any

Predicted integer mask (0/1 for binary, 0..num_classes-1 for multiclass).

required
targets Any

Ground truth integer mask.

required
num_classes int | None

Number of classes. Required for macro/weighted. Inferred from data for binary.

None
average str

Averaging mode. "binary" for single-class, "macro" for unweighted mean across classes, "weighted" for frequency-weighted mean.

'binary'

Returns:

Type Description
Any

Dice coefficient as a scalar value.

Examples:

>>> import jax.numpy as jnp
>>> preds = jnp.array([1, 1, 0, 0])
>>> truth = jnp.array([1, 0, 0, 0])
>>> dice_coefficient(preds, truth)  # 2*1 / (2+1)
0.666...

pixel_accuracy(predictions: Any, targets: Any) -> Any ¤

Fraction of correctly classified pixels.

Simple accuracy metric for segmentation tasks. Counts the proportion of pixels where prediction matches ground truth.

Note

Direction: HIGHER (1.0 = all pixels correct). Range: [0, 1]. Can be misleadingly high for imbalanced classes. Prefer IoU or Dice for class-imbalanced segmentation.

Parameters:

Name Type Description Default
predictions Any

Predicted integer mask.

required
targets Any

Ground truth integer mask.

required

Returns:

Type Description
Any

Pixel accuracy as a scalar value.

Examples:

>>> import jax.numpy as jnp
>>> preds = jnp.array([0, 1, 1, 0])
>>> truth = jnp.array([0, 1, 0, 0])
>>> pixel_accuracy(preds, truth)
0.75