Skip to content

calibrax.metrics.functional.divergence¤

Divergence functions for comparing probability distributions. Covers KL and reverse-KL divergence, Jensen-Shannon, total variation, Hellinger, chi-squared, Renyi, and the unified f-divergence framework. Also includes sample-based divergences: Wasserstein-1D, sliced Wasserstein, Sinkhorn, MMD, and Bregman divergence.

Statistical divergence functions between probability distributions.

Pure functions for measuring dissimilarity between distributions. Divergences are generally asymmetric (Finslerian in nature), distinguishing them from true distance metrics. This module covers f-divergences (KL, JS, Hellinger, chi-squared, Renyi, TV), optimal transport metrics (Wasserstein, Sinkhorn, sliced Wasserstein), kernel-based metrics (MMD), and Bregman divergences.

Includes 13 functions: kl_divergence, js_divergence, wasserstein_1d, mmd, total_variation, reverse_kl_divergence, hellinger_distance, chi_squared_divergence, renyi_divergence, f_divergence, sinkhorn_divergence, sliced_wasserstein, bregman_divergence.

kl_divergence(p: Any, q: Any) -> Any ¤

Kullback-Leibler divergence: sum(p * log(p / q)).

Measures information lost when q is used to approximate p. NOT symmetric: KL(p||q) != KL(q||p).

Note

Direction: LOWER (0.0 = identical distributions). Range: [0, inf). Not symmetric, not a true metric. Requires p and q to be probability vectors (sum to ~1).

Parameters:

Name Type Description Default
p Any

True distribution (probability vector).

required
q Any

Approximate distribution (probability vector).

required

Returns:

Type Description
Any

KL divergence as a scalar value.

Examples:

>>> import jax.numpy as jnp
>>> p = jnp.array([0.5, 0.5])
>>> kl_divergence(p, p)
0.0

reverse_kl_divergence(p: Any, q: Any) -> Any ¤

Reverse KL divergence: KL(q || p).

Mode-seeking variant, useful for variational inference. Penalizes q for placing mass where p has none.

Note

Direction: LOWER (0.0 = identical distributions). Range: [0, inf). Not symmetric. reverse_kl(p,q) = kl(q,p).

Parameters:

Name Type Description Default
p Any

First distribution.

required
q Any

Second distribution.

required

Returns:

Type Description
Any

Reverse KL divergence as a scalar value.

js_divergence(p: Any, q: Any) -> Any ¤

Jensen-Shannon divergence.

Symmetric, bounded version of KL divergence: JS(p,q) = 0.5 * KL(p||m) + 0.5 * KL(q||m) where m = 0.5*(p+q).

Note

Direction: LOWER (0.0 = identical distributions). Range: [0, ln(2)] with natural log. Symmetric. Square root of JS is a true metric.

Parameters:

Name Type Description Default
p Any

First probability vector.

required
q Any

Second probability vector.

required

Returns:

Type Description
Any

JS divergence as a scalar value.

Examples:

>>> import jax.numpy as jnp
>>> p = jnp.array([0.5, 0.5])
>>> js_divergence(p, p)
0.0

total_variation(p: Any, q: Any) -> Any ¤

Total variation distance: 0.5 * sum(|p - q|).

Both an f-divergence and an integral probability metric. Bounded [0, 1] for probability vectors.

Note

Direction: LOWER (0.0 = identical). Range: [0, 1]. True metric. Symmetric.

Parameters:

Name Type Description Default
p Any

First probability vector.

required
q Any

Second probability vector.

required

Returns:

Type Description
Any

Total variation distance as a scalar value.

hellinger_distance(p: Any, q: Any) -> Any ¤

Hellinger distance between probability distributions.

H(p,q) = sqrt(0.5 * sum((sqrt(p) - sqrt(q))^2)). Related to TV by Pinsker's inequality.

Note

Direction: LOWER (0.0 = identical). Range: [0, 1]. True metric. Symmetric.

Parameters:

Name Type Description Default
p Any

First probability vector.

required
q Any

Second probability vector.

required

Returns:

Type Description
Any

Hellinger distance as a scalar value.

chi_squared_divergence(p: Any, q: Any) -> Any ¤

Pearson chi-squared divergence: sum((p - q)^2 / q).

NOT symmetric. Sensitive to q near zero.

Note

Direction: LOWER (0.0 = identical). Range: [0, inf). Not symmetric. Uses safe_divide for numerical stability.

Parameters:

Name Type Description Default
p Any

Observed distribution.

required
q Any

Expected distribution.

required

Returns:

Type Description
Any

Chi-squared divergence as a scalar value.

renyi_divergence(p: Any, q: Any, *, alpha: float = 0.5) -> Any ¤

Renyi alpha-divergence.

D_alpha(p||q) = 1/(alpha-1) * log(sum(p^alpha * q^(1-alpha))). Generalizes KL (alpha -> 1).

Note

Direction: LOWER (0.0 = identical). Range: [0, inf). Not symmetric.

