Source code for fenn.nn.trainers.lora_trainer

import inspect
from copy import deepcopy
from typing import Dict, List, Optional, Union, cast

import torch
import torch.nn
import torch.optim
from rich.progress import (
    BarColumn,
    MofNCompleteColumn,
    Progress,
    TextColumn,
    TimeElapsedColumn,
)
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader

from fenn.logging import Logger
from fenn.nn.utils import Checkpoint

from .trainer import Trainer

try:
    from peft import LoraConfig, TaskType, get_peft_model

    PEFT_AVAILABLE = True
except ImportError:
    PEFT_AVAILABLE = False
    LoraConfig = None  # type: ignore
    TaskType = None  # type: ignore

_SUPPORTED_TASK_TYPES = {
    "SEQ_CLS",
    "CAUSAL_LM",
    "SEQ_2_SEQ_LM",
    "TOKEN_CLS",
    "QUESTION_ANS",
}
_GENERATIVE_TASK_TYPES = {"CAUSAL_LM", "SEQ_2_SEQ_LM"}


[docs] class LoRATrainer(Trainer): """ LoRATrainer extends the base Trainer to support Parameter-Efficient Fine-Tuning (PEFT) using LoRA (Low-Rank Adaptation). Designed for HuggingFace transformer models. DataLoaders must yield dicts whose keys match the model's forward signature (e.g. ``input_ids``, ``attention_mask``, ``labels``). Loss is taken from ``outputs.loss`` when labels are present in the batch; ``loss_fn`` is used as a fallback if the model does not return a loss. """
[docs] def __init__( self, model: torch.nn.Module, optim: torch.optim.Optimizer, task_type: str = "SEQ_CLS", r: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.1, target_modules: Optional[List[str]] = None, bias: str = "none", loss_fn: Optional[torch.nn.Module] = None, device: Union[torch.device, str] = "cpu", early_stopping_patience: Optional[int] = None, checkpoint_config: Optional[Checkpoint] = None, ): """Initialize the LoRATrainer. Args: model: The base HuggingFace model to fine-tune. optim: The optimizer. task_type: LoRA task type. One of ``"SEQ_CLS"``, ``"CAUSAL_LM"``, ``"SEQ_2_SEQ_LM"``, ``"TOKEN_CLS"``, ``"QUESTION_ANS"``. Defaults to ``"SEQ_CLS"``. r: LoRA rank — number of low-rank dimensions. Defaults to ``8``. lora_alpha: LoRA scaling factor. Defaults to ``16``. lora_dropout: Dropout applied to LoRA layers. Defaults to ``0.1``. target_modules: Module names to apply LoRA to (e.g. ``["q_proj", "v_proj"]``). If ``None``, peft auto-detects based on the architecture. bias: Which biases to train. One of ``"none"``, ``"all"``, ``"lora_only"``. Defaults to ``"none"``. loss_fn: Optional external loss function. Used when the model does not return a loss (i.e. labels are absent from the batch). Ignored otherwise. device: Device to train on. Defaults to ``"cpu"``. early_stopping_patience: Epochs without improvement before early stopping. Disabled when ``None``. checkpoint_config: Checkpoint configuration. Disabled when ``None``. """ if not PEFT_AVAILABLE: raise ImportError( "peft is not installed. Install fenn with the [transformers] extra: " "pip install fenn[transformers]" ) task_type_upper = task_type.upper() if task_type_upper not in _SUPPORTED_TASK_TYPES: raise ValueError( f"Unsupported task_type '{task_type}'. " f"Choose from: {sorted(_SUPPORTED_TASK_TYPES)}" ) lora_config = LoraConfig( task_type=getattr(TaskType, task_type_upper), r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=target_modules, bias=bias, ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # The optimizer was built before get_peft_model injected the LoRA parameters, # so those new tensors are not tracked. Rebuild it now with only the trainable # (LoRA) parameters, preserving the original optimizer class and hyperparameters. # Filter optim.defaults to valid constructor args only — PyTorch adds internal # keys (e.g. decoupled_weight_decay) that are not accepted by __init__. optimizer_class = type(optim) valid_init_params = set( inspect.signature(optimizer_class.__init__).parameters ) - {"self", "params"} optimizer_defaults = { k: v for k, v in optim.defaults.items() if k in valid_init_params } trainable_params = [p for p in model.parameters() if p.requires_grad] optim = optimizer_class(trainable_params, **optimizer_defaults) super().__init__( model=model, loss_fn=loss_fn, # type: ignore[arg-type] optim=optim, device=device, early_stopping_patience=early_stopping_patience, checkpoint_config=checkpoint_config, ) self._logger = Logger() self._task_type = task_type_upper self._is_generative = task_type_upper in _GENERATIVE_TASK_TYPES self._logger.display_info( f"LoRA applied — task: {task_type_upper} | r={r} | alpha={lora_alpha} | dropout={lora_dropout}" )
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _forward( self, batch: Dict ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """Run one forward pass. Returns: (loss, logits) — either may be ``None`` if the model / batch does not produce them. """ outputs = self._model(**batch) loss: Optional[torch.Tensor] = getattr(outputs, "loss", None) logits: Optional[torch.Tensor] = getattr(outputs, "logits", None) if loss is None and self._loss_fn is not None and logits is not None: labels = batch.get("labels") if labels is not None: loss = self._loss_fn(logits, labels) return loss, logits # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def fit( self, train_loader: DataLoader, epochs: int, val_loader: Optional[DataLoader] = None, val_epochs: int = 1, ): """Train the model with optional validation and early stopping. DataLoaders must yield dicts with at minimum ``input_ids`` and ``attention_mask``. Include ``labels`` to have the model (or ``loss_fn``) compute the loss. Args: train_loader: DataLoader for training data. epochs: Total number of epochs to train for. val_loader: DataLoader for validation data (optional). val_epochs: How often to run validation (in epochs). """ state = self._state progress = Progress( TextColumn( "[bold blue]Epoch {task.fields[epoch]}/{task.fields[total_epochs]}" ), BarColumn(), MofNCompleteColumn(), TimeElapsedColumn(), ) progress.start() epoch_task = progress.add_task( "Training", total=epochs - state.epoch, epoch=state.epoch, total_epochs=epochs, info="", ) for epoch in range(state.epoch + 1, epochs + 1): state.epoch = epoch state.model_state_dict = None state.optimizer_state_dict = None # --- TRAIN --- self._model.train() total_loss = 0.0 n_batches = 0 for batch in train_loader: batch = self._move_to_device(batch, self._device) loss, _ = self._forward(batch) if loss is None: raise ValueError( "No loss was returned from the model. Ensure that `labels` are " "included in the batch or pass a `loss_fn` to LoRATrainer." ) self._optimizer.zero_grad(set_to_none=True) loss.backward() self._optimizer.step() total_loss += float(loss.item()) n_batches += 1 if n_batches == 0: raise ValueError("train_loader produced 0 batches; cannot train.") state.train_loss = total_loss / n_batches progress.update( epoch_task, # pyright: ignore[reportArgumentType] advance=1, epoch=epoch, info=f"Train Loss: {state.train_loss:.4f}", ) # --- NO VALIDATION --- if val_loader is None: state.val_loss = None progress.console.print( f"[bold blue]Epoch {epoch}/{epochs}[/bold blue] " f"Train Loss: {state.train_loss:.4f}" ) Logger().display_info( f"Epoch {epoch}/{epochs} - Train Loss: {state.train_loss:.4f}", display_on_terminal=False, ) if state.train_loss < state.best_train_loss: state.best_train_loss = state.train_loss state.patience_counter = 0 else: state.patience_counter += 1 # --- VALIDATION --- elif epoch % val_epochs == 0 or epoch == epochs: self._model.eval() val_total_loss = 0.0 val_n_batches = 0 all_preds: list = [] all_labels: list = [] with torch.no_grad(): for batch in val_loader: batch = self._move_to_device(batch, self._device) val_loss, logits = self._forward(batch) if val_loss is not None: val_total_loss += float(val_loss.item()) val_n_batches += 1 if logits is not None and not self._is_generative: preds = torch.argmax(logits, dim=-1) all_preds.extend(preds.cpu().tolist()) labels = batch.get("labels") if labels is not None: all_labels.extend(labels.cpu().tolist()) if val_n_batches == 0: raise ValueError("val_loader produced 0 batches; cannot validate.") state.val_loss = val_total_loss / val_n_batches if all_preds and all_labels: val_acc = accuracy_score(all_labels, all_preds) state.acc = val_acc progress.console.print( f"[bold blue]Epoch {epoch}/{epochs}[/bold blue] " f"Train Loss: {state.train_loss:.4f} | " f"Val Loss: {state.val_loss:.4f} | Val Acc: {val_acc:.4f}" ) Logger().display_info( f"Epoch {epoch}/{epochs} - Train Loss: {state.train_loss:.4f} | " f"Val Loss: {state.val_loss:.4f} | Val Acc: {val_acc:.4f}", display_on_terminal=False, ) else: progress.console.print( f"[bold blue]Epoch {epoch}/{epochs}[/bold blue] " f"Train Loss: {state.train_loss:.4f} | Val Loss: {state.val_loss:.4f}" ) Logger().display_info( f"Epoch {epoch}/{epochs} - Train Loss: {state.train_loss:.4f} | " f"Val Loss: {state.val_loss:.4f}", display_on_terminal=False, ) if state.val_loss < state.best_val_loss: state.best_val_loss = state.val_loss self._best_state = state.clone( model_state_dict=self._model.state_dict(), optimizer_state_dict=self._optimizer.state_dict(), ) self._best_model = deepcopy(self._model) self._best_model.load_state_dict(self._best_state.model_state_dict) # pyright: ignore[reportOptionalMemberAccess, reportArgumentType] if self._checkpoint is not None: self._checkpoint.save(self._best_state, is_best=True) # pyright: ignore[reportArgumentType] state.patience_counter = 0 else: state.patience_counter += 1 if state.acc > state.best_acc: state.best_acc = state.acc # --- CHECKPOINTING --- if self._should_save_checkpoint(epoch, is_last_epoch=(epoch == epochs)): state.model_state_dict = self._model.state_dict() state.optimizer_state_dict = self._optimizer.state_dict() cast(Checkpoint, self._checkpoint).save(state) # --- EARLY STOPPING --- if ( self._early_stopping_patience is not None and state.patience_counter >= self._early_stopping_patience ): _reason = ( "validation loss" if val_loader is not None else "training loss" ) self._logger.display_info( f"Early stopping triggered. No improvement in {_reason} " f"for {self._early_stopping_patience} epochs.", display_on_terminal=False, ) break progress.stop()
[docs] def predict(self, dataloader_or_batch: Union[DataLoader, Dict, torch.Tensor]): """Generate predictions for a dataloader or a single batch. Labels are stripped from dict batches before inference so the model does not compute a loss during prediction. For classification tasks (``SEQ_CLS``, ``TOKEN_CLS``, ``QUESTION_ANS``), returns a flat list of predicted class indices. For generative tasks (``CAUSAL_LM``, ``SEQ_2_SEQ_LM``), returns a list of logit tensors (one per batch). Args: dataloader_or_batch: A DataLoader, a dict batch, or a raw tensor. Returns: list: Predicted class indices (classification) or logit tensors (generative). """ self._model.eval() predictions: list = [] def predict_batch(raw_batch): if isinstance(raw_batch, dict): raw_batch = self._move_to_device(raw_batch, self._device) inference_batch = {k: v for k, v in raw_batch.items() if k != "labels"} outputs = self._model(**inference_batch) logits = outputs.logits else: raw_batch = self._move_to_device(raw_batch, self._device) logits = self._model(raw_batch) if self._is_generative: predictions.append(logits.cpu()) else: preds = torch.argmax(logits, dim=-1) predictions.extend(preds.cpu().tolist()) with torch.no_grad(): if isinstance(dataloader_or_batch, DataLoader): for raw in dataloader_or_batch: predict_batch(raw) else: predict_batch(dataloader_or_batch) return predictions