fenn.nn

class fenn.nn.Checkpoint(*, name='checkpoint', dir, epochs=None, save_best=True)[source]

Bases: object

Checkpoint training state at given epochs and/or always the best model.

Saves full TrainingState snapshots (model weights, optimizer state, epoch counter, metrics) during training so that training can be resumed or the best model restored later.

Parameters:
  • name (str) – Base filename for checkpoint files (without extension).

  • dir (Path | str) – Directory to save checkpoint files in.

  • epochs (int | List[int] | None) – When to save checkpoints — an int saves every N epochs, a list[int] saves at specific epochs, or None to save only the best model.

  • save_best (bool) – If True, save the best model seen so far, updated whenever validation/training loss improves.

Example

>>> checkpoint = Checkpoint(dir="checkpoints/", epochs=5, save_best=True)
>>> trainer = Trainer(model, loss_fn, optimizer, checkpoint_config=checkpoint)
__init__(*, name='checkpoint', dir, epochs=None, save_best=True)[source]

Initialize the checkpoint configuration.

Parameters:
  • name (str) – The name of the checkpoint file.

  • dir (Path | str) – The directory to save checkpoints to.

  • epochs (int | List[int] | None) – The epochs at which to save checkpoints.

  • save_best (bool) – Whether to checkpoint the best model (based on validation or training loss).

load(checkpoint_path, device=None)[source]

Load a checkpoint from the given path.

Parameters:
  • path – Path to the checkpoint file.

  • device (device | None) – The device to load the checkpoint onto.

  • checkpoint_path (str | Path)

Returns:

The training state of the checkpoint.

Return type:

TrainingState

load_at_epoch(epoch, device=None)[source]

Load the checkpoint at the given epoch.

Parameters:
  • epoch (int) – Epoch to load the checkpoint at.

  • device (device | None) – The device to load the checkpoint onto.

Returns:

The training state of the checkpoint.

Return type:

TrainingState

load_best(device=None)[source]

Load the best checkpoint.

Parameters:

device (device | None) – The device to load the checkpoint onto.

Returns:

The training state of the checkpoint.

Return type:

TrainingState

save(state, is_best=False)[source]

Save a checkpoint of the training state at the current epoch.

Parameters:
  • state (TrainingState) – The training state to checkpoint.

  • is_best (bool) – If true save as best model

Return type:

None

