fenn.nn.trainers

class fenn.nn.trainers.ClassificationTrainer(model, loss_fn, optim, num_classes, multi_label=False, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Bases: Trainer

A trainer for classification tasks with PyTorch models.

Supports binary, multi-class, and multi-label classification by adapting the loss computation and prediction logic based on the task type. Handles both single-label (num_classes == 2 → binary, > 2 → multiclass) and multi-label (multi_label=True) scenarios.

The automatic task type detection configures: - Binary: sigmoid activation, BCE loss, threshold at 0.5 - Multiclass: softmax activation, cross-entropy loss - Multi-label: sigmoid activation, binary cross-entropy per label

Parameters:
  • model (Module) – The neural network model, expected to output logits for the classification task.

  • loss_fn (Module) – Loss function compatible with the task type (e.g. BCEWithLogitsLoss for binary/multi-label, CrossEntropyLoss for multiclass).

  • optim (Optimizer) – Optimizer for updating trainable parameters.

  • num_classes (int) – Number of classes (or labels in multi-label mode). Must be >= 1.

  • multi_label (bool) – Whether this is a multi-label classification problem. Requires num_classes >= 2.

  • device (device | str) – Device to run training on ('cpu', 'cuda', or 'mps').

  • early_stopping_patience (int | None) – Stop training after this many epochs without improvement in validation/training loss. None disables.

  • checkpoint_config (Checkpoint | None) – Optional Checkpoint for saving training state to disk.

Note

For binary classification (num_classes == 2, multi_label=False), labels should be [0, 1] shaped tensors. For multiclass, labels should be class indices. For multi-label, labels should be binary vectors of length num_classes.

__init__(model, loss_fn, optim, num_classes, multi_label=False, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Initialize a ClassificationTrainer instance.

Parameters:
  • model (Module) – The neural network model to train.

  • loss_fn (Module) – The loss function to use.

  • optim (Optimizer) – The optimizer to use.

  • num_classes (int) – The number of classes to predict.

  • multi_label (bool) – Whether to use multi-label classification.

  • device (device | str) – The device on which the data will be loaded.

  • early_stopping_patience (int | None) – The number of epochs to wait before early stopping.

  • checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.

fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]

Train the model with optional validation and early stopping.

The behaviour depends on the combination of val_loader and early_stopping_patience:

  • No validation loader, no early stopping: run full epochs.

  • No validation loader, early stopping set: stop on training loss.

  • Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.

  • Validation loader provided and early stopping set: monitor validation loss and stop when it plateaus.

Parameters:
  • train_loader (DataLoader) – DataLoader for training data.

  • epochs (int) – Total number of epochs to train for.

  • val_loader (DataLoader | None) – DataLoader for validation data (optional).

  • val_epochs (int) – How often to evaluate on validation set (in epochs).

Returns:

The trained model (returned according to return_model).

predict(dataloader_or_batch, return_proba=False)[source]

Predicts the output of the model for a given dataloader or batch.

Parameters:
  • dataloader_or_batch (DataLoader | Tensor) – A DataLoader or a torch tensor.

  • return_proba (bool) – If true, also returns the predicted probabilities alongside the predicted labels

Returns:

A list of predictions. list[list]:

If return_proba=True, returns a tuple where: - first element is the list of predicted labels - second element is the list of predicted probabilities

Return type:

list

class fenn.nn.trainers.LoRATrainer(model, optim, task_type='SEQ_CLS', r=8, lora_alpha=16, lora_dropout=0.1, target_modules=None, bias='none', loss_fn=None, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Bases: Trainer

LoRATrainer extends the base Trainer to support Parameter-Efficient Fine-Tuning (PEFT) using LoRA (Low-Rank Adaptation).

Designed for HuggingFace transformer models. DataLoaders must yield dicts whose keys match the model’s forward signature (e.g. input_ids, attention_mask, labels). Loss is taken from outputs.loss when labels are present in the batch; loss_fn is used as a fallback if the model does not return a loss.

Parameters:
  • model (Module)

  • optim (Optimizer)

  • task_type (str)

  • r (int)

  • lora_alpha (int)

  • lora_dropout (float)

  • target_modules (List[str] | None)

  • bias (str)

  • loss_fn (Module | None)

  • device (device | str)

  • early_stopping_patience (int | None)

  • checkpoint_config (Checkpoint | None)

__init__(model, optim, task_type='SEQ_CLS', r=8, lora_alpha=16, lora_dropout=0.1, target_modules=None, bias='none', loss_fn=None, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Initialize the LoRATrainer.

Parameters:
  • model (Module) – The base HuggingFace model to fine-tune.

  • optim (Optimizer) – The optimizer.

  • task_type (str) – LoRA task type. One of "SEQ_CLS", "CAUSAL_LM", "SEQ_2_SEQ_LM", "TOKEN_CLS", "QUESTION_ANS". Defaults to "SEQ_CLS".

  • r (int) – LoRA rank — number of low-rank dimensions. Defaults to 8.

  • lora_alpha (int) – LoRA scaling factor. Defaults to 16.

  • lora_dropout (float) – Dropout applied to LoRA layers. Defaults to 0.1.

  • target_modules (List[str] | None) – Module names to apply LoRA to (e.g. ["q_proj", "v_proj"]). If None, peft auto-detects based on the architecture.

  • bias (str) – Which biases to train. One of "none", "all", "lora_only". Defaults to "none".

  • loss_fn (Module | None) – Optional external loss function. Used when the model does not return a loss (i.e. labels are absent from the batch). Ignored otherwise.

  • device (device | str) – Device to train on. Defaults to "cpu".

  • early_stopping_patience (int | None) – Epochs without improvement before early stopping. Disabled when None.

  • checkpoint_config (Checkpoint | None) – Checkpoint configuration. Disabled when None.

fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]

Train the model with optional validation and early stopping.

DataLoaders must yield dicts with at minimum input_ids and attention_mask. Include labels to have the model (or loss_fn) compute the loss.

Parameters:
  • train_loader (DataLoader) – DataLoader for training data.

  • epochs (int) – Total number of epochs to train for.

  • val_loader (DataLoader | None) – DataLoader for validation data (optional).

  • val_epochs (int) – How often to run validation (in epochs).

predict(dataloader_or_batch)[source]

Generate predictions for a dataloader or a single batch.

Labels are stripped from dict batches before inference so the model does not compute a loss during prediction.

For classification tasks (SEQ_CLS, TOKEN_CLS, QUESTION_ANS), returns a flat list of predicted class indices.

For generative tasks (CAUSAL_LM, SEQ_2_SEQ_LM), returns a list of logit tensors (one per batch).

Parameters:

dataloader_or_batch (DataLoader | Dict | Tensor) – A DataLoader, a dict batch, or a raw tensor.

Returns:

Predicted class indices (classification) or logit tensors (generative).

Return type:

list

class fenn.nn.trainers.RegressionTrainer(model, loss_fn, optim, return_model='last', device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Bases: Trainer

A trainer for regression tasks with PyTorch models.

Extends the base Trainer with regression-specific metrics (R² score, MSE) and continuous-value prediction logic. Handles single-target regression with optional validation and early stopping.

Parameters:
  • model (Module) – The neural network model, expected to output continuous predictions.

  • loss_fn (Module) – Loss function suitable for regression (e.g. MSELoss, HuberLoss).

  • optim (Optimizer) – Optimizer for updating trainable parameters.

  • return_model (str) – Which model version to return after training. 'last' returns the final checkpoint, 'best' returns the best model by validation/training loss.

  • device (device | str) – Device to run training on ('cpu', 'cuda', or 'mps').

  • early_stopping_patience (int | None) – Stop training after this many epochs without improvement in loss. None disables.

  • checkpoint_config (Checkpoint | None) – Optional Checkpoint for saving training state to disk.

__init__(model, loss_fn, optim, return_model='last', device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Initialize a RegressionTrainer instance.

Parameters:
  • model (Module) – The neural network model to train.

  • loss_fn (Module) – The loss function to use.

  • optim (Optimizer) – The optimizer to use.

  • return_model (str) – Whether to return the ‘last’ or ‘best’ model after training.

  • device (device | str) – The device on which the data will be loaded.

  • early_stopping_patience (int | None) – The number of epochs to wait before early stopping.

  • checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.

fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]

Train the model with optional validation and early stopping.

The behaviour depends on the combination of val_loader and early_stopping_patience:

  • No validation loader, no early stopping: run full epochs.

  • No validation loader, early stopping set: stop on training loss.

  • Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.

  • Validation loader provided and early stopping set: monitor validation loss and stop when it plateaus.

Parameters:
  • train_loader (DataLoader) – DataLoader for training data.

  • epochs (int) – Total number of epochs to train for.

  • val_loader (DataLoader | None) – DataLoader for validation data (optional).

  • val_epochs (int) – How often to evaluate on validation set (in epochs).

Returns:

The trained model (returned according to return_model).

predict(dataloader_or_batch)[source]

Predicts the output of the model for a given dataloader or batch.

Parameters:

dataloader_or_batch (DataLoader | Tensor) – A DataLoader or a torch tensor.

Returns:

A list of predictions.

Return type:

list

class fenn.nn.trainers.Trainer(model, loss_fn, optim, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Bases: ABC

The base Trainer abstract class for classification and regression tasks.

Provides a common training loop with support for early stopping, checkpointing, and validation monitoring. Subclasses must implement fit() to define the per-epoch training logic and predict() to generate predictions from a model.

Subclasses:
Parameters:
  • model (Module)

  • loss_fn (Module)

  • optim (Optimizer)

  • device (device | str)

  • early_stopping_patience (int | None)

  • checkpoint_config (Checkpoint | None)

abstractmethod __init__(model, loss_fn, optim, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]

Initialize a Trainer instance to fit a neural network model.

Parameters:
  • model (Module) – The neural network model to train.

  • loss_fn (Module) – The loss function to use.

  • optim (Optimizer) – The optimizer to use.

  • num_classes – The number of classes to predict.

  • device (device | str) – The device on which the data will be loaded.

  • early_stopping_patience (int | None) – The number of epochs to wait before early stopping.

  • checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.

abstractmethod fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]

Train the model for a fixed number of epochs.

Runs the full training loop including forward/backward passes, validation evaluation, checkpointing, and early stopping. The exact behavior depends on the validation and early stopping configuration:

  • No validation loader, no early stopping: run full epochs.

  • No validation loader, early stopping set: stop on training loss.

  • Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.

  • Validation loader and early stopping set: monitor validation loss and stop when it plateaus for early_stopping_patience epochs.

Parameters:
  • train_loader (DataLoader) – PyTorch DataLoader yielding (data, labels) batches for training.

  • epochs (int) – Total number of training epochs. If resuming from a checkpoint, only the remaining epochs are run.

  • val_loader (DataLoader | None) – Optional DataLoader for validation evaluation.

  • val_epochs (int) – How frequently to evaluate on the validation set (e.g. val_epochs=2 means every 2 epochs).

Returns:

The trained model.

load_best_checkpoint()[source]

Load the best-performing checkpoint into the trainer’s model.

Restores the model weights from the checkpoint with the lowest validation (or training) loss recorded during training.

Raises:
  • ValueError – If no checkpoint configuration was provided at init.

  • FileNotFoundError – If no best checkpoint file exists.

Return type:

None

load_checkpoint(checkpoint_path)[source]

Load a checkpoint from the given file path and restore training state.

Restores the model weights, optimizer state, and epoch counter from a previously saved checkpoint file.

Parameters:

checkpoint_path (str | Path) – Path to the .pt checkpoint file.

Raises:
  • ValueError – If no checkpoint configuration was provided at init.

  • FileNotFoundError – If the checkpoint file does not exist.

Return type:

None

load_checkpoint_at_epoch(epoch)[source]

Load the checkpoint saved at a specific epoch.

Searches the checkpoint directory for the saved state at the requested epoch and restores the model and optimizer.

Parameters:

epoch (int) – The epoch whose checkpoint to load (1-indexed).

Raises:
  • ValueError – If no checkpoint configuration was provided at init.

  • FileNotFoundError – If no checkpoint exists for the given epoch.

Return type:

None

abstractmethod predict(dataloader_or_batch)[source]

Generate predictions from the trained model.

Runs inference on the provided data without computing gradients, returning model predictions in the same format as the training labels.

Parameters:

dataloader_or_batch (DataLoader | Tensor) – Either a PyTorch DataLoader yielding data batches, or a single tensor batch.

Returns:

A list of predictions (one per sample).

save_model(model_name='model.pth')[source]
Parameters:

model_name (str)