Source code for fenn.nn.utils.model_pretty_printer

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import torch.nn as nn


@dataclass(frozen=True)
class _RenderLimits:
    max_depth: Optional[int]
    max_children: Optional[int]
    max_lines: Optional[int]


[docs] class ModelPrettyPrinter: """Render a human-readable model summary for logs. Produces a tree-style architecture summary with parameter counts. Small models (module count ≤ ``small_model_threshold``) are printed in full; larger models are compacted to avoid overwhelming the log output. Args: model: The PyTorch module to summarise. small_model_threshold: Module count below which the full architecture is printed with no depth or child limits. compact_max_depth: Maximum nesting depth shown for large models. compact_max_children: Maximum number of children shown per module for large models. compact_max_lines: Maximum total lines in the rendered summary for large models. Example: >>> printer = ModelPrettyPrinter(my_model) >>> print(printer.render()) """
[docs] def __init__( self, model: nn.Module, *, small_model_threshold: int = 25, compact_max_depth: int = 3, compact_max_children: int = 8, compact_max_lines: int = 80, ) -> None: self._model = model self._small_model_threshold = small_model_threshold self._compact_limits = _RenderLimits( max_depth=compact_max_depth, max_children=compact_max_children, max_lines=compact_max_lines, )
[docs] def render(self) -> str: """Build and return the formatted model summary string. Returns: A multi-line string containing the model class name, parameter counts, and a tree view of the module hierarchy. """ module_count = sum(1 for _ in self._model.modules()) total_params = sum(param.numel() for param in self._model.parameters()) trainable_params = sum( param.numel() for param in self._model.parameters() if param.requires_grad ) frozen_params = total_params - trainable_params limits = self._select_limits(module_count) lines = [ "Model Summary", "-------------", f"Class: {self._model.__class__.__name__}", f"Modules: {module_count}", ( "Parameters: " f"total={total_params:,}, " f"trainable={trainable_params:,}, " f"frozen={frozen_params:,}" ), "", "Architecture:", self._format_module_header(self._model, root=True), ] children = list(self._model.named_children()) self._append_children( lines=lines, children=children, prefix="", depth=1, limits=limits, ) if limits.max_lines is not None and len(lines) >= limits.max_lines: omitted = max(module_count - (len(lines) - 6), 0) truncated_line = f"... output truncated ({omitted} modules omitted)" if omitted > 0 and lines[-1] != truncated_line: lines.append(truncated_line) return "\n".join(lines)
def __str__(self) -> str: """Return the model summary string (delegates to :meth:`render`).""" return self.render() def _select_limits(self, module_count: int) -> _RenderLimits: """Choose render limits based on model size. Args: module_count: Total number of modules in the model. Returns: A :class:`_RenderLimits` with ``None`` fields (no limits) for small models, or the compact limits for larger ones. """ if module_count <= self._small_model_threshold: return _RenderLimits(max_depth=None, max_children=None, max_lines=None) return self._compact_limits def _append_children( self, *, lines: list[str], children: list[tuple[str, nn.Module]], prefix: str, depth: int, limits: _RenderLimits, ) -> None: """Recursively append child module lines to the summary. Args: lines: Accumulator list of rendered lines (mutated in place). children: Named children of the current module. prefix: Indentation string for the current nesting level. depth: Current recursion depth (1-indexed from the root). limits: Active render limits controlling depth, children, and line count. """ if not children: return if limits.max_depth is not None and depth > limits.max_depth: lines.append(f"{prefix}... {len(children)} nested modules omitted") return display_children = children hidden_children = 0 if limits.max_children is not None and len(children) > limits.max_children: display_children = children[: limits.max_children] hidden_children = len(children) - len(display_children) for index, (name, module) in enumerate(display_children): if limits.max_lines is not None and len(lines) >= limits.max_lines: return lines.append(f"{prefix}{name}: {self._format_module_header(module)}") next_prefix = f"{prefix} " self._append_children( lines=lines, children=list(module.named_children()), prefix=next_prefix, depth=depth + 1, limits=limits, ) if hidden_children > 0 and ( limits.max_lines is None or len(lines) < limits.max_lines ): lines.append(f"{prefix}... {hidden_children} more modules") def _format_module_header(self, module: nn.Module, *, root: bool = False) -> str: """Format a single module's summary line. Includes the class name and, where non-zero, the direct parameter count, ``extra_repr`` string, and child count. Args: module: The module to format. root: If ``True``, child count is omitted (root is always expanded). Returns: A formatted string, e.g. ``"Linear (params=512, in_features=256, out_features=2)"``. """ details = [] direct_params = sum(param.numel() for param in module.parameters(recurse=False)) if direct_params: details.append(f"params={direct_params:,}") extra = self._normalize_extra_repr(module.extra_repr()) if extra: details.append(extra) child_count = sum(1 for _ in module.children()) if child_count and not root: details.append(f"children={child_count}") if not details: return module.__class__.__name__ return f"{module.__class__.__name__} ({', '.join(details)})" @staticmethod def _normalize_extra_repr(extra_repr: str) -> str: """Collapse and truncate a module's ``extra_repr`` to a single line. Joins multi-line repr strings into one line and truncates to 90 characters with an ellipsis if longer. Args: extra_repr: Raw string from ``module.extra_repr()``. Returns: A collapsed single-line string, or an empty string if the input was empty. """ if not extra_repr: return "" collapsed = " ".join( part.strip() for part in extra_repr.splitlines() if part.strip() ) if len(collapsed) <= 90: return collapsed return f"{collapsed[:87]}..."