Source code for neuralogic.nn.trainer.metrics

from __future__ import annotations

import math
import warnings
from collections.abc import Sequence
from enum import Enum
from typing import TYPE_CHECKING, Callable, Union

if TYPE_CHECKING:
    import numpy as np

_METRIC_REGISTRY: dict[str, Callable[[list, list], float]] = {}


def _register(name: str):
    """Decorator that registers a batch-level metric function."""

    def deco(fn: Callable[[list, list], float]) -> Callable[[list, list], float]:
        _METRIC_REGISTRY[name] = fn
        return fn

    return deco


[docs] class Metric(str, Enum): """Enum of available metric names. Inherits ``str`` so members can be used directly where a string is expected:: >>> Metric.ACCURACY == "accuracy" True >>> trainer.fit(..., metrics=[Metric.ACCURACY, Metric.F1_MACRO]) """ # Regression MAE = "mae" MSE = "mse" RMSE = "rmse" R2 = "r2" # Classification ACCURACY = "accuracy" PRECISION_MACRO = "precision_macro" RECALL_MACRO = "recall_macro" F1_MACRO = "f1_macro"
def _to_arrays(targets: list, outputs: list): """Convert parallel lists of target/output values to numpy arrays.""" import numpy as np t_arr = np.asarray(targets, dtype=float) o_arr = np.asarray(outputs, dtype=float) return t_arr, o_arr def _class_indices(arr) -> np.ndarray: """Convert array to integer class indices. Scalars → threshold 0.5. Vectors → argmax. 2D → row-wise argmax. """ import numpy as np if arr.ndim == 0: return np.asarray(int(arr >= 0.5)).reshape(1) if arr.ndim == 1: return np.asarray(int(np.argmax(arr))).reshape(1) # 2D: row-wise argmax return np.argmax(arr, axis=-1) def _macro_score(t_cls: np.ndarray, o_cls: np.ndarray, mode: str) -> float: """Compute macro-averaged precision, recall, or F1.""" import numpy as np classes = np.unique(np.concatenate([t_cls, o_cls])) scores: list[float] = [] for c in classes: tp = int(np.sum((t_cls == c) & (o_cls == c))) fp = int(np.sum((t_cls != c) & (o_cls == c))) fn = int(np.sum((t_cls == c) & (o_cls != c))) if mode == "precision": scores.append(tp / (tp + fp) if (tp + fp) > 0 else 0.0) elif mode == "recall": scores.append(tp / (tp + fn) if (tp + fn) > 0 else 0.0) elif mode == "f1": prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0 rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0 scores.append(2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0) return float(np.mean(scores)) if scores else 0.0 @_register("mae") def _mae(targets, outputs) -> float: """Mean absolute error.""" import numpy as np t, o = _to_arrays(targets, outputs) return float(np.mean(np.abs(t - o))) @_register("mse") def _mse(targets, outputs) -> float: """Mean squared error.""" import numpy as np t, o = _to_arrays(targets, outputs) return float(np.mean((t - o) ** 2)) @_register("rmse") def _rmse(targets, outputs) -> float: """Root mean squared error.""" import numpy as np t, o = _to_arrays(targets, outputs) return float(math.sqrt(np.mean((t - o) ** 2))) @_register("r2") def _r2(targets, outputs) -> float: """R\u00b2 coefficient of determination.""" import numpy as np t, o = _to_arrays(targets, outputs) ss_res = np.sum((t - o) ** 2) ss_tot = np.sum((t - np.mean(t)) ** 2) if ss_tot == 0: return 1.0 if ss_res == 0 else 0.0 return float(1.0 - ss_res / ss_tot) @_register("accuracy") def _accuracy(targets, outputs) -> float: """Fraction of samples where predicted class equals target class. Scalars are thresholded at 0.5; vectors and 2D rows use argmax. """ import numpy as np t_cls = _class_indices(np.asarray(targets, dtype=float)) o_cls = _class_indices(np.asarray(outputs, dtype=float)) return float(np.mean(t_cls == o_cls)) @_register("precision_macro") def _precision_macro(targets, outputs) -> float: """Macro-averaged precision (unweighted mean of per-class precision).""" t_arr, o_arr = _to_arrays(targets, outputs) t_cls = _class_indices(t_arr) o_cls = _class_indices(o_arr) return _macro_score(t_cls, o_cls, "precision") @_register("recall_macro") def _recall_macro(targets, outputs) -> float: """Macro-averaged recall (unweighted mean of per-class recall).""" t_arr, o_arr = _to_arrays(targets, outputs) t_cls = _class_indices(t_arr) o_cls = _class_indices(o_arr) return _macro_score(t_cls, o_cls, "recall") @_register("f1_macro") def _f1_macro(targets, outputs) -> float: """Macro-averaged F1 score (unweighted mean of per-class F1).""" t_arr, o_arr = _to_arrays(targets, outputs) t_cls = _class_indices(t_arr) o_cls = _class_indices(o_arr) return _macro_score(t_cls, o_cls, "f1")
[docs] def compute_metrics( targets: list, outputs: list, names: Sequence[Union[str, Metric]], ) -> dict[str, float]: """Compute named metrics over a batch of (target, output) pairs. Each metric receives the full batch and returns a single float. Parameters ---------- targets : list Per-sample target values (floats, lists, or 2D lists). outputs : list Per-sample output values (same shapes as targets). names : Sequence[str or Metric] Metric names to compute, e.g. ``["accuracy"]`` or ``[Metric.MAE, Metric.R2]``. Returns ------- dict[str, float] Mapping from metric name to its value across the batch. """ result: dict[str, float] = {} for name in names: key = str(name) if key not in _METRIC_REGISTRY: warnings.warn(f"Unknown metric '{key}'. Available: {sorted(_METRIC_REGISTRY)}") continue fn = _METRIC_REGISTRY[key] result[key] = fn(targets, outputs) return result
def _validate_metrics(metrics: list[str]) -> None: """Warn about unknown metric names.""" unknown = [m for m in metrics if str(m) not in _METRIC_REGISTRY] if unknown: warnings.warn(f"Unknown metric(s): {unknown}. Available: {sorted(_METRIC_REGISTRY)}")