Source code for fenn.nn.utils.checkpoint

from pathlib import Path
from typing import List, Optional, Union

import torch

from fenn.logging import Logger
from fenn.nn.utils.state import TrainingState


[docs] class Checkpoint: """Checkpoint training state at given epochs and/or always the best model. Saves full :class:`TrainingState` snapshots (model weights, optimizer state, epoch counter, metrics) during training so that training can be resumed or the best model restored later. Args: name: Base filename for checkpoint files (without extension). dir: Directory to save checkpoint files in. epochs: When to save checkpoints — an ``int`` saves every N epochs, a ``list[int]`` saves at specific epochs, or ``None`` to save only the best model. save_best: If ``True``, save the best model seen so far, updated whenever validation/training loss improves. Example: >>> checkpoint = Checkpoint(dir="checkpoints/", epochs=5, save_best=True) >>> trainer = Trainer(model, loss_fn, optimizer, checkpoint_config=checkpoint) """
[docs] def __init__( self, *, name: str = "checkpoint", dir: Union[Path, str], epochs: Optional[Union[int, List[int]]] = None, save_best: bool = True, ): """Initialize the checkpoint configuration. Args: name: The name of the checkpoint file. dir: The directory to save checkpoints to. epochs: The epochs at which to save checkpoints. save_best: Whether to checkpoint the best model (based on validation or training loss). """ self._logger = Logger() self.name = name self.dir = Path(dir) self.epochs = epochs self.save_best = save_best
def _setup(self) -> Optional["Checkpoint"]: """Set up the checkpoint directory and checks.""" self.dir.mkdir(parents=True, exist_ok=True) if self.epochs is None and not self.save_best: self._logger.system_warning( "Checkpoint configuration is passed, but both `epochs` and `save_best` are unset.\n" "Models will not be checkpointed." ) return if self.epochs is not None: self._logger.display_info( f"Checkpointing enabled. Checkpoints will be saved to {self.dir} every {self.epochs} epochs." ) if self.save_best: self._logger.display_info( f"Best model checkpointing enabled. Best model will be saved to {self.dir}." ) return self
[docs] def save(self, state: TrainingState, is_best: bool = False) -> None: """Save a checkpoint of the training state at the current epoch. Args: state: The training state to checkpoint. is_best: If true save as best model """ epoch = state.epoch if not is_best: filename = f"{self.name}_epoch_{epoch}.pt" filepath = self.dir / filename torch.save(state.to_dict(), filepath) self._logger.display_info( f"Checkpoint saved at epoch {epoch} to {filepath}.", display_on_terminal=False, ) elif is_best and self.save_best: filename = f"{self.name}_best.pt" filepath = self.dir / filename torch.save(state.to_dict(), filepath) self._logger.display_info( f"Best model checkpoint saved to {filepath} with acc {state.acc:.4f}.", display_on_terminal=False, )
[docs] def load( self, checkpoint_path: Union[str, Path], device: Optional[torch.device] = None ) -> TrainingState: """Load a checkpoint from the given path. Args: path: Path to the checkpoint file. device: The device to load the checkpoint onto. Returns: The training state of the checkpoint. """ filepath = Path(checkpoint_path) if not filepath.exists(): raise FileNotFoundError(f"Checkpoint file does not exist: {filepath}") if not filepath.suffix == ".pt": raise ValueError( f"Invalid checkpoint path: {filepath}. Checkpoint must be a .pt file." ) checkpoint = torch.load(filepath, map_location=device) state = TrainingState.from_dict(checkpoint) self._logger.display_info( f"Checkpoint loaded from {checkpoint_path}. Resuming from " f"epoch {state.epoch} with training loss {state.train_loss:.4f}." ) return state
[docs] def load_at_epoch( self, epoch: int, device: Optional[torch.device] = None ) -> TrainingState: """Load the checkpoint at the given epoch. Args: epoch: Epoch to load the checkpoint at. device: The device to load the checkpoint onto. Returns: The training state of the checkpoint. """ filepath = self.dir / f"{self.name}_epoch_{epoch}.pt" if not filepath.exists(): raise FileNotFoundError( f"Filepath does not exist: {filepath}.\n" f"Training state at epoch {epoch} has not been checkpointed." ) return self.load(filepath, device)
[docs] def load_best(self, device: Optional[torch.device] = None) -> TrainingState: """Load the best checkpoint. Args: device: The device to load the checkpoint onto. Returns: The training state of the checkpoint. """ filepath = self.dir / f"{self.name}_best.pt" if not filepath.exists(): raise FileNotFoundError( f"Filepath does not exist: {filepath}.\n" f"Best training state has not been checkpointed." ) return self.load(filepath, device)