from __future__ import annotations
from collections.abc import Sequence
from pathlib import Path
from typing import Union
from neuralogic.core.builder.dataset import BuiltDataset, GroundedDataset
from neuralogic.core.neural_module import NeuralModule
from neuralogic.dataset import Dataset
from neuralogic.nn.trainer.callbacks import (
CheckpointCallback,
EarlyStoppingCallback,
ProgressCallback,
TrainerCallback,
)
from neuralogic.nn.trainer.helpers import _build_logs, _ensure_built, _mean, _unpack_results
from neuralogic.nn.trainer.history import TrainerHistory
from neuralogic.nn.trainer.metrics import Metric, _validate_metrics, compute_metrics
[docs]
class Trainer:
def __init__(self, module: NeuralModule) -> None:
if module._neural_model is None:
raise ValueError("The model must be built before creating a Trainer. Call model.build(settings) first.")
self.model = module
self.stop_training: bool = False
[docs]
def fit(
self,
train_dataset: Dataset | GroundedDataset | BuiltDataset,
val_dataset: Dataset | GroundedDataset | BuiltDataset | None = None,
*,
epochs: int = 1,
batch_size: int = 1,
early_stopping_patience: int | None = None,
min_delta: float = 0.0,
checkpoint_dir: str | Path | None = None,
metrics: Sequence[Union[str, Metric]] | None = None,
silent: bool = False,
callbacks: Sequence[TrainerCallback] | None = None,
) -> TrainerHistory:
"""Run the training loop.
Parameters
----------
train_dataset :
Training data. Raw ``Dataset`` objects are built automatically;
pass a ``BuiltDataset`` to skip repeated grounding.
val_dataset :
Optional validation data. When provided, validation loss (and
any requested metrics) are computed after every epoch. Early
stopping and checkpointing depend on validation loss.
epochs : int
Number of epochs to train. Default 1.
batch_size : int
Batch size when building raw datasets. Default 1.
early_stopping_patience : int or None
Stop after this many epochs without validation-loss improvement.
Requires ``val_dataset``. Default ``None`` (no early stopping).
min_delta : float
Minimum absolute change in validation loss to count as
improvement. Default 0.0.
checkpoint_dir : str, Path, or None
Directory to save the best model (by validation loss). A file
named ``best.pkl`` is written on every improvement. Default
``None`` (no checkpointing).
metrics : Sequence[str or Metric] or None
Extra metrics to compute, e.g. ``[Metric.ACCURACY]`` or
``["mae", "r2"]``. Loss is always tracked. Default ``None``.
silent : bool
If ``True``, suppress the tqdm progress bar. Default ``False``.
callbacks : Sequence[TrainerCallback] or None
Additional callbacks to invoke. Built-in callbacks (early
stopping, checkpoint, progress) are appended automatically
based on the other arguments.
Returns
-------
TrainerHistory
Losses and metrics for every epoch.
"""
metric_names = [str(m) for m in metrics] if metrics else []
_validate_metrics(metric_names)
built_train = _ensure_built(self.model, train_dataset, batch_size)
built_val = _ensure_built(self.model, val_dataset, batch_size) if val_dataset is not None else None
optimizer = self.model._settings.optimizer
lr_decay = optimizer._lr_decay if hasattr(optimizer, "_lr_decay") else None
cb_list: list[TrainerCallback] = []
if early_stopping_patience is not None:
cb_list.append(EarlyStoppingCallback(early_stopping_patience, min_delta))
if checkpoint_dir is not None:
cb_list.append(CheckpointCallback(checkpoint_dir))
if not silent:
cb_list.append(ProgressCallback(epochs))
if callbacks:
cb_list.extend(callbacks)
history = TrainerHistory()
self.stop_training = False
for cb in cb_list:
cb.on_train_begin(self)
for epoch in range(epochs):
if self.stop_training:
history.stopped_early = True
break
train_results = self.model.train(built_train, epochs=1)
train_targets, train_outputs, train_errors = _unpack_results(train_results)
train_loss = _mean(train_errors)
history.train_losses.append(train_loss)
val_loss: float | None = None
if built_val is not None:
state = self.model.state_dict()
val_results = self.model.train(built_val, epochs=1)
self.model.load_state_dict(state)
val_targets, val_outputs, val_errors = _unpack_results(val_results)
val_loss = _mean(val_errors)
history.val_losses.append(val_loss)
if val_loss < history.best_val_loss:
history.best_val_loss = val_loss
history.best_epoch = epoch
current_lr = optimizer.lr
history.learning_rates.append(current_lr)
if lr_decay is not None:
lr_decay.decay(epoch)
logs = _build_logs(train_loss, val_loss, current_lr)
if metric_names and train_outputs:
train_mets = compute_metrics(train_targets, train_outputs, metric_names)
for name, value in train_mets.items():
history.train_metrics.setdefault(name, []).append(value)
logs[f"train_{name}"] = value
if metric_names and built_val is not None and val_outputs:
val_mets = compute_metrics(val_targets, val_outputs, metric_names) # type: ignore[possibly-unbound]
for name, value in val_mets.items():
history.val_metrics.setdefault(name, []).append(value)
logs[f"val_{name}"] = value
for cb in cb_list:
cb.on_epoch_end(self, epoch, logs)
for cb in cb_list:
cb.on_train_end(self)
return history
[docs]
def test(
self,
dataset: Dataset | GroundedDataset | BuiltDataset,
*,
batch_size: int = 1,
) -> list:
"""Evaluate the model on a dataset (no weight updates).
Parameters
----------
dataset :
Test data. Raw ``Dataset`` objects are built automatically.
batch_size : int
Batch size when building raw datasets. Default 1.
Returns
-------
list
Model outputs for every sample.
"""
built = _ensure_built(self.model, dataset, batch_size)
result = self.model.test(built)
if isinstance(result, list):
return result
return [result]