Skip to content

calibrax.metrics.functional.information¤

Information-theoretic functions based on Shannon entropy. Includes entropy, cross-entropy, mutual information, conditional entropy, normalized mutual information, and the Fisher information matrix.

Information-theoretic functions.

Pure functions for measuring information content, entropy, and dependence between random variables. Based on Shannon's information theory and Fisher's information geometry.

Includes 6 functions: entropy, cross_entropy, mutual_information, conditional_entropy, normalized_mutual_information, fisher_information_matrix.

entropy(p: Any) -> Any ¤

Shannon entropy: -sum(p * log(p)).

Measures uncertainty or disorder of a probability distribution. Maximum for uniform distribution, zero for deterministic.

Note

Direction: INFO (neither higher nor lower is inherently better). Range: [0, log(n)] where n is the number of outcomes. Uses natural logarithm.

Parameters:

Name Type Description Default
p Any

Probability vector (must sum to ~1).

required

Returns:

Type Description
Any

Shannon entropy as a scalar value.

Examples:

>>> import jax.numpy as jnp
>>> entropy(jnp.array([0.5, 0.5]))  # ln(2) ≈ 0.693
0.693...

cross_entropy(p: Any, q: Any) -> Any ¤

Cross-entropy: -sum(p * log(q)).

Measures the average number of nats needed to encode data from p using distribution q. Always >= entropy(p).

Note

Direction: LOWER (lower = better approximation of p by q). Range: [0, inf). cross_entropy(p, q) = entropy(p) + kl_divergence(p, q).

Parameters:

Name Type Description Default
p Any

True distribution (probability vector).

required
q Any

Coding distribution (probability vector).

required

Returns:

Type Description
Any

Cross-entropy as a scalar value.

Examples:

>>> import jax.numpy as jnp
>>> p = jnp.array([0.5, 0.5])
>>> cross_entropy(p, p)  # equals entropy(p) ≈ 0.693
0.693...

mutual_information(joint: Any) -> Any ¤

Mutual information from a joint probability table.

MI(X;Y) = sum_{i,j} p(i,j) * log(p(i,j) / (p(i)*p(j))). Measures statistical dependence between X and Y.

Note

Direction: HIGHER (more information = stronger dependence). Range: [0, min(H(X), H(Y))]. Symmetric: MI(X;Y) = MI(Y;X).

Parameters:

Name Type Description Default
joint Any

2D joint probability table of shape (n_x, n_y). Must sum to ~1.

required

Returns:

Type Description
Any

Mutual information as a scalar value.

Examples:

>>> import jax.numpy as jnp
>>> # Independent: p(x,y) = p(x)*p(y)
>>> joint = jnp.array([[0.25, 0.25], [0.25, 0.25]])
>>> mutual_information(joint)
0.0

conditional_entropy(joint: Any) -> Any ¤

Conditional entropy H(Y|X) from a joint probability table.

H(Y|X) = H(X,Y) - H(X). Measures remaining uncertainty about Y given knowledge of X.

Note

Direction: LOWER (less uncertainty = better prediction). Range: [0, H(Y)].

Parameters:

Name Type Description Default
joint Any

2D joint probability table of shape (n_x, n_y). Must sum to ~1.

required

Returns:

Type Description
Any

Conditional entropy as a scalar value.

Examples:

>>> import jax.numpy as jnp
>>> # Perfect dependence: knowing X determines Y
>>> joint = jnp.array([[0.5, 0.0], [0.0, 0.5]])
>>> conditional_entropy(joint)
0.0

normalized_mutual_information(joint: Any) -> Any ¤

Normalized mutual information: MI / sqrt(H(X) * H(Y)).

Bounded version of MI for comparing across different scales.

Note

Direction: HIGHER (1.0 = perfect dependence). Range: [0, 1]. Symmetric.

Parameters:

Name Type Description Default
joint Any

2D joint probability table of shape (n_x, n_y). Must sum to ~1.

required

Returns:

Type Description
Any

Normalized MI as a scalar value.

Examples:

>>> import jax.numpy as jnp
>>> joint = jnp.array([[0.5, 0.0], [0.0, 0.5]])
>>> normalized_mutual_information(joint)
1.0

fisher_information_matrix(log_prob_fn: Callable[..., Any], params: Any) -> Any ¤

Fisher information matrix at given parameter values.

I(theta) = -E[nabla^2 log p(x|theta)]. The unique Riemannian metric invariant under sufficient statistics (Chentsov's theorem).

Note

NOT registered as a scalar metric — returns a matrix. Useful for natural gradient methods and information geometry.

Parameters:

Name Type Description Default
log_prob_fn Callable[..., Any]

Log-probability function taking params as input.

required
params Any

Parameter values at which to compute the Fisher matrix.

required

Returns:

Type Description
Any

Fisher information matrix as a JAX array of shape

Any

(num_params, num_params).

Examples:

>>> import jax
>>> import jax.numpy as jnp
>>> def log_prob(theta):
...     return -0.5 * jnp.sum(theta ** 2)
>>> fim = fisher_information_matrix(log_prob, jnp.array([1.0, 2.0]))
>>> fim.shape
(2, 2)