fenn.nn¶
- class fenn.nn.Checkpoint(*, name='checkpoint', dir, epochs=None, save_best=True)[source]¶
Bases:
objectCheckpoint training state at given epochs and/or always the best model.
Saves full
TrainingStatesnapshots (model weights, optimizer state, epoch counter, metrics) during training so that training can be resumed or the best model restored later.- Parameters:
name (str) – Base filename for checkpoint files (without extension).
dir (Path | str) – Directory to save checkpoint files in.
epochs (int | List[int] | None) – When to save checkpoints — an
intsaves every N epochs, alist[int]saves at specific epochs, orNoneto save only the best model.save_best (bool) – If
True, save the best model seen so far, updated whenever validation/training loss improves.
Example
>>> checkpoint = Checkpoint(dir="checkpoints/", epochs=5, save_best=True) >>> trainer = Trainer(model, loss_fn, optimizer, checkpoint_config=checkpoint)
- __init__(*, name='checkpoint', dir, epochs=None, save_best=True)[source]¶
Initialize the checkpoint configuration.
- Parameters:
name (str) – The name of the checkpoint file.
dir (Path | str) – The directory to save checkpoints to.
epochs (int | List[int] | None) – The epochs at which to save checkpoints.
save_best (bool) – Whether to checkpoint the best model (based on validation or training loss).
- load(checkpoint_path, device=None)[source]¶
Load a checkpoint from the given path.
- Parameters:
path – Path to the checkpoint file.
device (device | None) – The device to load the checkpoint onto.
checkpoint_path (str | Path)
- Returns:
The training state of the checkpoint.
- Return type:
- load_at_epoch(epoch, device=None)[source]¶
Load the checkpoint at the given epoch.
- Parameters:
epoch (int) – Epoch to load the checkpoint at.
device (device | None) – The device to load the checkpoint onto.
- Returns:
The training state of the checkpoint.
- Return type:
- load_best(device=None)[source]¶
Load the best checkpoint.
- Parameters:
device (device | None) – The device to load the checkpoint onto.
- Returns:
The training state of the checkpoint.
- Return type:
- save(state, is_best=False)[source]¶
Save a checkpoint of the training state at the current epoch.
- Parameters:
state (TrainingState) – The training state to checkpoint.
is_best (bool) – If true save as best model
- Return type:
None
- class fenn.nn.ClassificationTrainer(model, loss_fn, optim, num_classes, multi_label=False, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Bases:
TrainerA trainer for classification tasks with PyTorch models.
Supports binary, multi-class, and multi-label classification by adapting the loss computation and prediction logic based on the task type. Handles both single-label (
num_classes == 2→ binary,> 2→ multiclass) and multi-label (multi_label=True) scenarios.The automatic task type detection configures: - Binary: sigmoid activation, BCE loss, threshold at 0.5 - Multiclass: softmax activation, cross-entropy loss - Multi-label: sigmoid activation, binary cross-entropy per label
- Parameters:
model (Module) – The neural network model, expected to output logits for the classification task.
loss_fn (Module) – Loss function compatible with the task type (e.g. BCEWithLogitsLoss for binary/multi-label, CrossEntropyLoss for multiclass).
optim (Optimizer) – Optimizer for updating trainable parameters.
num_classes (int) – Number of classes (or labels in multi-label mode). Must be
>= 1.multi_label (bool) – Whether this is a multi-label classification problem. Requires
num_classes >= 2.device (device | str) – Device to run training on (
'cpu','cuda', or'mps').early_stopping_patience (int | None) – Stop training after this many epochs without improvement in validation/training loss.
Nonedisables.checkpoint_config (Checkpoint | None) – Optional
Checkpointfor saving training state to disk.
Note
For binary classification (
num_classes == 2,multi_label=False), labels should be[0, 1]shaped tensors. For multiclass, labels should be class indices. For multi-label, labels should be binary vectors of lengthnum_classes.- __init__(model, loss_fn, optim, num_classes, multi_label=False, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Initialize a ClassificationTrainer instance.
- Parameters:
model (Module) – The neural network model to train.
loss_fn (Module) – The loss function to use.
optim (Optimizer) – The optimizer to use.
num_classes (int) – The number of classes to predict.
multi_label (bool) – Whether to use multi-label classification.
device (device | str) – The device on which the data will be loaded.
early_stopping_patience (int | None) – The number of epochs to wait before early stopping.
checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.
- fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]¶
Train the model with optional validation and early stopping.
The behaviour depends on the combination of val_loader and
early_stopping_patience:No validation loader, no early stopping: run full
epochs.No validation loader, early stopping set: stop on training loss.
Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.
Validation loader provided and early stopping set: monitor validation loss and stop when it plateaus.
- Parameters:
train_loader (DataLoader) – DataLoader for training data.
epochs (int) – Total number of epochs to train for.
val_loader (DataLoader | None) – DataLoader for validation data (optional).
val_epochs (int) – How often to evaluate on validation set (in epochs).
- Returns:
The trained model (returned according to
return_model).
- predict(dataloader_or_batch, return_proba=False)[source]¶
Predicts the output of the model for a given dataloader or batch.
- Parameters:
dataloader_or_batch (DataLoader | Tensor) – A DataLoader or a torch tensor.
return_proba (bool) – If true, also returns the predicted probabilities alongside the predicted labels
- Returns:
A list of predictions. list[list]:
If return_proba=True, returns a tuple where: - first element is the list of predicted labels - second element is the list of predicted probabilities
- Return type:
list
- class fenn.nn.RegressionTrainer(model, loss_fn, optim, return_model='last', device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Bases:
TrainerA trainer for regression tasks with PyTorch models.
Extends the base
Trainerwith regression-specific metrics (R² score, MSE) and continuous-value prediction logic. Handles single-target regression with optional validation and early stopping.- Parameters:
model (Module) – The neural network model, expected to output continuous predictions.
loss_fn (Module) – Loss function suitable for regression (e.g. MSELoss, HuberLoss).
optim (Optimizer) – Optimizer for updating trainable parameters.
return_model (str) – Which model version to return after training.
'last'returns the final checkpoint,'best'returns the best model by validation/training loss.device (device | str) – Device to run training on (
'cpu','cuda', or'mps').early_stopping_patience (int | None) – Stop training after this many epochs without improvement in loss.
Nonedisables.checkpoint_config (Checkpoint | None) – Optional
Checkpointfor saving training state to disk.
- __init__(model, loss_fn, optim, return_model='last', device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Initialize a RegressionTrainer instance.
- Parameters:
model (Module) – The neural network model to train.
loss_fn (Module) – The loss function to use.
optim (Optimizer) – The optimizer to use.
return_model (str) – Whether to return the ‘last’ or ‘best’ model after training.
device (device | str) – The device on which the data will be loaded.
early_stopping_patience (int | None) – The number of epochs to wait before early stopping.
checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.
- fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]¶
Train the model with optional validation and early stopping.
The behaviour depends on the combination of val_loader and
early_stopping_patience:No validation loader, no early stopping: run full
epochs.No validation loader, early stopping set: stop on training loss.
Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.
Validation loader provided and early stopping set: monitor validation loss and stop when it plateaus.
- Parameters:
train_loader (DataLoader) – DataLoader for training data.
epochs (int) – Total number of epochs to train for.
val_loader (DataLoader | None) – DataLoader for validation data (optional).
val_epochs (int) – How often to evaluate on validation set (in epochs).
- Returns:
The trained model (returned according to
return_model).
- class fenn.nn.Trainer(model, loss_fn, optim, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Bases:
ABCThe base Trainer abstract class for classification and regression tasks.
Provides a common training loop with support for early stopping, checkpointing, and validation monitoring. Subclasses must implement
fit()to define the per-epoch training logic andpredict()to generate predictions from a model.- Subclasses:
ClassificationTrainerfor classification tasks.RegressionTrainerfor regression tasks.LoRATrainerfor parameter-efficient fine-tuning.
- Parameters:
model (Module)
loss_fn (Module)
optim (Optimizer)
device (device | str)
early_stopping_patience (int | None)
checkpoint_config (Checkpoint | None)
- abstractmethod __init__(model, loss_fn, optim, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Initialize a Trainer instance to fit a neural network model.
- Parameters:
model (Module) – The neural network model to train.
loss_fn (Module) – The loss function to use.
optim (Optimizer) – The optimizer to use.
num_classes – The number of classes to predict.
device (device | str) – The device on which the data will be loaded.
early_stopping_patience (int | None) – The number of epochs to wait before early stopping.
checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.
- abstractmethod fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]¶
Train the model for a fixed number of epochs.
Runs the full training loop including forward/backward passes, validation evaluation, checkpointing, and early stopping. The exact behavior depends on the validation and early stopping configuration:
No validation loader, no early stopping: run full
epochs.No validation loader, early stopping set: stop on training loss.
Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.
Validation loader and early stopping set: monitor validation loss and stop when it plateaus for
early_stopping_patienceepochs.
- Parameters:
train_loader (DataLoader) – PyTorch
DataLoaderyielding(data, labels)batches for training.epochs (int) – Total number of training epochs. If resuming from a checkpoint, only the remaining epochs are run.
val_loader (DataLoader | None) – Optional
DataLoaderfor validation evaluation.val_epochs (int) – How frequently to evaluate on the validation set (e.g.
val_epochs=2means every 2 epochs).
- Returns:
The trained model.
- load_best_checkpoint()[source]¶
Load the best-performing checkpoint into the trainer’s model.
Restores the model weights from the checkpoint with the lowest validation (or training) loss recorded during training.
- Raises:
ValueError – If no checkpoint configuration was provided at init.
FileNotFoundError – If no best checkpoint file exists.
- Return type:
None
- load_checkpoint(checkpoint_path)[source]¶
Load a checkpoint from the given file path and restore training state.
Restores the model weights, optimizer state, and epoch counter from a previously saved checkpoint file.
- Parameters:
checkpoint_path (str | Path) – Path to the
.ptcheckpoint file.- Raises:
ValueError – If no checkpoint configuration was provided at init.
FileNotFoundError – If the checkpoint file does not exist.
- Return type:
None
- load_checkpoint_at_epoch(epoch)[source]¶
Load the checkpoint saved at a specific epoch.
Searches the checkpoint directory for the saved state at the requested epoch and restores the model and optimizer.
- Parameters:
epoch (int) – The epoch whose checkpoint to load (1-indexed).
- Raises:
ValueError – If no checkpoint configuration was provided at init.
FileNotFoundError – If no checkpoint exists for the given epoch.
- Return type:
None
- abstractmethod predict(dataloader_or_batch)[source]¶
Generate predictions from the trained model.
Runs inference on the provided data without computing gradients, returning model predictions in the same format as the training labels.
- Parameters:
dataloader_or_batch (DataLoader | Tensor) – Either a PyTorch
DataLoaderyielding data batches, or a single tensor batch.- Returns:
A list of predictions (one per sample).