Skip to content

calibrax.metrics.functional.fairness¤

Algorithmic fairness metrics for auditing predictions across demographic groups. Includes demographic parity ratio, equalized odds difference, equal opportunity difference, and disparate impact ratio (80% rule).

Fairness metrics for algorithmic bias evaluation.

Pure functions for assessing disparities between demographic groups. All require a protected_attribute array in addition to predictions/targets.

The impossibility theorem (Chouldechova 2017, Kleinberg et al. 2016) states that demographic parity, equalized odds, and predictive value parity cannot all hold simultaneously for an imperfect classifier with different base rates across groups. This module provides multiple fairness criteria so users can understand the inherent trade-offs.

Registered with domain="fairness", signature=MetricSignature.CUSTOM.

demographic_parity_ratio(predictions: Any, protected_attribute: Any) -> Any ¤

Ratio of positive prediction rates across demographic groups.

Computes min(rate_a/rate_b, rate_b/rate_a) for all group pairs, returning the minimum pairwise ratio. Does NOT use targets.

Parameters:

Name Type Description Default
predictions Any

Binary predictions or probabilities, shape (n,).

required
protected_attribute Any

Group membership labels, shape (n,).

required

Returns:

Type Description
Any

DPR in [0, 1]. 1.0 = perfect demographic parity.

Examples:

>>> import jax.numpy as jnp
>>> preds = jnp.array([1, 1, 0, 0, 1, 1])
>>> groups = jnp.array([0, 0, 0, 1, 1, 1])
>>> demographic_parity_ratio(preds, groups)  # 2/3 / (2/3) = 1.0
...

equalized_odds_difference(predictions: Any, targets: Any, protected_attribute: Any) -> Any ¤

Maximum absolute difference in TPR or FPR across groups.

max(|TPR_a - TPR_b|, |FPR_a - FPR_b|) over all group pairs.

Parameters:

Name Type Description Default
predictions Any

Binary predictions, shape (n,).

required
targets Any

Binary ground truth, shape (n,).

required
protected_attribute Any

Group membership labels, shape (n,).

required

Returns:

Type Description
Any

EOD in [0, 1]. 0.0 = perfect equalized odds.

Examples:

>>> import jax.numpy as jnp
>>> preds = jnp.array([1, 1, 0, 1, 1, 0])
>>> targets = jnp.array([1, 1, 0, 1, 1, 0])
>>> groups = jnp.array([0, 0, 0, 1, 1, 1])
>>> equalized_odds_difference(preds, targets, groups)
0.0

equal_opportunity_difference(predictions: Any, targets: Any, protected_attribute: Any) -> Any ¤

Absolute difference in TPR across demographic groups.

Simpler than equalized odds — only examines positive outcomes.

Parameters:

Name Type Description Default
predictions Any

Binary predictions, shape (n,).

required
targets Any

Binary ground truth, shape (n,).

required
protected_attribute Any

Group membership labels, shape (n,).

required

Returns:

Type Description
Any

EOD in [0, 1]. 0.0 = perfect equal opportunity.

Examples:

>>> import jax.numpy as jnp
>>> preds = jnp.array([1, 1, 0, 1, 1, 0])
>>> targets = jnp.array([1, 1, 0, 1, 1, 0])
>>> groups = jnp.array([0, 0, 0, 1, 1, 1])
>>> equal_opportunity_difference(preds, targets, groups)
0.0

disparate_impact_ratio(predictions: Any, protected_attribute: Any) -> Any ¤

Disparate impact ratio (same as demographic parity ratio).

Named following US legal terminology (80% rule). Values < 0.8 typically indicate disparate impact under US legal standards.

Parameters:

Name Type Description Default
predictions Any

Binary predictions or probabilities, shape (n,).

required
protected_attribute Any

Group membership labels, shape (n,).

required

Returns:

Type Description
Any

DIR in [0, 1]. Values >= 0.8 generally pass the 80% rule.

Examples:

>>> import jax.numpy as jnp
>>> preds = jnp.array([1, 1, 0, 0, 1, 1])
>>> groups = jnp.array([0, 0, 0, 1, 1, 1])
>>> disparate_impact_ratio(preds, groups)
...

group_metric_breakdown(metric_fn: Callable[..., float], predictions: Any, targets: Any, protected_attribute: Any) -> dict[str, float] ¤

Apply any metric function separately to each demographic group.

Turns any (predictions, targets) -> float metric into a per-group breakdown. Groups with fewer than 2 samples are skipped.

Parameters:

Name Type Description Default
metric_fn Callable[..., float]

Callable with signature (predictions, targets) -> float.

required
predictions Any

Predicted values, shape (n,).

required
targets Any

Ground truth values, shape (n,).

required
protected_attribute Any

Group membership labels, shape (n,).

required

Returns:

Type Description
dict[str, float]

Dictionary mapping group names (as strings) to metric values.

Examples:

>>> import jax.numpy as jnp
>>> from calibrax.metrics.functional.regression import mse
>>> preds = jnp.array([1.0, 2.0, 3.0, 4.0])
>>> targets = jnp.array([1.0, 2.0, 3.0, 4.0])
>>> groups = jnp.array([0, 0, 1, 1])
>>> group_metric_breakdown(mse, preds, targets, groups)
{'0': 0.0, '1': 0.0}