Source code for fenn.nn.trainers.trainer

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Optional, Union

import torch
import torch.nn
import torch.optim
from sklearn.metrics import (  # noqa: F401
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
)
from torch.utils.data import DataLoader

from fenn.export import Exporter
from fenn.logging import Logger
from fenn.nn.utils import Checkpoint, ModelPrettyPrinter, TrainingState


[docs] class Trainer(ABC): """The base Trainer abstract class for classification and regression tasks. Provides a common training loop with support for early stopping, checkpointing, and validation monitoring. Subclasses must implement :meth:`fit` to define the per-epoch training logic and :meth:`predict` to generate predictions from a model. Subclasses: - :class:`ClassificationTrainer` for classification tasks. - :class:`RegressionTrainer` for regression tasks. - :class:`LoRATrainer` for parameter-efficient fine-tuning. """
[docs] @abstractmethod def __init__( self, model: torch.nn.Module, loss_fn: torch.nn.Module, optim: torch.optim.Optimizer, device: Union[torch.device, str] = "cpu", early_stopping_patience: Optional[int] = None, checkpoint_config: Optional[Checkpoint] = None, ): """Initialize a Trainer instance to fit a neural network model. Args: model: The neural network model to train. loss_fn: The loss function to use. optim: The optimizer to use. num_classes: The number of classes to predict. device: The device on which the data will be loaded. early_stopping_patience: The number of epochs to wait before early stopping. checkpoint_config: The checkpoint configuration. If `None`, checkpointing is disabled. """ self._logger = Logger() self._exporter = Exporter() self._loss_fn = loss_fn self._num_classes = 2 self._device = torch.device(device) # training state at epoch 0 self._model = model.to(device) self._optimizer = optim self._state = TrainingState(epoch=0) self._log_model_summary() self._best_state: Optional[TrainingState] = None """Best training state based on validation loss.""" self._best_model: Optional[torch.nn.Module] = None # checkpoint setup self._checkpoint = checkpoint_config if self._checkpoint is not None: self._checkpoint._setup() # early stopping setup self._early_stopping_patience = early_stopping_patience if self._early_stopping_patience is not None: self._logger.display_info( f"Early stopping enabled with patience of {self._early_stopping_patience} epochs." )
def _log_model_summary(self) -> None: summary = ModelPrettyPrinter(self._model).render() self._logger.display_info(summary, display_on_terminal=False) def _move_to_device(self, batch: Any, device: Union[torch.device, str]) -> Any: """Recursively move tensor data to the specified device. Handles tensors, lists, tuples, and dictionaries of tensors. Non-tensor values are passed through unchanged. Args: batch: Input data — a tensor, list, tuple, or dict of tensors. device: Target device (``torch.device`` or string like ``'cuda'``). Returns: The input data with all tensors moved to the target device. """ if torch.is_tensor(batch): return batch.to(device) if isinstance(batch, (list, tuple)): return type(batch)(self._move_to_device(x, device) for x in batch) if isinstance(batch, dict): return {k: self._move_to_device(v, device) for k, v in batch.items()} return batch def _should_save_checkpoint(self, epoch: int, is_last_epoch: bool = False) -> bool: """Check whether a checkpoint should be saved at the given epoch. Evaluates the checkpoint configuration to determine if the current epoch qualifies for a checkpoint save based on fixed intervals, explicit epoch lists, or being the final training epoch. Args: epoch: The current epoch number (1-indexed). is_last_epoch: Whether this is the final training epoch. Returns: ``True`` if a checkpoint should be saved, ``False`` otherwise. """ if self._checkpoint is None: return False if is_last_epoch: # always save at the last epoch return True elif isinstance(self._checkpoint.epochs, list): # save at specific epochs return epoch in self._checkpoint.epochs elif isinstance(self._checkpoint.epochs, int): # save every N epochs return epoch % self._checkpoint.epochs == 0 return False
[docs] @abstractmethod def fit( self, train_loader: DataLoader, epochs: int, val_loader: Optional[DataLoader] = None, val_epochs: int = 1, ): """Train the model for a fixed number of epochs. Runs the full training loop including forward/backward passes, validation evaluation, checkpointing, and early stopping. The exact behavior depends on the validation and early stopping configuration: * No validation loader, no early stopping: run full ``epochs``. * No validation loader, early stopping set: stop on training loss. * Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics. * Validation loader and early stopping set: monitor validation loss and stop when it plateaus for ``early_stopping_patience`` epochs. Args: train_loader: PyTorch ``DataLoader`` yielding ``(data, labels)`` batches for training. epochs: Total number of training epochs. If resuming from a checkpoint, only the remaining epochs are run. val_loader: Optional ``DataLoader`` for validation evaluation. val_epochs: How frequently to evaluate on the validation set (e.g. ``val_epochs=2`` means every 2 epochs). Returns: The trained model. """ pass
def _replace_state(self, new_state: TrainingState) -> None: """Replace the current training state with a previously saved state. Loads the model weights and optimizer state from the provided state, effectively restoring training to a checkpointed point. Args: new_state: A :class:`~fenn.nn.utils.TrainingState` containing the model and optimizer state dicts to restore. """ self._state = new_state if new_state.model_state_dict: self._model.load_state_dict(new_state.model_state_dict) if new_state.optimizer_state_dict: self._optimizer.load_state_dict(new_state.optimizer_state_dict)
[docs] def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> None: """Load a checkpoint from the given file path and restore training state. Restores the model weights, optimizer state, and epoch counter from a previously saved checkpoint file. Args: checkpoint_path: Path to the ``.pt`` checkpoint file. Raises: ValueError: If no checkpoint configuration was provided at init. FileNotFoundError: If the checkpoint file does not exist. """ if self._checkpoint is None: raise ValueError("Cannot load checkpoint: checkpoint_config is missing.") new_state = self._checkpoint.load(checkpoint_path) self._replace_state(new_state)
[docs] def load_checkpoint_at_epoch(self, epoch: int) -> None: """Load the checkpoint saved at a specific epoch. Searches the checkpoint directory for the saved state at the requested epoch and restores the model and optimizer. Args: epoch: The epoch whose checkpoint to load (1-indexed). Raises: ValueError: If no checkpoint configuration was provided at init. FileNotFoundError: If no checkpoint exists for the given epoch. """ if self._checkpoint is None: raise ValueError("Cannot load checkpoint: checkpoint_config is missing.") new_state = self._checkpoint.load_at_epoch(epoch) self._replace_state(new_state)
[docs] def load_best_checkpoint(self) -> None: """Load the best-performing checkpoint into the trainer's model. Restores the model weights from the checkpoint with the lowest validation (or training) loss recorded during training. Raises: ValueError: If no checkpoint configuration was provided at init. FileNotFoundError: If no best checkpoint file exists. """ if self._checkpoint is None: raise ValueError("Cannot load checkpoint: checkpoint_config is missing.") new_state = self._checkpoint.load_best() self._replace_state(new_state)
[docs] def save_model(self, model_name: str = "model.pth"): torch.save(self._model.state_dict(), (self._exporter.export_dir / model_name))
[docs] @abstractmethod def predict(self, dataloader_or_batch: Union[DataLoader, torch.Tensor]): """Generate predictions from the trained model. Runs inference on the provided data without computing gradients, returning model predictions in the same format as the training labels. Args: dataloader_or_batch: Either a PyTorch ``DataLoader`` yielding data batches, or a single tensor batch. Returns: A list of predictions (one per sample). """ pass