class fenn.nn.ClassificationTrainer(model, loss_fn, optim, num_classes, multi_label=False, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Bases: Trainer

A trainer for classification tasks with PyTorch models.

Supports binary, multi-class, and multi-label classification by adapting the loss computation and prediction logic based on the task type. Handles both single-label (num_classes == 2 → binary, > 2 → multiclass) and multi-label (multi_label=True) scenarios.

The automatic task type detection configures: - Binary: sigmoid activation, BCE loss, threshold at 0.5 - Multiclass: softmax activation, cross-entropy loss - Multi-label: sigmoid activation, binary cross-entropy per label

Parameters:
  • model (Module) – The neural network model, expected to output logits for the classification task.

  • loss_fn (Module) – Loss function compatible with the task type (e.g. BCEWithLogitsLoss for binary/multi-label, CrossEntropyLoss for multiclass).

  • optim (Optimizer) – Optimizer for updating trainable parameters.

  • num_classes (int) – Number of classes (or labels in multi-label mode). Must be >= 1.

  • multi_label (bool) – Whether this is a multi-label classification problem. Requires num_classes >= 2.

  • device (device | str) – Device to run training on ('cpu', 'cuda', or 'mps').

  • early_stopping_patience (int | None) – Stop training after this many epochs without improvement in validation/training loss. None disables.

  • checkpoint_config (Checkpoint | None) – Optional Checkpoint for saving training state to disk.

Note

For binary classification (num_classes == 2, multi_label=False), labels should be [0, 1] shaped tensors. For multiclass, labels should be class indices. For multi-label, labels should be binary vectors of length num_classes.

__init__(model, loss_fn, optim, num_classes, multi_label=False, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Initialize a ClassificationTrainer instance.

Parameters:
  • model (Module) – The neural network model to train.

  • loss_fn (Module) – The loss function to use.

  • optim (Optimizer) – The optimizer to use.

  • num_classes (int) – The number of classes to predict.

  • multi_label (bool) – Whether to use multi-label classification.

  • device (device | str) – The device on which the data will be loaded.

  • early_stopping_patience (int | None) – The number of epochs to wait before early stopping.

  • checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.

fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]

Train the model with optional validation and early stopping.

The behaviour depends on the combination of val_loader and early_stopping_patience:

  • No validation loader, no early stopping: run full epochs.

  • No validation loader, early stopping set: stop on training loss.

  • Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.

  • Validation loader provided and early stopping set: monitor validation loss and stop when it plateaus.

Parameters:
  • train_loader (DataLoader) – DataLoader for training data.

  • epochs (int) – Total number of epochs to train for.

  • val_loader (DataLoader | None) – DataLoader for validation data (optional).

  • val_epochs (int) – How often to evaluate on validation set (in epochs).

Returns:

The trained model (returned according to return_model).

predict(dataloader_or_batch, return_proba=False)[source]

Predicts the output of the model for a given dataloader or batch.

Parameters:
  • dataloader_or_batch (DataLoader | Tensor) – A DataLoader or a torch tensor.

  • return_proba (bool) – If true, also returns the predicted probabilities alongside the predicted labels

Returns:

A list of predictions. list[list]:

If return_proba=True, returns a tuple where: - first element is the list of predicted labels - second element is the list of predicted probabilities

Return type:

list

class fenn.nn.RegressionTrainer(model, loss_fn, optim, return_model='last', device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Bases: Trainer

A trainer for regression tasks with PyTorch models.

Extends the base Trainer with regression-specific metrics (R² score, MSE) and continuous-value prediction logic. Handles single-target regression with optional validation and early stopping.

Parameters:
  • model (Module) – The neural network model, expected to output continuous predictions.

  • loss_fn (Module) – Loss function suitable for regression (e.g. MSELoss, HuberLoss).

  • optim (Optimizer) – Optimizer for updating trainable parameters.

  • return_model (str) – Which model version to return after training. 'last' returns the final checkpoint, 'best' returns the best model by validation/training loss.

  • device (device | str) – Device to run training on ('cpu', 'cuda', or 'mps').

  • early_stopping_patience (int | None) – Stop training after this many epochs without improvement in loss. None disables.

  • checkpoint_config (Checkpoint | None) – Optional Checkpoint for saving training state to disk.

__init__(model, loss_fn, optim, return_model='last', device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Initialize a RegressionTrainer instance.

Parameters:
  • model (Module) – The neural network model to train.

  • loss_fn (Module) – The loss function to use.

  • optim (Optimizer) – The optimizer to use.

  • return_model (str) – Whether to return the ‘last’ or ‘best’ model after training.

  • device (device | str) – The device on which the data will be loaded.

  • early_stopping_patience (int | None) – The number of epochs to wait before early stopping.

  • checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.

fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]

Train the model with optional validation and early stopping.

The behaviour depends on the combination of val_loader and early_stopping_patience:

  • No validation loader, no early stopping: run full epochs.

  • No validation loader, early stopping set: stop on training loss.

  • Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.

  • Validation loader provided and early stopping set: monitor validation loss and stop when it plateaus.

Parameters:
  • train_loader (DataLoader) – DataLoader for training data.

  • epochs (int) – Total number of epochs to train for.

  • val_loader (DataLoader | None) – DataLoader for validation data (optional).

  • val_epochs (int) – How often to evaluate on validation set (in epochs).

Returns:

The trained model (returned according to return_model).

predict(dataloader_or_batch)[source]

Predicts the output of the model for a given dataloader or batch.

Parameters:

dataloader_or_batch (DataLoader | Tensor) – A DataLoader or a torch tensor.

Returns:

A list of predictions.

Return type:

list

class fenn.nn.Trainer(model, loss_fn, optim, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Bases: ABC

The base Trainer abstract class for classification and regression tasks.

Provides a common training loop with support for early stopping, checkpointing, and validation monitoring. Subclasses must implement fit() to define the per-epoch training logic and predict() to generate predictions from a model.

Subclasses:
Parameters:
  • model (Module)

  • loss_fn (Module)

  • optim (Optimizer)

  • device (device | str)

  • early_stopping_patience (int | None)

  • checkpoint_config (Checkpoint | None)

abstractmethod __init__(model, loss_fn, optim, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Initialize a Trainer instance to fit a neural network model.

Parameters:
  • model (Module) – The neural network model to train.

  • loss_fn (Module) – The loss function to use.

  • optim (Optimizer) – The optimizer to use.

  • num_classes – The number of classes to predict.

  • device (device | str) – The device on which the data will be loaded.

  • early_stopping_patience (int | None) – The number of epochs to wait before early stopping.

  • checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.

abstractmethod fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]

Train the model for a fixed number of epochs.

Runs the full training loop including forward/backward passes, validation evaluation, checkpointing, and early stopping. The exact behavior depends on the validation and early stopping configuration:

  • No validation loader, no early stopping: run full epochs.

  • No validation loader, early stopping set: stop on training loss.

  • Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.

  • Validation loader and early stopping set: monitor validation loss and stop when it plateaus for early_stopping_patience epochs.

Parameters:
  • train_loader (DataLoader) – PyTorch DataLoader yielding (data, labels) batches for training.

  • epochs (int) – Total number of training epochs. If resuming from a checkpoint, only the remaining epochs are run.

  • val_loader (DataLoader | None) – Optional DataLoader for validation evaluation.

  • val_epochs (int) – How frequently to evaluate on the validation set (e.g. val_epochs=2 means every 2 epochs).

Returns:

The trained model.

load_best_checkpoint()[source]

Load the best-performing checkpoint into the trainer’s model.

Restores the model weights from the checkpoint with the lowest validation (or training) loss recorded during training.

Raises:
  • ValueError – If no checkpoint configuration was provided at init.

  • FileNotFoundError – If no best checkpoint file exists.

Return type:

None

load_checkpoint(checkpoint_path)[source]

Load a checkpoint from the given file path and restore training state.

Restores the model weights, optimizer state, and epoch counter from a previously saved checkpoint file.

Parameters:

checkpoint_path (str | Path) – Path to the .pt checkpoint file.

Raises:
  • ValueError – If no checkpoint configuration was provided at init.

  • FileNotFoundError – If the checkpoint file does not exist.

Return type:

None

load_checkpoint_at_epoch(epoch)[source]

Load the checkpoint saved at a specific epoch.

Searches the checkpoint directory for the saved state at the requested epoch and restores the model and optimizer.

Parameters:

epoch (int) – The epoch whose checkpoint to load (1-indexed).

Raises:
  • ValueError – If no checkpoint configuration was provided at init.

  • FileNotFoundError – If no checkpoint exists for the given epoch.

Return type:

None

abstractmethod predict(dataloader_or_batch)[source]

Generate predictions from the trained model.

Runs inference on the provided data without computing gradients, returning model predictions in the same format as the training labels.

Parameters:

dataloader_or_batch (DataLoader | Tensor) – Either a PyTorch DataLoader yielding data batches, or a single tensor batch.

Returns:

A list of predictions (one per sample).

save_model(model_name='model.pth')[source]
Parameters:

model_name (str)