ModuleTreeWidget API#
Bases: AnyWidget
Interactive tree viewer for PyTorch nn.Module architecture.
Displays the full module hierarchy with parameter counts, shapes, trainable/frozen/buffer badges, and a density indicator.
Examples:
import torch.nn as nn
from wigglystuff import ModuleTreeWidget
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
ModuleTreeWidget(model, initial_expand_depth=2)
Create a ModuleTreeWidget.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
A PyTorch |
None
|
|
initial_expand_depth
|
int
|
Number of tree levels to expand initially. |
1
|
Source code in wigglystuff/module_tree.py
Synced traitlets#
| Traitlet | Type | Notes |
|---|---|---|
tree |
dict |
JSON-serializable tree extracted from a PyTorch nn.Module. |
initial_expand_depth |
int |
Number of tree levels to expand on first render (default: 1). |