Parameters:

Name Type Description Default
p Any

First probability vector.

required
q Any

Second probability vector.

required
alpha float

Order parameter. Must not equal 1.0 (use KL instead).

0.5

Returns:

Type Description
Any

Renyi divergence as a scalar value.

Raises:

Type Description
ValueError

If alpha equals 1.0.

f_divergence(p: Any, q: Any, *, generator: Callable[[Any], Any]) -> Any ¤

Unified f-divergence with arbitrary convex generator.

D_f(p||q) = sum(q * f(p / q)) where f is convex with f(1) = 0.

Note

Direction: LOWER (0.0 = identical if f(1)=0). Range: [0, inf). Recovers KL (f(u)=ulog(u)), TV (f(u)=0.5|u-1|), Hellinger, chi-squared as special cases.

Parameters:

Name Type Description Default
p Any

First probability vector.

required
q Any

Second probability vector.

required
generator Callable[[Any], Any]

Convex function f with f(1) = 0.

required

Returns:

Type Description
Any

f-divergence as a scalar value.

wasserstein_1d(p: Any, q: Any) -> Any ¤

1D Wasserstein-1 (Earth Mover's) distance between samples.

For 1D data: sort both, take mean absolute difference. Operates on sample arrays, not probability vectors.

Note

Direction: LOWER (0.0 = identical distributions). Range: [0, inf). True metric. Symmetric.

Parameters:

Name Type Description Default
p Any

First sample array.

required
q Any

Second sample array.

required

Returns:

Type Description
Any

Wasserstein-1 distance as a scalar value.

mmd(x: Any, y: Any, *, kernel: str = 'rbf', bandwidth: float = 1.0) -> Any ¤

Maximum Mean Discrepancy between sample distributions.

Measures distance using kernel mean embeddings. O(n^{-½}) estimation rate regardless of dimension.

Note

Direction: LOWER (0.0 = identical distributions). Range: [0, inf). True metric. Symmetric.

Parameters:

Name Type Description Default
x Any

First sample matrix (n_samples, n_features).

required
y Any

Second sample matrix (n_samples, n_features).

required
kernel str

Kernel type: "rbf" or "laplace".

'rbf'
bandwidth float

Kernel bandwidth parameter.

1.0

Returns:

Type Description
Any

MMD as a scalar value.

sinkhorn_divergence(x: Any, y: Any, *, regularization: float = 0.1, max_iter: int = 100, threshold: float = 1e-05) -> Any ¤

Debiased Sinkhorn divergence (entropic optimal transport).

S(x,y) = OT_reg(x,y) - 0.5*(OT_reg(x,x) + OT_reg(y,y)). Differentiable and JIT-compatible.

Note

Direction: LOWER (0.0 = identical distributions). Range: [0, inf). Symmetric. Differentiable. Debiased: S(x,x) = 0.

Parameters:

Name Type Description Default
x Any

First sample matrix (n_samples, n_features).

required
y Any

Second sample matrix (n_samples, n_features).

required
regularization float

Entropic regularization strength.

0.1
max_iter int

Maximum Sinkhorn iterations.

100
threshold float

Convergence threshold.

1e-05

Returns:

Type Description
Any

Sinkhorn divergence as a scalar value.

sliced_wasserstein(x: Any, y: Any, *, num_projections: int = 50, p: float = 2.0, key: Any | None = None) -> Any ¤

Sliced Wasserstein distance.

Project onto random 1D directions, compute exact 1D Wasserstein, average. Practical for high-dimensional distribution comparison.

Note

Direction: LOWER (0.0 = identical distributions). Range: [0, inf). True metric. Symmetric.

Parameters:

Name Type Description Default
x Any

First sample matrix (n_samples, n_features).

required
y Any

Second sample matrix (n_samples, n_features).

required
num_projections int

Number of random 1D projections.

50
p float

Order of Wasserstein distance.

2.0
key Any | None

JAX PRNG key for reproducibility. Uses fixed seed if None.

None

Returns:

Type Description
Any

Sliced Wasserstein distance as a scalar value.

bregman_divergence(x: Any, y: Any, *, generator: Callable[[Any], Any], generator_grad: Callable[[Any], Any] | None = None) -> Any ¤

Bregman divergence with arbitrary convex generator.

D_psi(x, y) = psi(x) - psi(y) - <grad_psi(y), x - y>. Unifies squared Euclidean, KL, Itakura-Saito, and Mahalanobis.

Note

Direction: LOWER (0.0 = identical points). Range: [0, inf). Not symmetric in general.

Parameters:

Name Type Description Default
x Any

First point or batch.

required
y Any

Second point or batch.

required
generator Callable[[Any], Any]

Strictly convex differentiable function psi.

required
generator_grad Callable[[Any], Any] | None

Gradient of generator. If None, computed via jax.grad(generator).

None

Returns:

Type Description
Any

Bregman divergence as a scalar value.