fenn.nn.utils¶
- class fenn.nn.utils.Checkpoint(*, name='checkpoint', dir, epochs=None, save_best=True)[source]¶
Bases:
objectCheckpoint training state at given epochs and/or always the best model.
Saves full
TrainingStatesnapshots (model weights, optimizer state, epoch counter, metrics) during training so that training can be resumed or the best model restored later.- Parameters:
name (str) – Base filename for checkpoint files (without extension).
dir (Path | str) – Directory to save checkpoint files in.
epochs (int | List[int] | None) – When to save checkpoints — an
intsaves every N epochs, alist[int]saves at specific epochs, orNoneto save only the best model.save_best (bool) – If
True, save the best model seen so far, updated whenever validation/training loss improves.
Example
>>> checkpoint = Checkpoint(dir="checkpoints/", epochs=5, save_best=True) >>> trainer = Trainer(model, loss_fn, optimizer, checkpoint_config=checkpoint)
- __init__(*, name='checkpoint', dir, epochs=None, save_best=True)[source]¶
Initialize the checkpoint configuration.
- Parameters:
name (str) – The name of the checkpoint file.
dir (Path | str) – The directory to save checkpoints to.
epochs (int | List[int] | None) – The epochs at which to save checkpoints.
save_best (bool) – Whether to checkpoint the best model (based on validation or training loss).
- load(checkpoint_path, device=None)[source]¶
Load a checkpoint from the given path.
- Parameters:
path – Path to the checkpoint file.
device (device | None) – The device to load the checkpoint onto.
checkpoint_path (str | Path)
- Returns:
The training state of the checkpoint.
- Return type:
- load_at_epoch(epoch, device=None)[source]¶
Load the checkpoint at the given epoch.
- Parameters:
epoch (int) – Epoch to load the checkpoint at.
device (device | None) – The device to load the checkpoint onto.
- Returns:
The training state of the checkpoint.
- Return type:
- load_best(device=None)[source]¶
Load the best checkpoint.
- Parameters:
device (device | None) – The device to load the checkpoint onto.
- Returns:
The training state of the checkpoint.
- Return type:
- save(state, is_best=False)[source]¶
Save a checkpoint of the training state at the current epoch.
- Parameters:
state (TrainingState) – The training state to checkpoint.
is_best (bool) – If true save as best model
- Return type:
None
- class fenn.nn.utils.ModelPrettyPrinter(model, *, small_model_threshold=25, compact_max_depth=3, compact_max_children=8, compact_max_lines=80)[source]¶
Bases:
objectRender a human-readable model summary for logs.
Produces a tree-style architecture summary with parameter counts. Small models (module count ≤
small_model_threshold) are printed in full; larger models are compacted to avoid overwhelming the log output.- Parameters:
model (nn.Module) – The PyTorch module to summarise.
small_model_threshold (int) – Module count below which the full architecture is printed with no depth or child limits.
compact_max_depth (int) – Maximum nesting depth shown for large models.
compact_max_children (int) – Maximum number of children shown per module for large models.
compact_max_lines (int) – Maximum total lines in the rendered summary for large models.
Example
>>> printer = ModelPrettyPrinter(my_model) >>> print(printer.render())
- class fenn.nn.utils.TrainingState(epoch, acc=None, train_loss=None, val_loss=None, model_state_dict=None, optimizer_state_dict=None, patience_counter=0, best_acc=-inf, best_train_loss=inf, best_val_loss=inf)[source]¶
Bases:
objectTraining state for a neural network model.
- Parameters:
epoch (int)
acc (float | None)
train_loss (float | None)
val_loss (float | None)
model_state_dict (dict[str, Any] | None)
optimizer_state_dict (dict[str, Any] | None)
patience_counter (int)
best_acc (float)
best_train_loss (float)
best_val_loss (float)
- __init__(epoch, acc=None, train_loss=None, val_loss=None, model_state_dict=None, optimizer_state_dict=None, patience_counter=0, best_acc=-inf, best_train_loss=inf, best_val_loss=inf)¶
- Parameters:
epoch (int)
acc (float | None)
train_loss (float | None)
val_loss (float | None)
model_state_dict (dict[str, Any] | None)
optimizer_state_dict (dict[str, Any] | None)
patience_counter (int)
best_acc (float)
best_train_loss (float)
best_val_loss (float)
- Return type:
None
- acc: float | None = None¶
Accuracy on the validation set (if provided)
- best_acc: float = -inf¶
Best validation accuracy achieved up to this epoch
- best_train_loss: float = inf¶
Best train loss achieved up to this epoch
- best_val_loss: float = inf¶
Best validation loss achieved up to this epoch
- clone(**kwargs)[source]¶
Clone the training state with optional updated fields.
- Returns:
A new TrainingState instance.
- epoch: int¶
- classmethod from_dict(data)[source]¶
Deserialize a dictionary to a TrainingState instance.
- Parameters:
data (dict[str, Any]) – The serialized training state.
- Returns:
A new TrainingState instance.
- model_state_dict: dict[str, Any] | None = None¶
- optimizer_state_dict: dict[str, Any] | None = None¶
- patience_counter: int = 0¶
Patience counter up to this epoch for early stopping.
- train_loss: float | None = None¶
Train mean loss over all batches
- val_loss: float | None = None¶
Validation mean loss over all batches