Skip to content

calibrax.metrics.functional.clustering¤

Clustering evaluation metrics for measuring partition quality. Covers extrinsic measures (adjusted Rand index, NMI, AMI, V-measure) that compare against ground-truth labels, and intrinsic measures (silhouette, Calinski- Harabasz, Davies-Bouldin) that evaluate structure without labels.

Clustering evaluation metrics for unsupervised learning.

Pure functions for evaluating clustering quality. Divided into two categories:

External evaluation (requires ground truth labels): - adjusted_rand_index, normalized_mutual_information_clustering, adjusted_mutual_information, v_measure

Internal evaluation (no ground truth, uses feature distances): - silhouette_score, calinski_harabasz_score, davies_bouldin_score

All accept integer label arrays. Internal metrics additionally require a feature matrix. Registered with domain="clustering" and signature=MetricSignature.FEATURES_LABELS.

adjusted_rand_index(labels_true: Any, labels_pred: Any) -> Any ¤

Adjusted Rand Index for clustering agreement.

Measures similarity between two clusterings, corrected for chance. ARI = (RI - E[RI]) / (max(RI) - E[RI]).

Parameters:

Name Type Description Default
labels_true Any

Ground truth integer labels, shape (n,).

required
labels_pred Any

Predicted integer labels, shape (n,).

required

Returns:

Type Description
Any

ARI value in [-1, 1]. 1.0 = perfect agreement, 0.0 = random,

Any

negative = worse than random.

Examples:

>>> import jax.numpy as jnp
>>> adjusted_rand_index(jnp.array([0, 0, 1, 1]), jnp.array([0, 0, 1, 1]))
1.0

normalized_mutual_information_clustering(labels_true: Any, labels_pred: Any, *, average: str = 'arithmetic') -> Any ¤

Normalized Mutual Information for clustering.

MI(true, pred) / normalizer. Range [0, 1].

Parameters:

Name Type Description Default
labels_true Any

Ground truth integer labels, shape (n,).

required
labels_pred Any

Predicted integer labels, shape (n,).

required
average str

Normalizer type. One of "arithmetic" (default), "geometric", "min", "max".

'arithmetic'

Returns:

Type Description
Any

NMI value in [0, 1]. 1.0 = perfect agreement.

Raises:

Type Description
ValueError

If average is not one of the supported options.

Examples:

>>> import jax.numpy as jnp
>>> normalized_mutual_information_clustering(
...     jnp.array([0, 0, 1, 1]), jnp.array([0, 0, 1, 1])
... )
1.0

adjusted_mutual_information(labels_true: Any, labels_pred: Any) -> Any ¤

Adjusted Mutual Information for clustering.

Chance-adjusted version of NMI: AMI = (MI - E[MI]) / (max(H_true, H_pred) - E[MI]). More robust than NMI for comparing clusterings of different sizes.

Parameters:

Name Type Description Default
labels_true Any

Ground truth integer labels, shape (n,).

required
labels_pred Any

Predicted integer labels, shape (n,).

required

Returns:

Type Description
Any

AMI value in [-1, 1]. 1.0 = perfect agreement,

Any

0.0 = random labeling.

Examples:

>>> import jax.numpy as jnp
>>> adjusted_mutual_information(
...     jnp.array([0, 0, 1, 1]), jnp.array([0, 0, 1, 1])
... )
1.0

v_measure(labels_true: Any, labels_pred: Any, *, beta: float = 1.0) -> Any ¤

V-measure: harmonic mean of homogeneity and completeness.

Equivalent to NMI with arithmetic normalizer when beta=1.0. beta > 1 weights completeness more, beta < 1 weights homogeneity more.

Parameters:

Name Type Description Default
labels_true Any

Ground truth integer labels, shape (n,).

required
labels_pred Any

Predicted integer labels, shape (n,).

required
beta float

Weight parameter. 1.0 = equal weight. >1 = favor completeness.

1.0

Returns:

Type Description
Any

V-measure in [0, 1]. 1.0 = perfect clustering.

Examples:

>>> import jax.numpy as jnp
>>> v_measure(jnp.array([0, 0, 1, 1]), jnp.array([0, 0, 1, 1]))
1.0

silhouette_score(features: Any, labels: Any) -> Any ¤

Mean silhouette coefficient across all samples.

For each sample: s = (b - a) / max(a, b) where a = mean intra-cluster distance and b = mean nearest-cluster distance. O(n^2) complexity.

Parameters:

Name Type Description Default
features Any

Feature matrix of shape (n, d).

required
labels Any

Cluster assignment integer labels, shape (n,).

required

Returns:

Type Description
Any

Mean silhouette in [-1, 1]. Higher = better separated clusters.

Examples:

>>> import jax.numpy as jnp
>>> features = jnp.array([[0.0, 0.0], [0.1, 0.0], [10.0, 10.0], [10.1, 10.0]])
>>> labels = jnp.array([0, 0, 1, 1])
>>> silhouette_score(features, labels)  # Close to 1.0
...

calinski_harabasz_score(features: Any, labels: Any) -> Any ¤

Calinski-Harabasz Index (Variance Ratio Criterion).

Ratio of between-cluster to within-cluster dispersion, adjusted for cluster and sample counts. Higher = better-separated clusters.

Parameters:

Name Type Description Default
features Any

Feature matrix of shape (n, d).

required
labels Any

Cluster assignment integer labels, shape (n,).

required

Returns:

Type Description
Any

Calinski-Harabasz score (>= 0). Higher is better.

Examples:

>>> import jax.numpy as jnp
>>> features = jnp.array([[0.0, 0.0], [0.1, 0.0], [10.0, 10.0], [10.1, 10.0]])
>>> labels = jnp.array([0, 0, 1, 1])
>>> calinski_harabasz_score(features, labels)  # Large value
...

davies_bouldin_score(features: Any, labels: Any) -> Any ¤

Davies-Bouldin Index for cluster separation.

For each cluster, finds the worst-case similarity ratio with another cluster. Lower = better separated clusters.

Parameters:

Name Type Description Default
features Any

Feature matrix of shape (n, d).

required
labels Any

Cluster assignment integer labels, shape (n,).

required

Returns:

Type Description
Any

Davies-Bouldin score (>= 0). Lower is better.

Examples:

>>> import jax.numpy as jnp
>>> features = jnp.array([[0.0, 0.0], [0.1, 0.0], [10.0, 10.0], [10.1, 10.0]])
>>> labels = jnp.array([0, 0, 1, 1])
>>> davies_bouldin_score(features, labels)  # Close to 0
...