ToTf 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytorch/README.md +200 -0
- pytorch/__init__.py +28 -0
- pytorch/modelview.py +347 -0
- pytorch/smartsummary.py +1054 -0
- pytorch/trainingmonitor.py +87 -0
- pytorch/utils.py +431 -0
- tenf/README.md +161 -0
- tenf/__init__.py +28 -0
- tenf/computation_nodes.py +176 -0
- tenf/modelview.py +989 -0
- tenf/smartsummary.py +846 -0
- tenf/utils.py +451 -0
- totf-0.1.5.dist-info/METADATA +1438 -0
- totf-0.1.5.dist-info/RECORD +17 -0
- totf-0.1.5.dist-info/WHEEL +5 -0
- totf-0.1.5.dist-info/licenses/LICENSE +201 -0
- totf-0.1.5.dist-info/top_level.txt +2 -0
pytorch/README.md
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# PyTorch Module for ToTf
|
|
2
|
+
|
|
3
|
+
This module contains PyTorch-specific implementations.
|
|
4
|
+
|
|
5
|
+
## Modules
|
|
6
|
+
|
|
7
|
+
### TrainingMonitor
|
|
8
|
+
|
|
9
|
+
A comprehensive training monitor that integrates seamlessly with PyTorch training loops.
|
|
10
|
+
|
|
11
|
+
### SmartSummary
|
|
12
|
+
|
|
13
|
+
Advanced model analysis tool that goes beyond basic `model.summary()` with bottleneck detection and gradient tracking.
|
|
14
|
+
|
|
15
|
+
---
|
|
16
|
+
|
|
17
|
+
## TrainingMonitor
|
|
18
|
+
|
|
19
|
+
A comprehensive training monitor that integrates seamlessly with PyTorch training loops.
|
|
20
|
+
|
|
21
|
+
### Features
|
|
22
|
+
- Real-time progress bars with metric display
|
|
23
|
+
- Automatic CSV logging with timestamps
|
|
24
|
+
- Running averages for all metrics
|
|
25
|
+
- RAM and VRAM monitoring
|
|
26
|
+
- Crash-resistant (auto-flush)
|
|
27
|
+
|
|
28
|
+
### Usage Example
|
|
29
|
+
|
|
30
|
+
```python
|
|
31
|
+
from ToTf import TrainingMonitor
|
|
32
|
+
|
|
33
|
+
# Wrap your DataLoader
|
|
34
|
+
epochs = 5
|
|
35
|
+
|
|
36
|
+
for epoch in range(epochs):
|
|
37
|
+
monitor = TrainingMonitor(
|
|
38
|
+
train_loader,
|
|
39
|
+
desc=f"Epoch {epoch + 1}",
|
|
40
|
+
log_file=f"train_log_epoch_{epoch + 1}.csv"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Iterate through batches
|
|
44
|
+
for batch in monitor:
|
|
45
|
+
# Your training logic here
|
|
46
|
+
loss = training_step(batch)
|
|
47
|
+
|
|
48
|
+
# Log metrics
|
|
49
|
+
monitor.log({
|
|
50
|
+
'loss': loss.item(),
|
|
51
|
+
'lr': optimizer.param_groups[0]['lr']
|
|
52
|
+
})
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
### CSV Output Format
|
|
56
|
+
|
|
57
|
+
The log file contains:
|
|
58
|
+
- `timestamp`: Time of logging
|
|
59
|
+
- `step`: Current step number
|
|
60
|
+
- `<metric_name>`: Your logged metrics (running average)
|
|
61
|
+
- `RAM_pct`: RAM usage percentage
|
|
62
|
+
- `VRAM_gb`: GPU memory usage in GB
|
|
63
|
+
|
|
64
|
+
### Notes
|
|
65
|
+
|
|
66
|
+
- The monitor automatically tracks running averages of all metrics
|
|
67
|
+
- Metrics are flushed to disk after each log call for crash safety
|
|
68
|
+
- VRAM is only logged when CUDA is available
|
|
69
|
+
- Compatible with any PyTorch DataLoader or iterable
|
|
70
|
+
|
|
71
|
+
---
|
|
72
|
+
|
|
73
|
+
## SmartSummary
|
|
74
|
+
|
|
75
|
+
Advanced model summary with bottleneck detection and gradient analysis.
|
|
76
|
+
|
|
77
|
+
### Features
|
|
78
|
+
- **Comprehensive Analysis**: Shows layer types, shapes, and parameters
|
|
79
|
+
- **Bottleneck Detection**: Identifies layers that may slow down training
|
|
80
|
+
- **Gradient Tracking**: Monitors gradient variance to find unstable layers
|
|
81
|
+
- **Memory Estimation**: Calculates memory usage per layer
|
|
82
|
+
- **Export Options**: Save to file or export as dictionary
|
|
83
|
+
- **Works with Complex Models**: Handles nested architectures and residual connections
|
|
84
|
+
|
|
85
|
+
### Usage Example
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
from ToTf import SmartSummary
|
|
89
|
+
import torch.nn as nn
|
|
90
|
+
|
|
91
|
+
# Create your model
|
|
92
|
+
model = YourModel()
|
|
93
|
+
|
|
94
|
+
# Basic analysis
|
|
95
|
+
summary = SmartSummary(model, input_size=(3, 224, 224))
|
|
96
|
+
summary.show()
|
|
97
|
+
|
|
98
|
+
# With gradient tracking
|
|
99
|
+
summary = SmartSummary(
|
|
100
|
+
model,
|
|
101
|
+
input_size=(3, 224, 224),
|
|
102
|
+
track_gradients=True # Requires backward pass
|
|
103
|
+
)
|
|
104
|
+
summary.show()
|
|
105
|
+
|
|
106
|
+
# Get bottlenecks programmatically
|
|
107
|
+
bottlenecks = summary.get_bottlenecks(top_n=5)
|
|
108
|
+
for bn in bottlenecks:
|
|
109
|
+
print(f"Bottleneck: {bn['layer']}")
|
|
110
|
+
print(f" Score: {bn['score']:.2f}")
|
|
111
|
+
print(f" Reasons: {', '.join(bn['reasons'])}")
|
|
112
|
+
print(f" Parameters: {bn['params']:,}")
|
|
113
|
+
|
|
114
|
+
# Export analysis
|
|
115
|
+
summary.save_to_file("model_analysis.txt")
|
|
116
|
+
data = summary.to_dict()
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
### Parameters
|
|
120
|
+
|
|
121
|
+
- `model` (nn.Module): PyTorch model to analyze
|
|
122
|
+
- `input_size` (Tuple, optional): Input tensor shape excluding batch dimension
|
|
123
|
+
- Example: `(3, 224, 224)` for RGB images
|
|
124
|
+
- If omitted, only parameter counting is performed
|
|
125
|
+
- `batch_size` (int): Batch size for shape inference (default: 1)
|
|
126
|
+
- `device` (str): Device for analysis - 'cpu' or 'cuda' (default: 'cpu')
|
|
127
|
+
- `track_gradients` (bool): Whether to track gradient statistics (default: False)
|
|
128
|
+
- Requires a forward and backward pass
|
|
129
|
+
- Useful for identifying training instabilities
|
|
130
|
+
|
|
131
|
+
### Methods
|
|
132
|
+
|
|
133
|
+
- `show(show_bottlenecks=True)`: Display formatted summary table
|
|
134
|
+
- `get_bottlenecks(top_n=5)`: Get list of bottleneck layers
|
|
135
|
+
- `to_dict()`: Export complete analysis as dictionary
|
|
136
|
+
- `save_to_file(filename)`: Save summary to text file
|
|
137
|
+
|
|
138
|
+
### Bottleneck Detection
|
|
139
|
+
|
|
140
|
+
SmartSummary identifies bottlenecks based on:
|
|
141
|
+
|
|
142
|
+
1. **Parameter Count**: Layers with >10% of total parameters
|
|
143
|
+
2. **Gradient Variance**: High variance indicates potential instability
|
|
144
|
+
3. **Output Size**: Large intermediate tensors (>10MB)
|
|
145
|
+
|
|
146
|
+
Each bottleneck is scored and ranked. Higher scores indicate more critical bottlenecks.
|
|
147
|
+
|
|
148
|
+
### Output Format
|
|
149
|
+
|
|
150
|
+
The summary table shows:
|
|
151
|
+
- **Layer (type)**: Layer class name
|
|
152
|
+
- **Output Shape**: Tensor dimensions after this layer
|
|
153
|
+
- **Param #**: Number of parameters in this layer
|
|
154
|
+
- **Trainable**: Whether layer has trainable parameters (✓/✗)
|
|
155
|
+
- **Gradient Stats** (if tracking enabled): Variance, mean, and max
|
|
156
|
+
|
|
157
|
+
### Comparison with Other Tools
|
|
158
|
+
|
|
159
|
+
| Feature | SmartSummary | torchsummary | torchinfo | model.summary() |
|
|
160
|
+
|---------|--------------|--------------|-----------|-----------------|
|
|
161
|
+
| Basic layer info | ✓ | ✓ | ✓ | ✓ |
|
|
162
|
+
| Bottleneck detection | ✓ | ✗ | ✗ | ✗ |
|
|
163
|
+
| Gradient tracking | ✓ | ✗ | ✗ | ✗ |
|
|
164
|
+
| Memory estimation | ✓ | ✓ | ✓ | ✗ |
|
|
165
|
+
| Export to file/dict | ✓ | ✗ | ✓ | ✗ |
|
|
166
|
+
| Complex models | ✓ | Limited | ✓ | Limited |
|
|
167
|
+
| PyTorch native | ✓ | ✓ | ✓ | N/A (TF) |
|
|
168
|
+
|
|
169
|
+
### Advanced Example
|
|
170
|
+
|
|
171
|
+
```python
|
|
172
|
+
# Compare different architectures
|
|
173
|
+
models = {
|
|
174
|
+
"ResNet18": resnet18(),
|
|
175
|
+
"VGG16": vgg16(),
|
|
176
|
+
"EfficientNet": efficientnet_b0()
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
for name, model in models.items():
|
|
180
|
+
print(f"\n{'='*50}")
|
|
181
|
+
print(f"Analyzing {name}")
|
|
182
|
+
print(f"{'='*50}")
|
|
183
|
+
|
|
184
|
+
summary = SmartSummary(model, input_size=(3, 224, 224))
|
|
185
|
+
bottlenecks = summary.get_bottlenecks(top_n=3)
|
|
186
|
+
|
|
187
|
+
print(f"Total parameters: {summary.total_params:,}")
|
|
188
|
+
print(f"Top bottleneck: {bottlenecks[0]['layer']}")
|
|
189
|
+
print(f" Score: {bottlenecks[0]['score']:.2f}")
|
|
190
|
+
```
|
|
191
|
+
|
|
192
|
+
### Notes
|
|
193
|
+
|
|
194
|
+
- The monitor automatically tracks running averages of all metrics
|
|
195
|
+
- Metrics are flushed to disk after each log call for crash safety
|
|
196
|
+
- VRAM is only logged when CUDA is available
|
|
197
|
+
- Compatible with any PyTorch DataLoader or iterable
|
|
198
|
+
- For very large models, omit `input_size` to skip forward pass
|
|
199
|
+
- Gradient tracking adds overhead - use during model design phase
|
|
200
|
+
- Works with models containing skip connections, attention mechanisms, etc.
|
pytorch/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PyTorch-specific modules for ToTf
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .trainingmonitor import TrainingMonitor
|
|
6
|
+
from .smartsummary import SmartSummary
|
|
7
|
+
from .modelview import ModelView, draw_graph
|
|
8
|
+
from .utils import (
|
|
9
|
+
lazy_flatten,
|
|
10
|
+
get_flatten_size,
|
|
11
|
+
loss_ncc,
|
|
12
|
+
ncc_score,
|
|
13
|
+
LRFinder,
|
|
14
|
+
find_lr
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"TrainingMonitor",
|
|
19
|
+
"SmartSummary",
|
|
20
|
+
"ModelView",
|
|
21
|
+
"draw_graph",
|
|
22
|
+
"lazy_flatten",
|
|
23
|
+
"get_flatten_size",
|
|
24
|
+
"loss_ncc",
|
|
25
|
+
"ncc_score",
|
|
26
|
+
"LRFinder",
|
|
27
|
+
"find_lr"
|
|
28
|
+
]
|
pytorch/modelview.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ModelView - PyTorch Model Architecture Visualization
|
|
3
|
+
|
|
4
|
+
A wrapper around torchview for generating publication-quality neural network
|
|
5
|
+
architecture diagrams for PyTorch models with a unified API similar to the
|
|
6
|
+
TensorFlow ModelView.
|
|
7
|
+
|
|
8
|
+
Features:
|
|
9
|
+
- High-quality architecture diagrams suitable for research papers
|
|
10
|
+
- Multiple output formats (PNG, PDF, SVG)
|
|
11
|
+
- Automatic layer shape and parameter annotation
|
|
12
|
+
- Support for complex architectures (residual, multi-input/output, branching)
|
|
13
|
+
- Customizable styling and layout
|
|
14
|
+
- Leverages torchview internally for comprehensive PyTorch support
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
from typing import Dict, List, Tuple, Optional, Union, Any
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
import json
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from torchview import draw_graph as torchview_draw_graph
|
|
25
|
+
TORCHVIEW_AVAILABLE = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
TORCHVIEW_AVAILABLE = False
|
|
28
|
+
torchview_draw_graph = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ModelView:
|
|
32
|
+
"""
|
|
33
|
+
Generate publication-quality architecture diagrams for PyTorch models.
|
|
34
|
+
|
|
35
|
+
This class wraps torchview to provide a consistent API similar to the
|
|
36
|
+
TensorFlow ModelView, making it easy to switch between frameworks while
|
|
37
|
+
maintaining the same visualization workflow.
|
|
38
|
+
|
|
39
|
+
Features:
|
|
40
|
+
- Automatic graph layout and rendering
|
|
41
|
+
- Layer parameter counts and shapes
|
|
42
|
+
- Tensor flow visualization
|
|
43
|
+
- Multiple output formats
|
|
44
|
+
- Customizable styling
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
>>> model = YourModel()
|
|
48
|
+
>>> view = ModelView(model, input_size=(3, 224, 224))
|
|
49
|
+
>>> view.render('model_architecture.png')
|
|
50
|
+
>>> # Or with custom styling
|
|
51
|
+
>>> view.render('model.pdf', format='pdf', rankdir='TB',
|
|
52
|
+
... show_shapes=True, show_layer_names=True)
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
model: nn.Module,
|
|
58
|
+
input_size: Optional[Union[Tuple[int, ...], List[Tuple[int, ...]]]] = None,
|
|
59
|
+
input_data: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
|
60
|
+
batch_size: int = 1,
|
|
61
|
+
device: str = "cpu",
|
|
62
|
+
depth: int = 3,
|
|
63
|
+
expand_nested: bool = False,
|
|
64
|
+
hide_inner_tensors: bool = True,
|
|
65
|
+
hide_module_functions: bool = True,
|
|
66
|
+
roll: bool = False,
|
|
67
|
+
show_shapes: bool = True,
|
|
68
|
+
dtypes: Optional[List[torch.dtype]] = None,
|
|
69
|
+
**kwargs
|
|
70
|
+
):
|
|
71
|
+
"""
|
|
72
|
+
Initialize ModelView with torchview integration.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
model: PyTorch model (nn.Module) to visualize
|
|
76
|
+
input_size: Input tensor size(s) excluding batch dimension
|
|
77
|
+
e.g., (3, 224, 224) for images
|
|
78
|
+
Can be list of sizes for multi-input models
|
|
79
|
+
input_data: Optional actual input tensor(s) instead of input_size
|
|
80
|
+
batch_size: Batch size for shape inference (default: 1)
|
|
81
|
+
device: Device for computation ('cpu' or 'cuda')
|
|
82
|
+
depth: Maximum depth for nested models (default: 3)
|
|
83
|
+
Controls how deep to show in module hierarchy
|
|
84
|
+
expand_nested: Whether to expand nested models with dashed borders
|
|
85
|
+
hide_inner_tensors: If True, only show input/output tensors
|
|
86
|
+
If False, show all intermediate tensors
|
|
87
|
+
hide_module_functions: If True, hide operations inside layers
|
|
88
|
+
If False, show all operations
|
|
89
|
+
roll: If True, roll recursive modules (useful for RNNs)
|
|
90
|
+
show_shapes: Whether to show tensor shapes in visualization
|
|
91
|
+
dtypes: Optional list of dtypes for each input
|
|
92
|
+
**kwargs: Additional arguments to pass to torchview
|
|
93
|
+
"""
|
|
94
|
+
if not TORCHVIEW_AVAILABLE:
|
|
95
|
+
raise ImportError(
|
|
96
|
+
"torchview is required for PyTorch ModelView. "
|
|
97
|
+
"Install with: pip install torchview"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self.model = model
|
|
101
|
+
self.input_size = input_size
|
|
102
|
+
self.input_data = input_data
|
|
103
|
+
self.batch_size = batch_size
|
|
104
|
+
self.device = device
|
|
105
|
+
self.depth = depth
|
|
106
|
+
self.expand_nested = expand_nested
|
|
107
|
+
self.hide_inner_tensors = hide_inner_tensors
|
|
108
|
+
self.hide_module_functions = hide_module_functions
|
|
109
|
+
self.roll = roll
|
|
110
|
+
self.show_shapes = show_shapes
|
|
111
|
+
self.dtypes = dtypes
|
|
112
|
+
self.kwargs = kwargs
|
|
113
|
+
|
|
114
|
+
# Store the torchview graph object
|
|
115
|
+
self._graph = None
|
|
116
|
+
self._build_graph()
|
|
117
|
+
|
|
118
|
+
def _build_graph(self):
|
|
119
|
+
"""Build the computation graph using torchview"""
|
|
120
|
+
# Prepare input data
|
|
121
|
+
if self.input_data is not None:
|
|
122
|
+
input_arg = self.input_data
|
|
123
|
+
elif self.input_size is not None:
|
|
124
|
+
# Convert input_size to proper format
|
|
125
|
+
if isinstance(self.input_size, list):
|
|
126
|
+
input_arg = [tuple([self.batch_size] + list(size)) for size in self.input_size]
|
|
127
|
+
else:
|
|
128
|
+
input_arg = tuple([self.batch_size] + list(self.input_size))
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError("Either input_size or input_data must be provided")
|
|
131
|
+
|
|
132
|
+
# Build graph using torchview
|
|
133
|
+
try:
|
|
134
|
+
self._graph = torchview_draw_graph(
|
|
135
|
+
model=self.model,
|
|
136
|
+
input_size=input_arg if self.input_data is None else None,
|
|
137
|
+
input_data=self.input_data,
|
|
138
|
+
device=self.device,
|
|
139
|
+
depth=self.depth,
|
|
140
|
+
expand_nested=self.expand_nested,
|
|
141
|
+
hide_inner_tensors=self.hide_inner_tensors,
|
|
142
|
+
hide_module_functions=self.hide_module_functions,
|
|
143
|
+
roll=self.roll,
|
|
144
|
+
dtypes=self.dtypes,
|
|
145
|
+
**self.kwargs
|
|
146
|
+
)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
raise RuntimeError(f"Failed to build computation graph: {e}")
|
|
149
|
+
|
|
150
|
+
def render(
|
|
151
|
+
self,
|
|
152
|
+
filename: str,
|
|
153
|
+
format: Optional[str] = None,
|
|
154
|
+
rankdir: str = 'TB',
|
|
155
|
+
show_shapes: bool = True,
|
|
156
|
+
show_layer_names: bool = False,
|
|
157
|
+
show_params: bool = True,
|
|
158
|
+
dpi: int = 300,
|
|
159
|
+
cleanup: bool = True,
|
|
160
|
+
**kwargs
|
|
161
|
+
) -> str:
|
|
162
|
+
"""
|
|
163
|
+
Render the model architecture diagram to a file.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
filename: Output file path
|
|
167
|
+
format: Output format ('png', 'pdf', 'svg'). If None, inferred from filename
|
|
168
|
+
rankdir: Graph direction ('TB'=top-to-bottom, 'LR'=left-to-right,
|
|
169
|
+
'BT'=bottom-to-top, 'RL'=right-to-left)
|
|
170
|
+
show_shapes: Whether to display tensor shapes (always enabled in torchview)
|
|
171
|
+
show_layer_names: Whether to display layer names (always shown in torchview)
|
|
172
|
+
show_params: Whether to display parameter counts (always shown in torchview)
|
|
173
|
+
dpi: Resolution for raster formats (PNG)
|
|
174
|
+
cleanup: Whether to remove intermediate files
|
|
175
|
+
**kwargs: Additional rendering arguments
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Path to the rendered file
|
|
179
|
+
"""
|
|
180
|
+
if self._graph is None:
|
|
181
|
+
raise RuntimeError("Graph not built. Call _build_graph() first.")
|
|
182
|
+
|
|
183
|
+
# Infer format from filename if not provided
|
|
184
|
+
if format is None:
|
|
185
|
+
format = Path(filename).suffix[1:].lower()
|
|
186
|
+
if not format:
|
|
187
|
+
format = 'png'
|
|
188
|
+
|
|
189
|
+
# Configure graph visualization settings
|
|
190
|
+
graph_viz = self._graph.visual_graph
|
|
191
|
+
graph_viz.graph_attr['rankdir'] = rankdir
|
|
192
|
+
graph_viz.graph_attr['dpi'] = str(dpi)
|
|
193
|
+
graph_viz.format = format
|
|
194
|
+
|
|
195
|
+
# Render the graph
|
|
196
|
+
output_path = Path(filename).with_suffix('')
|
|
197
|
+
try:
|
|
198
|
+
graph_viz.render(str(output_path), cleanup=cleanup)
|
|
199
|
+
result_path = f"{output_path}.{format}"
|
|
200
|
+
return result_path
|
|
201
|
+
except Exception as e:
|
|
202
|
+
raise RuntimeError(f"Failed to render graph: {e}")
|
|
203
|
+
|
|
204
|
+
def render_advanced(
|
|
205
|
+
self,
|
|
206
|
+
filename: str,
|
|
207
|
+
format: Optional[str] = None,
|
|
208
|
+
rankdir: str = 'TB',
|
|
209
|
+
show_shapes: bool = True,
|
|
210
|
+
show_layer_names: bool = True,
|
|
211
|
+
show_params: bool = True,
|
|
212
|
+
dpi: int = 300,
|
|
213
|
+
cleanup: bool = True,
|
|
214
|
+
**kwargs
|
|
215
|
+
) -> str:
|
|
216
|
+
"""
|
|
217
|
+
Render advanced computation graph visualization (alias for render).
|
|
218
|
+
|
|
219
|
+
This method provides the same functionality as render() but with a name
|
|
220
|
+
consistent with the TensorFlow ModelView API.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
filename: Output file path
|
|
224
|
+
format: Output format ('png', 'pdf', 'svg')
|
|
225
|
+
rankdir: Graph direction
|
|
226
|
+
show_shapes: Whether to display tensor shapes
|
|
227
|
+
show_layer_names: Whether to display layer names
|
|
228
|
+
show_params: Whether to display parameter counts
|
|
229
|
+
dpi: Resolution for raster formats
|
|
230
|
+
cleanup: Whether to remove intermediate files
|
|
231
|
+
**kwargs: Additional rendering arguments
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
Path to the rendered file
|
|
235
|
+
"""
|
|
236
|
+
return self.render(
|
|
237
|
+
filename=filename,
|
|
238
|
+
format=format,
|
|
239
|
+
rankdir=rankdir,
|
|
240
|
+
show_shapes=show_shapes,
|
|
241
|
+
show_layer_names=show_layer_names,
|
|
242
|
+
show_params=show_params,
|
|
243
|
+
dpi=dpi,
|
|
244
|
+
cleanup=cleanup,
|
|
245
|
+
**kwargs
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def get_summary_dict(self) -> Dict[str, Any]:
|
|
249
|
+
"""
|
|
250
|
+
Get a dictionary summary of the model architecture.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Dictionary containing layer info, shapes, and parameter statistics
|
|
254
|
+
"""
|
|
255
|
+
if self._graph is None:
|
|
256
|
+
raise RuntimeError("Graph not built.")
|
|
257
|
+
|
|
258
|
+
# Extract information from torchview graph
|
|
259
|
+
total_params = 0
|
|
260
|
+
trainable_params = 0
|
|
261
|
+
layers = []
|
|
262
|
+
|
|
263
|
+
# Parse the graph nodes
|
|
264
|
+
for node in self._graph.edge_list:
|
|
265
|
+
if hasattr(node, 'num_params'):
|
|
266
|
+
total_params += node.num_params.get('total', 0)
|
|
267
|
+
trainable_params += node.num_params.get('trainable', 0)
|
|
268
|
+
|
|
269
|
+
return {
|
|
270
|
+
'model_name': self.model.__class__.__name__,
|
|
271
|
+
'total_parameters': int(total_params),
|
|
272
|
+
'trainable_parameters': int(trainable_params),
|
|
273
|
+
'non_trainable_parameters': int(total_params - trainable_params),
|
|
274
|
+
'input_size': self.input_size,
|
|
275
|
+
'device': self.device,
|
|
276
|
+
'depth': self.depth,
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
def save_summary_json(self, filename: str):
|
|
280
|
+
"""Save model summary as JSON file"""
|
|
281
|
+
summary = self.get_summary_dict()
|
|
282
|
+
with open(filename, 'w') as f:
|
|
283
|
+
json.dump(summary, f, indent=2)
|
|
284
|
+
|
|
285
|
+
def show(self, detailed: bool = False):
|
|
286
|
+
"""
|
|
287
|
+
Print a text-based summary of the model architecture.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
detailed: Whether to show detailed information (not used in torchview)
|
|
291
|
+
"""
|
|
292
|
+
if self._graph is None:
|
|
293
|
+
raise RuntimeError("Graph not built.")
|
|
294
|
+
|
|
295
|
+
print(self._graph)
|
|
296
|
+
|
|
297
|
+
def export_svg(self, filename: str) -> str:
|
|
298
|
+
"""Export as SVG (vector format for publications)"""
|
|
299
|
+
return self.render(filename, format='svg')
|
|
300
|
+
|
|
301
|
+
def export_pdf(self, filename: str) -> str:
|
|
302
|
+
"""Export as PDF (vector format for papers)"""
|
|
303
|
+
return self.render(filename, format='pdf')
|
|
304
|
+
|
|
305
|
+
def export_png(self, filename: str, dpi: int = 300) -> str:
|
|
306
|
+
"""Export as high-resolution PNG"""
|
|
307
|
+
return self.render(filename, format='png', dpi=dpi)
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def visual_graph(self):
|
|
311
|
+
"""Get the underlying torchview visual graph"""
|
|
312
|
+
if self._graph is None:
|
|
313
|
+
raise RuntimeError("Graph not built.")
|
|
314
|
+
return self._graph.visual_graph
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def draw_graph(
|
|
318
|
+
model: nn.Module,
|
|
319
|
+
input_size: Optional[Union[Tuple[int, ...], List[Tuple[int, ...]]]] = None,
|
|
320
|
+
input_data: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
|
321
|
+
save_path: Optional[str] = None,
|
|
322
|
+
**kwargs
|
|
323
|
+
) -> Optional[str]:
|
|
324
|
+
"""
|
|
325
|
+
Convenience function to quickly visualize a PyTorch model.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
model: PyTorch model (nn.Module) to visualize
|
|
329
|
+
input_size: Input size(s) excluding batch dimension
|
|
330
|
+
input_data: Optional actual input tensor(s) instead of input_size
|
|
331
|
+
save_path: Path to save the visualization. If None, only prints summary
|
|
332
|
+
**kwargs: Additional arguments passed to ModelView.render()
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Path to saved file if save_path is provided, else None
|
|
336
|
+
|
|
337
|
+
Example:
|
|
338
|
+
>>> model = MyModel()
|
|
339
|
+
>>> draw_graph(model, input_size=(3, 224, 224), save_path='model.png')
|
|
340
|
+
"""
|
|
341
|
+
view = ModelView(model, input_size=input_size, input_data=input_data)
|
|
342
|
+
|
|
343
|
+
if save_path:
|
|
344
|
+
return view.render(save_path, **kwargs)
|
|
345
|
+
else:
|
|
346
|
+
view.show()
|
|
347
|
+
return None
|