Skip to content

calibrax.metrics.functional.image¤

Image quality metrics for comparing generated or reconstructed images against references. Tier 0 functions include PSNR, SSIM, MS-SSIM, and Vendi Score -- all pure JAX with no external dependencies.

Image quality metrics -- pure math on pixel arrays.

All metrics in this module are pure mathematical operations on image arrays. No pretrained models, no neural networks, no external dependencies beyond JAX.

Includes: PSNR, SSIM, MS-SSIM, and Vendi Score. Registered with domain="image".

psnr(predictions: Any, targets: Any, *, max_val: float = 1.0) -> Any ¤

Peak Signal-to-Noise Ratio.

PSNR = 10 * log10(max_val^2 / MSE). Measured in dB. Higher is better. For identical images, returns a very large value (clamped to avoid inf).

Parameters:

Name Type Description Default
predictions Any

Predicted image array (any shape).

required
targets Any

Ground truth image array (same shape).

required
max_val float

Maximum pixel value (1.0 for [0,1] range, 255 for uint8).

1.0

Returns:

Type Description
Any

PSNR value in dB.

Examples:

>>> import jax.numpy as jnp
>>> img = jnp.ones((8, 8)) * 0.5
>>> psnr(img, img)  # Very high value (identical images)
...

ssim(predictions: Any, targets: Any, *, max_val: float = 1.0, filter_size: int = 11, filter_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03) -> Any ¤

Structural Similarity Index Measure.

Computes luminance, contrast, and structure similarity using a Gaussian window. For multi-channel images, averages across channels.

Parameters:

Name Type Description Default
predictions Any

Predicted image, shape (H, W) or (H, W, C).

required
targets Any

Ground truth image, same shape.

required
max_val float

Maximum pixel value.

1.0
filter_size int

Gaussian window size (should be odd).

11
filter_sigma float

Gaussian standard deviation.

1.5
k1 float

Luminance stability constant.

0.01
k2 float

Contrast stability constant.

0.03

Returns:

Type Description
Any

SSIM value in [0, 1]. 1.0 = identical images.

Examples:

>>> import jax.numpy as jnp
>>> img = jnp.ones((32, 32)) * 0.5
>>> ssim(img, img)
1.0

ms_ssim(predictions: Any, targets: Any, *, max_val: float = 1.0, power_factors: tuple[float, ...] | None = None) -> Any ¤

Multi-Scale Structural Similarity Index.

Computes SSIM at multiple downsample scales and combines with power weights. Requires images large enough for the number of scales.

Parameters:

Name Type Description Default
predictions Any

Predicted image, shape (H, W) or (H, W, C).

required
targets Any

Ground truth image, same shape.

required
max_val float

Maximum pixel value.

1.0
power_factors tuple[float, ...] | None

Weights per scale. Default: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333).

None

Returns:

Type Description
Any

MS-SSIM value in [0, 1]. 1.0 = identical images.

Examples:

>>> import jax.numpy as jnp
>>> img = jnp.ones((160, 160)) * 0.5
>>> ms_ssim(img, img)
1.0

vendi_score(similarity_matrix: Any) -> Any ¤

Vendi Score: diversity measure via eigenvalue entropy.

Computes exp(entropy of eigenvalues) of a similarity matrix. Higher values indicate more diversity.

Parameters:

Name Type Description Default
similarity_matrix Any

Square similarity matrix of shape (n, n). Values should be in [0, 1] with 1 on diagonal.

required

Returns:

Type Description
Any

Vendi score >= 1.0. Score of 1.0 means all items identical,

Any

score of n means maximum diversity (all items orthogonal).

Examples:

>>> import jax.numpy as jnp
>>> identity = jnp.eye(3)  # 3 orthogonal items
>>> vendi_score(identity)  # 3.0
...

Plugin Metrics (Tier 1-2)¤

Optional Dependency

FID, Inception Score, and LPIPS require pretrained backbones: uv pip install "calibrax[image]"

Import directly from the plugin module:

from calibrax.metrics.plugins.image import FIDMetric, InceptionScoreMetric, LPIPSMetric

See Stateful Metrics for the base class API.