Source code for fenn.nn.utils.state
from dataclasses import asdict, dataclass, replace
from typing import Any, Optional, TypeAlias
StateDict: TypeAlias = dict[str, Any]
[docs]
@dataclass
class TrainingState:
"""Training state for a neural network model."""
# May use pydantic.BaseModel instead of dataclass
epoch: int
acc: Optional[float] = None
"""Accuracy on the validation set (if provided)"""
train_loss: Optional[float] = None
"""Train mean loss over all batches"""
val_loss: Optional[float] = None
"""Validation mean loss over all batches"""
model_state_dict: Optional[StateDict] = None
optimizer_state_dict: Optional[StateDict] = None
patience_counter: int = 0
"""Patience counter up to this epoch for early stopping."""
best_acc: float = float("-inf")
"""Best validation accuracy achieved up to this epoch"""
best_train_loss: float = float("inf")
"""Best train loss achieved up to this epoch"""
best_val_loss: float = float("inf")
"""Best validation loss achieved up to this epoch"""
[docs]
def to_dict(self):
"""Serialize the training state to a dictionary."""
assert self.model_state_dict is not None
assert self.optimizer_state_dict is not None
return asdict(self)
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]):
"""Deserialize a dictionary to a TrainingState instance.
Args:
data: The serialized training state.
Returns:
A new TrainingState instance.
"""
return cls(**data)
[docs]
def clone(self, **kwargs):
"""Clone the training state with optional updated fields.
Returns:
A new TrainingState instance.
"""
return replace(self, **kwargs)