from copy import deepcopy
from typing import Optional, Union, cast
import torch
import torch.nn
import torch.optim
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
TextColumn,
TimeElapsedColumn,
)
from sklearn.metrics import r2_score
from torch.utils.data import DataLoader
from fenn.logging import Logger
from fenn.nn.utils import Checkpoint
from .trainer import Trainer
[docs]
class RegressionTrainer(Trainer):
"""A trainer for regression tasks with PyTorch models.
Extends the base :class:`Trainer` with regression-specific metrics (R² score,
MSE) and continuous-value prediction logic. Handles single-target regression
with optional validation and early stopping.
Args:
model: The neural network model, expected to output continuous predictions.
loss_fn: Loss function suitable for regression (e.g. MSELoss, HuberLoss).
optim: Optimizer for updating trainable parameters.
return_model: Which model version to return after training. ``'last'``
returns the final checkpoint, ``'best'`` returns the best model
by validation/training loss.
device: Device to run training on (``'cpu'``, ``'cuda'``, or ``'mps'``).
early_stopping_patience: Stop training after this many epochs without
improvement in loss. ``None`` disables.
checkpoint_config: Optional :class:`~fenn.nn.utils.Checkpoint` for
saving training state to disk.
"""
[docs]
def __init__(
self,
model: torch.nn.Module,
loss_fn: torch.nn.Module,
optim: torch.optim.Optimizer,
return_model: str = "last",
device: Union[torch.device, str] = "cpu",
early_stopping_patience: Optional[int] = None,
checkpoint_config: Optional[Checkpoint] = None,
):
"""Initialize a RegressionTrainer instance.
Args:
model: The neural network model to train.
loss_fn: The loss function to use.
optim: The optimizer to use.
return_model: Whether to return the 'last' or 'best' model after training.
device: The device on which the data will be loaded.
early_stopping_patience: The number of epochs to wait before early stopping.
checkpoint_config: The checkpoint configuration. If `None`, checkpointing is disabled.
"""
super().__init__(
model=model,
loss_fn=loss_fn,
optim=optim,
device=device,
early_stopping_patience=early_stopping_patience,
checkpoint_config=checkpoint_config,
)
self._logger = Logger()
self._return_model = return_model.lower()
if self._return_model not in {"last", "best"}:
raise ValueError("return_model must be 'last' or 'best'")
[docs]
def fit(
self,
train_loader: DataLoader,
epochs: int,
val_loader: Optional[DataLoader] = None,
val_epochs: int = 1,
):
"""Train the model with optional validation and early stopping.
The behaviour depends on the combination of `val_loader` and
``early_stopping_patience``:
* No validation loader, no early stopping: run full ``epochs``.
* No validation loader, early stopping set: stop on training loss.
* Validation loader provided, no early stopping: evaluate every epoch
but continue regardless of metrics.
* Validation loader provided and early stopping set: monitor
validation loss and stop when it plateaus.
Args:
train_loader: DataLoader for training data.
epochs: Total number of epochs to train for.
val_loader: DataLoader for validation data (optional).
val_epochs: How often to evaluate on validation set (in epochs).
Returns:
The trained model (returned according to ``return_model``).
"""
state = self._state
progress = Progress(
TextColumn(
"[bold blue]Epoch {task.fields[epoch]}/{task.fields[total_epochs]}"
),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
)
progress.start()
epoch_task = progress.add_task(
"Training",
total=epochs - state.epoch,
epoch=state.epoch,
total_epochs=epochs,
info="",
)
for epoch in range(state.epoch + 1, epochs + 1):
state.epoch = epoch
state.model_state_dict = None
state.optimizer_state_dict = None
# --- TRAIN ---
self._model.train()
total_loss = 0.0
n_batches = 0
for data, labels in train_loader:
data = self._move_to_device(data, self._device)
labels = labels.to(self._device)
outputs = self._model(data)
loss = self._loss_fn(outputs, labels)
self._optimizer.zero_grad(set_to_none=True)
loss.backward()
self._optimizer.step()
total_loss += float(loss.item())
n_batches += 1
if n_batches == 0:
raise ValueError("train_loader produced 0 batches; cannot train.")
state.train_loss = total_loss / n_batches
progress.update(
epoch_task, # pyright: ignore[reportArgumentType]
advance=1,
epoch=epoch,
info=f"Train Mean Loss : {state.train_loss:.4f}",
)
# --- NO VALIDATION ---
if val_loader is None:
state.val_loss = None
progress.console.print(
f"[bold blue]Epoch {epoch}/{epochs}[/bold blue] Train Loss: {state.train_loss:.4f}"
)
Logger().display_info(
f"Epoch {epoch}/{epochs} - Train Loss: {state.train_loss:.4f}",
display_on_terminal=False,
)
if state.train_loss < state.best_train_loss:
state.best_train_loss = state.train_loss
state.patience_counter = 0
else:
state.patience_counter += 1
# --- VALIDATION ---
elif epoch % val_epochs == 0 or epoch == epochs:
self._model.eval()
val_labels = []
val_predictions = []
val_total_loss = 0.0
val_n_batches = 0
with torch.no_grad():
for data, labels in val_loader:
data = self._move_to_device(data, self._device)
labels = labels.to(self._device)
outputs = self._model(data)
val_batch_loss = self._loss_fn(outputs, labels)
val_total_loss += float(val_batch_loss.item())
val_n_batches += 1
logits = outputs
preds = logits.squeeze(-1)
val_predictions.extend(preds.cpu().tolist())
val_labels.extend(labels.cpu().tolist())
if val_n_batches == 0:
raise ValueError("val_loader produced 0 batches; cannot validate.")
if val_n_batches > 0:
val_mean_loss = val_total_loss / val_n_batches
val_r2 = r2_score(val_labels, val_predictions)
progress.console.print(
f"[bold blue]Epoch {epoch}/{epochs}[/bold blue] Train Loss: {state.train_loss:.4f} | Val Loss: {val_mean_loss:.4f} | Val R2: {val_r2:.4f}"
)
Logger().display_info(
f"Epoch {epoch}/{epochs} - Train Loss: {state.train_loss:.4f} | Val Loss: {val_mean_loss:.4f} | Val R2: {val_r2:.4f}",
display_on_terminal=False,
)
state.val_loss = val_total_loss / val_n_batches
state.acc = r2_score(val_labels, val_predictions)
progress.update(
epoch_task, # pyright: ignore[reportArgumentType]
info=f"Train Mean Loss: {state.train_loss:.4f} | Val Loss: {state.val_loss:.4f} | Val MSE: {state.acc:.4f}",
)
if state.val_loss < state.best_val_loss:
state.best_val_loss = state.val_loss
# Update best state for improved val_loss
self._best_state = state.clone(
model_state_dict=self._model.state_dict(),
optimizer_state_dict=self._optimizer.state_dict(),
)
self._best_model = deepcopy(self._model)
self._best_model.load_state_dict(self._best_state.model_state_dict) # pyright: ignore[reportArgumentType]
if self._checkpoint is not None:
self._checkpoint.save(self._best_state, is_best=True)
state.patience_counter = 0
else:
state.patience_counter += 1
if state.acc > state.best_acc:
state.best_acc = state.acc
# --- CHECKPOINTING ---
if self._should_save_checkpoint(epoch, is_last_epoch=(epoch == epochs)):
state.model_state_dict = self._model.state_dict()
state.optimizer_state_dict = self._optimizer.state_dict()
cast(Checkpoint, self._checkpoint).save(state)
# --- EARLY STOPPING ---
if (
self._early_stopping_patience is not None
and state.patience_counter >= self._early_stopping_patience
):
if val_loader is None:
_reason = "training loss"
else:
_reason = "validation loss"
self._logger.display_info(
f"Early stopping triggered. No improvement in {_reason} for {self._early_stopping_patience} epochs."
)
break
progress.stop()
if self._return_model == "best" and self._best_model is not None:
return self._best_model
return self._model
[docs]
def predict(self, dataloader_or_batch: Union[DataLoader, torch.Tensor]):
"""Predicts the output of the model for a given dataloader or batch.
Args:
dataloader_or_batch: A DataLoader or a torch tensor.
Returns:
list: A list of predictions.
"""
self._model.eval()
predictions = []
def predict_batch(batch):
batch = self._move_to_device(batch, self._device)
logits = self._model(batch)
preds = logits.squeeze(-1)
predictions.extend(preds.cpu().tolist())
with torch.no_grad():
if isinstance(dataloader_or_batch, DataLoader):
for data, _ in dataloader_or_batch:
predict_batch(data)
else:
predict_batch(dataloader_or_batch)
return predictions