fenn.nn.trainers¶
- class fenn.nn.trainers.ClassificationTrainer(model, loss_fn, optim, num_classes, multi_label=False, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Bases:
TrainerA trainer for classification tasks with PyTorch models.
Supports binary, multi-class, and multi-label classification by adapting the loss computation and prediction logic based on the task type. Handles both single-label (
num_classes == 2→ binary,> 2→ multiclass) and multi-label (multi_label=True) scenarios.The automatic task type detection configures: - Binary: sigmoid activation, BCE loss, threshold at 0.5 - Multiclass: softmax activation, cross-entropy loss - Multi-label: sigmoid activation, binary cross-entropy per label
- Parameters:
model (Module) – The neural network model, expected to output logits for the classification task.
loss_fn (Module) – Loss function compatible with the task type (e.g. BCEWithLogitsLoss for binary/multi-label, CrossEntropyLoss for multiclass).
optim (Optimizer) – Optimizer for updating trainable parameters.
num_classes (int) – Number of classes (or labels in multi-label mode). Must be
>= 1.multi_label (bool) – Whether this is a multi-label classification problem. Requires
num_classes >= 2.device (device | str) – Device to run training on (
'cpu','cuda', or'mps').early_stopping_patience (int | None) – Stop training after this many epochs without improvement in validation/training loss.
Nonedisables.checkpoint_config (Checkpoint | None) – Optional
Checkpointfor saving training state to disk.
Note
For binary classification (
num_classes == 2,multi_label=False), labels should be[0, 1]shaped tensors. For multiclass, labels should be class indices. For multi-label, labels should be binary vectors of lengthnum_classes.- __init__(model, loss_fn, optim, num_classes, multi_label=False, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Initialize a ClassificationTrainer instance.
- Parameters:
model (Module) – The neural network model to train.
loss_fn (Module) – The loss function to use.
optim (Optimizer) – The optimizer to use.
num_classes (int) – The number of classes to predict.
multi_label (bool) – Whether to use multi-label classification.
device (device | str) – The device on which the data will be loaded.
early_stopping_patience (int | None) – The number of epochs to wait before early stopping.
checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.
- fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]¶
Train the model with optional validation and early stopping.
The behaviour depends on the combination of val_loader and
early_stopping_patience:No validation loader, no early stopping: run full
epochs.No validation loader, early stopping set: stop on training loss.
Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.
Validation loader provided and early stopping set: monitor validation loss and stop when it plateaus.
- Parameters:
train_loader (DataLoader) – DataLoader for training data.
epochs (int) – Total number of epochs to train for.
val_loader (DataLoader | None) – DataLoader for validation data (optional).
val_epochs (int) – How often to evaluate on validation set (in epochs).
- Returns:
The trained model (returned according to
return_model).
- predict(dataloader_or_batch, return_proba=False)[source]¶
Predicts the output of the model for a given dataloader or batch.
- Parameters:
dataloader_or_batch (DataLoader | Tensor) – A DataLoader or a torch tensor.
return_proba (bool) – If true, also returns the predicted probabilities alongside the predicted labels
- Returns:
A list of predictions. list[list]:
If return_proba=True, returns a tuple where: - first element is the list of predicted labels - second element is the list of predicted probabilities
- Return type:
list
- class fenn.nn.trainers.LoRATrainer(model, optim, task_type='SEQ_CLS', r=8, lora_alpha=16, lora_dropout=0.1, target_modules=None, bias='none', loss_fn=None, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Bases:
TrainerLoRATrainer extends the base Trainer to support Parameter-Efficient Fine-Tuning (PEFT) using LoRA (Low-Rank Adaptation).
Designed for HuggingFace transformer models. DataLoaders must yield dicts whose keys match the model’s forward signature (e.g.
input_ids,attention_mask,labels). Loss is taken fromoutputs.losswhen labels are present in the batch;loss_fnis used as a fallback if the model does not return a loss.- Parameters:
model (Module)
optim (Optimizer)
task_type (str)
r (int)
lora_alpha (int)
lora_dropout (float)
target_modules (List[str] | None)
bias (str)
loss_fn (Module | None)
device (device | str)
early_stopping_patience (int | None)
checkpoint_config (Checkpoint | None)
- __init__(model, optim, task_type='SEQ_CLS', r=8, lora_alpha=16, lora_dropout=0.1, target_modules=None, bias='none', loss_fn=None, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Initialize the LoRATrainer.
- Parameters:
model (Module) – The base HuggingFace model to fine-tune.
optim (Optimizer) – The optimizer.
task_type (str) – LoRA task type. One of
"SEQ_CLS","CAUSAL_LM","SEQ_2_SEQ_LM","TOKEN_CLS","QUESTION_ANS". Defaults to"SEQ_CLS".r (int) – LoRA rank — number of low-rank dimensions. Defaults to
8.lora_alpha (int) – LoRA scaling factor. Defaults to
16.lora_dropout (float) – Dropout applied to LoRA layers. Defaults to
0.1.target_modules (List[str] | None) – Module names to apply LoRA to (e.g.
["q_proj", "v_proj"]). IfNone, peft auto-detects based on the architecture.bias (str) – Which biases to train. One of
"none","all","lora_only". Defaults to"none".loss_fn (Module | None) – Optional external loss function. Used when the model does not return a loss (i.e. labels are absent from the batch). Ignored otherwise.
device (device | str) – Device to train on. Defaults to
"cpu".early_stopping_patience (int | None) – Epochs without improvement before early stopping. Disabled when
None.checkpoint_config (Checkpoint | None) – Checkpoint configuration. Disabled when
None.
- fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]¶
Train the model with optional validation and early stopping.
DataLoaders must yield dicts with at minimum
input_idsandattention_mask. Includelabelsto have the model (orloss_fn) compute the loss.- Parameters:
train_loader (DataLoader) – DataLoader for training data.
epochs (int) – Total number of epochs to train for.
val_loader (DataLoader | None) – DataLoader for validation data (optional).
val_epochs (int) – How often to run validation (in epochs).
- predict(dataloader_or_batch)[source]¶
Generate predictions for a dataloader or a single batch.
Labels are stripped from dict batches before inference so the model does not compute a loss during prediction.
For classification tasks (
SEQ_CLS,TOKEN_CLS,QUESTION_ANS), returns a flat list of predicted class indices.For generative tasks (
CAUSAL_LM,SEQ_2_SEQ_LM), returns a list of logit tensors (one per batch).- Parameters:
dataloader_or_batch (DataLoader | Dict | Tensor) – A DataLoader, a dict batch, or a raw tensor.
- Returns:
Predicted class indices (classification) or logit tensors (generative).
- Return type:
list
- class fenn.nn.trainers.RegressionTrainer(model, loss_fn, optim, return_model='last', device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Bases:
TrainerA trainer for regression tasks with PyTorch models.
Extends the base
Trainerwith regression-specific metrics (R² score, MSE) and continuous-value prediction logic. Handles single-target regression with optional validation and early stopping.- Parameters:
model (Module) – The neural network model, expected to output continuous predictions.
loss_fn (Module) – Loss function suitable for regression (e.g. MSELoss, HuberLoss).
optim (Optimizer) – Optimizer for updating trainable parameters.
return_model (str) – Which model version to return after training.
'last'returns the final checkpoint,'best'returns the best model by validation/training loss.device (device | str) – Device to run training on (
'cpu','cuda', or'mps').early_stopping_patience (int | None) – Stop training after this many epochs without improvement in loss.
Nonedisables.checkpoint_config (Checkpoint | None) – Optional
Checkpointfor saving training state to disk.
- __init__(model, loss_fn, optim, return_model='last', device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Initialize a RegressionTrainer instance.
- Parameters:
model (Module) – The neural network model to train.
loss_fn (Module) – The loss function to use.
optim (Optimizer) – The optimizer to use.
return_model (str) – Whether to return the ‘last’ or ‘best’ model after training.
device (device | str) – The device on which the data will be loaded.
early_stopping_patience (int | None) – The number of epochs to wait before early stopping.
checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.
- fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]¶
Train the model with optional validation and early stopping.
The behaviour depends on the combination of val_loader and
early_stopping_patience:No validation loader, no early stopping: run full
epochs.No validation loader, early stopping set: stop on training loss.
Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.
Validation loader provided and early stopping set: monitor validation loss and stop when it plateaus.
- Parameters:
train_loader (DataLoader) – DataLoader for training data.
epochs (int) – Total number of epochs to train for.
val_loader (DataLoader | None) – DataLoader for validation data (optional).
val_epochs (int) – How often to evaluate on validation set (in epochs).
- Returns:
The trained model (returned according to
return_model).
- class fenn.nn.trainers.Trainer(model, loss_fn, optim, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Bases:
ABCThe base Trainer abstract class for classification and regression tasks.
Provides a common training loop with support for early stopping, checkpointing, and validation monitoring. Subclasses must implement
fit()to define the per-epoch training logic andpredict()to generate predictions from a model.- Subclasses:
ClassificationTrainerfor classification tasks.RegressionTrainerfor regression tasks.LoRATrainerfor parameter-efficient fine-tuning.
- Parameters:
model (Module)
loss_fn (Module)
optim (Optimizer)
device (device | str)
early_stopping_patience (int | None)
checkpoint_config (Checkpoint | None)
- abstractmethod __init__(model, loss_fn, optim, device='cpu', early_stopping_patience=None, checkpoint_config=None)[source]¶
Initialize a Trainer instance to fit a neural network model.
- Parameters:
model (Module) – The neural network model to train.
loss_fn (Module) – The loss function to use.
optim (Optimizer) – The optimizer to use.
num_classes – The number of classes to predict.
device (device | str) – The device on which the data will be loaded.
early_stopping_patience (int | None) – The number of epochs to wait before early stopping.
checkpoint_config (Checkpoint | None) – The checkpoint configuration. If None, checkpointing is disabled.
- abstractmethod fit(train_loader, epochs, val_loader=None, val_epochs=1)[source]¶
Train the model for a fixed number of epochs.
Runs the full training loop including forward/backward passes, validation evaluation, checkpointing, and early stopping. The exact behavior depends on the validation and early stopping configuration:
No validation loader, no early stopping: run full
epochs.No validation loader, early stopping set: stop on training loss.
Validation loader provided, no early stopping: evaluate every epoch but continue regardless of metrics.
Validation loader and early stopping set: monitor validation loss and stop when it plateaus for
early_stopping_patienceepochs.
- Parameters:
train_loader (DataLoader) – PyTorch
DataLoaderyielding(data, labels)batches for training.epochs (int) – Total number of training epochs. If resuming from a checkpoint, only the remaining epochs are run.
val_loader (DataLoader | None) – Optional
DataLoaderfor validation evaluation.val_epochs (int) – How frequently to evaluate on the validation set (e.g.
val_epochs=2means every 2 epochs).
- Returns:
The trained model.
- load_best_checkpoint()[source]¶
Load the best-performing checkpoint into the trainer’s model.
Restores the model weights from the checkpoint with the lowest validation (or training) loss recorded during training.
- Raises:
ValueError – If no checkpoint configuration was provided at init.
FileNotFoundError – If no best checkpoint file exists.
- Return type:
None
- load_checkpoint(checkpoint_path)[source]¶
Load a checkpoint from the given file path and restore training state.
Restores the model weights, optimizer state, and epoch counter from a previously saved checkpoint file.
- Parameters:
checkpoint_path (str | Path) – Path to the
.ptcheckpoint file.- Raises:
ValueError – If no checkpoint configuration was provided at init.
FileNotFoundError – If the checkpoint file does not exist.
- Return type:
None
- load_checkpoint_at_epoch(epoch)[source]¶
Load the checkpoint saved at a specific epoch.
Searches the checkpoint directory for the saved state at the requested epoch and restores the model and optimizer.
- Parameters:
epoch (int) – The epoch whose checkpoint to load (1-indexed).
- Raises:
ValueError – If no checkpoint configuration was provided at init.
FileNotFoundError – If no checkpoint exists for the given epoch.
- Return type:
None
- abstractmethod predict(dataloader_or_batch)[source]¶
Generate predictions from the trained model.
Runs inference on the provided data without computing gradients, returning model predictions in the same format as the training labels.
- Parameters:
dataloader_or_batch (DataLoader | Tensor) – Either a PyTorch
DataLoaderyielding data batches, or a single tensor batch.- Returns:
A list of predictions (one per sample).