Source code for neuralogic.nn.trainer.callbacks

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from neuralogic.nn.trainer.trainer import Trainer


[docs] class TrainerCallback: """Base class for training callbacks. Override any of the hooks. The trainer calls them in the order they were passed to :meth:`Trainer.fit`. """
[docs] def on_train_begin(self, trainer: Trainer) -> None: """Called once before the first epoch."""
[docs] def on_epoch_end(self, trainer: Trainer, epoch: int, logs: dict[str, Any]) -> None: """Called after every epoch. Parameters ---------- trainer : Trainer The trainer instance (access ``trainer.model``, etc.). epoch : int 0-indexed epoch number that just finished. logs : dict Dictionary with keys ``"train_loss"``, ``"val_loss"`` (if available), ``"lr"``, and per-metric keys like ``"train_accuracy"``, ``"val_mae"``, etc. """
[docs] def on_train_end(self, trainer: Trainer) -> None: """Called once after training finishes (or early-stops)."""
[docs] class EarlyStoppingCallback(TrainerCallback): """Stop training when validation loss stops improving. Parameters ---------- patience : int Number of epochs with no improvement after which training stops. min_delta : float Minimum absolute change to qualify as an improvement. """ def __init__(self, patience: int, min_delta: float = 0.0) -> None: self.patience = patience self.min_delta = min_delta self._counter = 0 self._best: float = float("inf")
[docs] def on_epoch_end(self, trainer: Trainer, epoch: int, logs: dict[str, Any]) -> None: val_loss = logs.get("val_loss") if val_loss is None: return if val_loss < self._best - self.min_delta: self._best = val_loss self._counter = 0 else: self._counter += 1 if self._counter >= self.patience: trainer.stop_training = True
[docs] def on_train_begin(self, trainer: Trainer) -> None: self._counter = 0 self._best = float("inf")
[docs] class CheckpointCallback(TrainerCallback): """Save the model whenever validation loss improves. Parameters ---------- directory : str or Path Directory to save checkpoints into (created if missing). filename : str Filename for the checkpoint file (default ``"best.pkl"``). """ def __init__(self, directory: str | Path, filename: str = "best.pkl") -> None: self.directory = Path(directory) self.filename = filename self._best: float = float("inf") @property def path(self) -> Path: return self.directory / self.filename
[docs] def on_train_begin(self, trainer: Trainer) -> None: self.directory.mkdir(parents=True, exist_ok=True) self._best = float("inf")
[docs] def on_epoch_end(self, trainer: Trainer, epoch: int, logs: dict[str, Any]) -> None: val_loss = logs.get("val_loss") if val_loss is None: return # can't decide without validation loss if val_loss < self._best: self._best = val_loss trainer.model.save(str(self.path))
[docs] class ProgressCallback(TrainerCallback): """Show a tqdm progress bar during training. Parameters ---------- epochs : int Total number of epochs (used for progress bar length). """ def __init__(self, epochs: int) -> None: self._epochs = epochs self._pbar: Any = None
[docs] def on_train_begin(self, trainer: Trainer) -> None: try: from tqdm import tqdm self._pbar = tqdm(total=self._epochs, desc="Training", unit="epoch") except ImportError: self._pbar = None
[docs] def on_epoch_end(self, trainer: Trainer, epoch: int, logs: dict[str, Any]) -> None: if self._pbar is None: return postfix: dict[str, Any] = {} for key in ("train_loss", "val_loss", "lr"): if key in logs: postfix[key] = f"{logs[key]:.4f}" for key, value in logs.items(): if key.startswith("val_") and key not in ("val_loss", "lr"): postfix[key] = f"{value:.4f}" self._pbar.set_postfix(postfix) self._pbar.update(1)
[docs] def on_train_end(self, trainer: Trainer) -> None: if self._pbar is not None: self._pbar.close()