fenn.nn.utils

class fenn.nn.utils.Checkpoint(*, name='checkpoint', dir, epochs=None, save_best=True)[source]

Bases: object

Checkpoint training state at given epochs and/or always the best model.

Saves full TrainingState snapshots (model weights, optimizer state, epoch counter, metrics) during training so that training can be resumed or the best model restored later.

Parameters:
  • name (str) – Base filename for checkpoint files (without extension).

  • dir (Path | str) – Directory to save checkpoint files in.

  • epochs (int | List[int] | None) – When to save checkpoints — an int saves every N epochs, a list[int] saves at specific epochs, or None to save only the best model.

  • save_best (bool) – If True, save the best model seen so far, updated whenever validation/training loss improves.

Example

>>> checkpoint = Checkpoint(dir="checkpoints/", epochs=5, save_best=True)
>>> trainer = Trainer(model, loss_fn, optimizer, checkpoint_config=checkpoint)
__init__(*, name='checkpoint', dir, epochs=None, save_best=True)[source]

Initialize the checkpoint configuration.

Parameters:
  • name (str) – The name of the checkpoint file.

  • dir (Path | str) – The directory to save checkpoints to.

  • epochs (int | List[int] | None) – The epochs at which to save checkpoints.

  • save_best (bool) – Whether to checkpoint the best model (based on validation or training loss).

load(checkpoint_path, device=None)[source]

Load a checkpoint from the given path.

Parameters:
  • path – Path to the checkpoint file.

  • device (device | None) – The device to load the checkpoint onto.

  • checkpoint_path (str | Path)

Returns:

The training state of the checkpoint.

Return type:

TrainingState

load_at_epoch(epoch, device=None)[source]

Load the checkpoint at the given epoch.

Parameters:
  • epoch (int) – Epoch to load the checkpoint at.

  • device (device | None) – The device to load the checkpoint onto.

Returns:

The training state of the checkpoint.

Return type:

TrainingState

load_best(device=None)[source]

Load the best checkpoint.

Parameters:

device (device | None) – The device to load the checkpoint onto.

Returns:

The training state of the checkpoint.

Return type:

TrainingState

save(state, is_best=False)[source]

Save a checkpoint of the training state at the current epoch.

Parameters:
  • state (TrainingState) – The training state to checkpoint.

  • is_best (bool) – If true save as best model

Return type:

None

class fenn.nn.utils.ModelPrettyPrinter(model, *, small_model_threshold=25, compact_max_depth=3, compact_max_children=8, compact_max_lines=80)[source]

Bases: object

Render a human-readable model summary for logs.

Produces a tree-style architecture summary with parameter counts. Small models (module count ≤ small_model_threshold) are printed in full; larger models are compacted to avoid overwhelming the log output.

Parameters:
  • model (nn.Module) – The PyTorch module to summarise.

  • small_model_threshold (int) – Module count below which the full architecture is printed with no depth or child limits.

  • compact_max_depth (int) – Maximum nesting depth shown for large models.

  • compact_max_children (int) – Maximum number of children shown per module for large models.

  • compact_max_lines (int) – Maximum total lines in the rendered summary for large models.

Example

>>> printer = ModelPrettyPrinter(my_model)
>>> print(printer.render())
__init__(model, *, small_model_threshold=25, compact_max_depth=3, compact_max_children=8, compact_max_lines=80)[source]
Parameters:
  • model (Module)

  • small_model_threshold (int)

  • compact_max_depth (int)

  • compact_max_children (int)

  • compact_max_lines (int)

Return type:

None

render()[source]

Build and return the formatted model summary string.

Returns:

A multi-line string containing the model class name, parameter counts, and a tree view of the module hierarchy.

Return type:

str

class fenn.nn.utils.TrainingState(epoch, acc=None, train_loss=None, val_loss=None, model_state_dict=None, optimizer_state_dict=None, patience_counter=0, best_acc=-inf, best_train_loss=inf, best_val_loss=inf)[source]

Bases: object

Training state for a neural network model.

Parameters:
  • epoch (int)

  • acc (float | None)

  • train_loss (float | None)

  • val_loss (float | None)

  • model_state_dict (dict[str, Any] | None)

  • optimizer_state_dict (dict[str, Any] | None)

  • patience_counter (int)

  • best_acc (float)

  • best_train_loss (float)

  • best_val_loss (float)

__init__(epoch, acc=None, train_loss=None, val_loss=None, model_state_dict=None, optimizer_state_dict=None, patience_counter=0, best_acc=-inf, best_train_loss=inf, best_val_loss=inf)
Parameters:
  • epoch (int)

  • acc (float | None)

  • train_loss (float | None)

  • val_loss (float | None)

  • model_state_dict (dict[str, Any] | None)

  • optimizer_state_dict (dict[str, Any] | None)

  • patience_counter (int)

  • best_acc (float)

  • best_train_loss (float)

  • best_val_loss (float)

Return type:

None

acc: float | None = None

Accuracy on the validation set (if provided)

best_acc: float = -inf

Best validation accuracy achieved up to this epoch

best_train_loss: float = inf

Best train loss achieved up to this epoch

best_val_loss: float = inf

Best validation loss achieved up to this epoch

clone(**kwargs)[source]

Clone the training state with optional updated fields.

Returns:

A new TrainingState instance.

epoch: int
classmethod from_dict(data)[source]

Deserialize a dictionary to a TrainingState instance.

Parameters:

data (dict[str, Any]) – The serialized training state.

Returns:

A new TrainingState instance.

model_state_dict: dict[str, Any] | None = None
optimizer_state_dict: dict[str, Any] | None = None
patience_counter: int = 0

Patience counter up to this epoch for early stopping.

to_dict()[source]

Serialize the training state to a dictionary.

train_loss: float | None = None

Train mean loss over all batches

val_loss: float | None = None

Validation mean loss over all batches