torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ONNX Export Utilities for Torch-RecHub models.
|
|
3
|
+
|
|
4
|
+
This module provides non-invasive ONNX export functionality for recommendation models.
|
|
5
|
+
It uses reflection to extract feature information from models and wraps dict-input models
|
|
6
|
+
to be compatible with ONNX's positional argument requirements.
|
|
7
|
+
|
|
8
|
+
Date: 2024
|
|
9
|
+
References:
|
|
10
|
+
- PyTorch ONNX Export: https://pytorch.org/docs/stable/onnx.html
|
|
11
|
+
- ONNX Runtime: https://onnxruntime.ai/docs/
|
|
12
|
+
Authors: Torch-RecHub Contributors
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import warnings
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
|
|
22
|
+
from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ONNXWrapper(nn.Module):
|
|
26
|
+
"""Wraps a dict-input model to accept positional arguments for ONNX compatibility.
|
|
27
|
+
|
|
28
|
+
ONNX does not support dict as input, so this wrapper converts positional arguments
|
|
29
|
+
back to dict format before passing to the original model.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model: The original PyTorch model that accepts dict input.
|
|
33
|
+
input_names: Ordered list of feature names corresponding to input positions.
|
|
34
|
+
mode: Optional mode for dual-tower models ("user" or "item").
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> wrapper = ONNXWrapper(dssm_model, ["user_id", "movie_id", "hist_movie_id"])
|
|
38
|
+
>>> # Now can call: wrapper(user_id_tensor, movie_id_tensor, hist_tensor)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, model: nn.Module, input_names: List[str], mode: Optional[str] = None):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.model = model
|
|
44
|
+
self.input_names = input_names
|
|
45
|
+
self._original_mode = getattr(model, 'mode', None)
|
|
46
|
+
|
|
47
|
+
# Set mode for dual-tower models
|
|
48
|
+
if mode is not None and hasattr(model, 'mode'):
|
|
49
|
+
model.mode = mode
|
|
50
|
+
|
|
51
|
+
def forward(self, *args) -> torch.Tensor:
|
|
52
|
+
"""Convert positional args to dict and call original model."""
|
|
53
|
+
if len(args) != len(self.input_names):
|
|
54
|
+
raise ValueError(f"Expected {len(self.input_names)} inputs, got {len(args)}. "
|
|
55
|
+
f"Expected names: {self.input_names}")
|
|
56
|
+
x_dict = {name: arg for name, arg in zip(self.input_names, args)}
|
|
57
|
+
return self.model(x_dict)
|
|
58
|
+
|
|
59
|
+
def restore_mode(self):
|
|
60
|
+
"""Restore the original mode of the model."""
|
|
61
|
+
if hasattr(self.model, 'mode'):
|
|
62
|
+
self.model.mode = self._original_mode
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Re-export from model_utils for backward compatibility
|
|
66
|
+
# The actual implementations are now in model_utils.py
|
|
67
|
+
from .model_utils import extract_feature_info, generate_dummy_input, generate_dummy_input_dict, generate_dynamic_axes
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ONNXExporter:
|
|
71
|
+
"""Main class for exporting Torch-RecHub models to ONNX format.
|
|
72
|
+
|
|
73
|
+
This exporter handles the complexity of converting dict-input models to ONNX
|
|
74
|
+
by automatically extracting feature information and wrapping the model.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
model: The PyTorch recommendation model to export.
|
|
78
|
+
device: Device for export operations (default: 'cpu').
|
|
79
|
+
|
|
80
|
+
Example:
|
|
81
|
+
>>> exporter = ONNXExporter(deepfm_model)
|
|
82
|
+
>>> exporter.export("model.onnx")
|
|
83
|
+
|
|
84
|
+
>>> # For dual-tower models
|
|
85
|
+
>>> exporter = ONNXExporter(dssm_model)
|
|
86
|
+
>>> exporter.export("user_tower.onnx", mode="user")
|
|
87
|
+
>>> exporter.export("item_tower.onnx", mode="item")
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(self, model: nn.Module, device: str = 'cpu'):
|
|
91
|
+
self.model = model
|
|
92
|
+
self.device = device
|
|
93
|
+
self.feature_info = extract_feature_info(model)
|
|
94
|
+
|
|
95
|
+
def export(
|
|
96
|
+
self,
|
|
97
|
+
output_path: str,
|
|
98
|
+
mode: Optional[str] = None,
|
|
99
|
+
dummy_input: Optional[Dict[str,
|
|
100
|
+
torch.Tensor]] = None,
|
|
101
|
+
batch_size: int = 2,
|
|
102
|
+
seq_length: int = 10,
|
|
103
|
+
opset_version: int = 14,
|
|
104
|
+
dynamic_batch: bool = True,
|
|
105
|
+
verbose: bool = False
|
|
106
|
+
) -> bool:
|
|
107
|
+
"""Export the model to ONNX format.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
output_path: Path to save the ONNX model.
|
|
111
|
+
mode: For dual-tower models, specify "user" or "item" to export
|
|
112
|
+
only that tower. None exports the full model.
|
|
113
|
+
dummy_input: Optional dict of example inputs. If not provided,
|
|
114
|
+
dummy inputs will be generated automatically.
|
|
115
|
+
batch_size: Batch size for generated dummy input (default: 2).
|
|
116
|
+
seq_length: Sequence length for SequenceFeature (default: 10).
|
|
117
|
+
opset_version: ONNX opset version (default: 14).
|
|
118
|
+
dynamic_batch: Whether to enable dynamic batch size (default: True).
|
|
119
|
+
verbose: Whether to print export details (default: False).
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
True if export succeeded, False otherwise.
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
RuntimeError: If ONNX export fails.
|
|
126
|
+
"""
|
|
127
|
+
self.model.eval()
|
|
128
|
+
self.model.to(self.device)
|
|
129
|
+
|
|
130
|
+
# Determine which features to use based on mode
|
|
131
|
+
if mode == "user":
|
|
132
|
+
features = self.feature_info['user_features']
|
|
133
|
+
if not features:
|
|
134
|
+
raise ValueError("No user features found in model for mode='user'")
|
|
135
|
+
elif mode == "item":
|
|
136
|
+
features = self.feature_info['item_features']
|
|
137
|
+
if not features:
|
|
138
|
+
raise ValueError("No item features found in model for mode='item'")
|
|
139
|
+
else:
|
|
140
|
+
features = self.feature_info['features']
|
|
141
|
+
|
|
142
|
+
input_names = [f.name for f in features]
|
|
143
|
+
|
|
144
|
+
# Create wrapped model
|
|
145
|
+
wrapper = ONNXWrapper(self.model, input_names, mode=mode)
|
|
146
|
+
wrapper.eval()
|
|
147
|
+
|
|
148
|
+
# Generate or use provided dummy input
|
|
149
|
+
if dummy_input is not None:
|
|
150
|
+
dummy_tuple = tuple(dummy_input[name].to(self.device) for name in input_names)
|
|
151
|
+
else:
|
|
152
|
+
dummy_tuple = generate_dummy_input(features, batch_size=batch_size, seq_length=seq_length, device=self.device)
|
|
153
|
+
|
|
154
|
+
# Configure dynamic axes
|
|
155
|
+
dynamic_axes = None
|
|
156
|
+
if dynamic_batch:
|
|
157
|
+
seq_feature_names = [f.name for f in features if isinstance(f, SequenceFeature)]
|
|
158
|
+
dynamic_axes = generate_dynamic_axes(input_names=input_names, output_names=["output"], seq_features=seq_feature_names)
|
|
159
|
+
|
|
160
|
+
# Ensure output directory exists
|
|
161
|
+
output_dir = os.path.dirname(output_path)
|
|
162
|
+
if output_dir and not os.path.exists(output_dir):
|
|
163
|
+
os.makedirs(output_dir)
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
with torch.no_grad():
|
|
167
|
+
torch.onnx.export(
|
|
168
|
+
wrapper,
|
|
169
|
+
dummy_tuple,
|
|
170
|
+
output_path,
|
|
171
|
+
input_names=input_names,
|
|
172
|
+
output_names=["output"],
|
|
173
|
+
dynamic_axes=dynamic_axes,
|
|
174
|
+
opset_version=opset_version,
|
|
175
|
+
do_constant_folding=True,
|
|
176
|
+
verbose=verbose,
|
|
177
|
+
dynamo=False # Use legacy exporter for dynamic_axes support
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if verbose:
|
|
181
|
+
print(f"Successfully exported ONNX model to: {output_path}")
|
|
182
|
+
print(f" Input names: {input_names}")
|
|
183
|
+
print(f" Opset version: {opset_version}")
|
|
184
|
+
print(f" Dynamic batch: {dynamic_batch}")
|
|
185
|
+
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
except Exception as e:
|
|
189
|
+
warnings.warn(f"ONNX export failed: {str(e)}")
|
|
190
|
+
raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
|
|
191
|
+
finally:
|
|
192
|
+
# Restore original mode
|
|
193
|
+
wrapper.restore_mode()
|
|
194
|
+
|
|
195
|
+
def get_input_info(self, mode: Optional[str] = None) -> Dict[str, Any]:
|
|
196
|
+
"""Get information about model inputs.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
mode: For dual-tower models, "user" or "item".
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Dict with input names, types, and shapes.
|
|
203
|
+
"""
|
|
204
|
+
if mode == "user":
|
|
205
|
+
features = self.feature_info['user_features']
|
|
206
|
+
elif mode == "item":
|
|
207
|
+
features = self.feature_info['item_features']
|
|
208
|
+
else:
|
|
209
|
+
features = self.feature_info['features']
|
|
210
|
+
|
|
211
|
+
info = []
|
|
212
|
+
for f in features:
|
|
213
|
+
feat_info = {'name': f.name, 'type': type(f).__name__, 'embed_dim': f.embed_dim}
|
|
214
|
+
if hasattr(f, 'vocab_size'):
|
|
215
|
+
feat_info['vocab_size'] = f.vocab_size
|
|
216
|
+
if hasattr(f, 'pooling'):
|
|
217
|
+
feat_info['pooling'] = f.pooling
|
|
218
|
+
info.append(feat_info)
|
|
219
|
+
|
|
220
|
+
return {'mode': mode, 'inputs': info, 'input_names': [f.name for f in features]}
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Visualization Utilities for Torch-RecHub.
|
|
3
|
+
|
|
4
|
+
This module provides model structure visualization using torchview library.
|
|
5
|
+
Requires optional dependencies: pip install torch-rechub[visualization]
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
>>> from torch_rechub.utils.visualization import visualize_model, display_graph
|
|
9
|
+
>>> graph = visualize_model(model, depth=4)
|
|
10
|
+
>>> display_graph(graph) # Display in Jupyter Notebook
|
|
11
|
+
|
|
12
|
+
>>> # Save to file
|
|
13
|
+
>>> visualize_model(model, save_path="model_arch.pdf")
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from typing import Any, Dict, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
|
|
21
|
+
# Check for optional dependencies
|
|
22
|
+
TORCHVIEW_AVAILABLE = False
|
|
23
|
+
TORCHVIEW_SKIP_REASON = "torchview not installed"
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from torchview import draw_graph
|
|
27
|
+
TORCHVIEW_AVAILABLE = True
|
|
28
|
+
except ImportError as e:
|
|
29
|
+
TORCHVIEW_SKIP_REASON = f"torchview not available: {e}"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _is_jupyter_environment() -> bool:
|
|
33
|
+
"""Check if running in Jupyter/IPython environment."""
|
|
34
|
+
try:
|
|
35
|
+
from IPython import get_ipython
|
|
36
|
+
shell = get_ipython()
|
|
37
|
+
if shell is None:
|
|
38
|
+
return False
|
|
39
|
+
# Check for Jupyter notebook or qtconsole
|
|
40
|
+
shell_class = shell.__class__.__name__
|
|
41
|
+
return shell_class in ('ZMQInteractiveShell', 'TerminalInteractiveShell')
|
|
42
|
+
except (ImportError, NameError):
|
|
43
|
+
return False
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def display_graph(graph: Any, format: str = 'png') -> Any:
|
|
47
|
+
"""Display a torchview ComputationGraph in Jupyter Notebook.
|
|
48
|
+
|
|
49
|
+
This function provides a reliable way to display visualization graphs
|
|
50
|
+
in Jupyter environments, especially VSCode Jupyter.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
graph : ComputationGraph
|
|
55
|
+
A torchview ComputationGraph object returned by visualize_model().
|
|
56
|
+
format : str, default='png'
|
|
57
|
+
Output format, 'png' recommended for VSCode.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
graphviz.Digraph or None
|
|
62
|
+
The displayed graph object, or None if display fails.
|
|
63
|
+
|
|
64
|
+
Examples
|
|
65
|
+
--------
|
|
66
|
+
>>> graph = visualize_model(model, depth=4)
|
|
67
|
+
>>> display_graph(graph) # Works in VSCode Jupyter
|
|
68
|
+
"""
|
|
69
|
+
if not TORCHVIEW_AVAILABLE:
|
|
70
|
+
raise ImportError(f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
|
|
71
|
+
"Install with: pip install torch-rechub[visualization]")
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
import graphviz
|
|
75
|
+
|
|
76
|
+
# Set format for VSCode compatibility
|
|
77
|
+
graphviz.set_jupyter_format(format)
|
|
78
|
+
except ImportError:
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
# Get the visual_graph (graphviz.Digraph object)
|
|
82
|
+
visual = graph.visual_graph
|
|
83
|
+
|
|
84
|
+
# Try to use IPython display for explicit rendering
|
|
85
|
+
try:
|
|
86
|
+
from IPython.display import display
|
|
87
|
+
display(visual)
|
|
88
|
+
return visual
|
|
89
|
+
except ImportError:
|
|
90
|
+
# Not in IPython/Jupyter environment, return the graph
|
|
91
|
+
return visual
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def visualize_model(
|
|
95
|
+
model: nn.Module,
|
|
96
|
+
input_data: Optional[Dict[str,
|
|
97
|
+
torch.Tensor]] = None,
|
|
98
|
+
batch_size: int = 2,
|
|
99
|
+
seq_length: int = 10,
|
|
100
|
+
depth: int = 3,
|
|
101
|
+
show_shapes: bool = True,
|
|
102
|
+
expand_nested: bool = True,
|
|
103
|
+
save_path: Optional[str] = None,
|
|
104
|
+
graph_name: str = "model",
|
|
105
|
+
device: str = "cpu",
|
|
106
|
+
dpi: int = 300,
|
|
107
|
+
**kwargs
|
|
108
|
+
) -> Any:
|
|
109
|
+
"""Visualize a Torch-RecHub model's computation graph.
|
|
110
|
+
|
|
111
|
+
This function generates a visual representation of the model architecture,
|
|
112
|
+
showing layer connections, tensor shapes, and nested module structures.
|
|
113
|
+
It automatically extracts feature information from the model to generate
|
|
114
|
+
appropriate dummy inputs.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
model : nn.Module
|
|
119
|
+
PyTorch model to visualize. Should be a Torch-RecHub model
|
|
120
|
+
with feature attributes (e.g., DeepFM, DSSM, MMOE).
|
|
121
|
+
input_data : dict, optional
|
|
122
|
+
Dict of example inputs {feature_name: tensor}.
|
|
123
|
+
If None, inputs are auto-generated based on model features.
|
|
124
|
+
batch_size : int, default=2
|
|
125
|
+
Batch size for auto-generated inputs.
|
|
126
|
+
seq_length : int, default=10
|
|
127
|
+
Sequence length for SequenceFeature inputs.
|
|
128
|
+
depth : int, default=3
|
|
129
|
+
Visualization depth - higher values show more detail.
|
|
130
|
+
Set to -1 to show all layers.
|
|
131
|
+
show_shapes : bool, default=True
|
|
132
|
+
Whether to display tensor shapes on edges.
|
|
133
|
+
expand_nested : bool, default=True
|
|
134
|
+
Whether to expand nested nn.Module with dashed borders.
|
|
135
|
+
save_path : str, optional
|
|
136
|
+
Path to save the graph image. Supports .pdf, .svg, .png formats.
|
|
137
|
+
If None, displays in Jupyter or opens system viewer.
|
|
138
|
+
graph_name : str, default="model"
|
|
139
|
+
Name for the computation graph.
|
|
140
|
+
device : str, default="cpu"
|
|
141
|
+
Device for model execution during tracing.
|
|
142
|
+
dpi : int, default=300
|
|
143
|
+
Resolution in dots per inch for output image.
|
|
144
|
+
Higher values produce sharper images suitable for papers.
|
|
145
|
+
**kwargs : dict
|
|
146
|
+
Additional arguments passed to torchview.draw_graph().
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
ComputationGraph
|
|
151
|
+
A torchview ComputationGraph object.
|
|
152
|
+
- Use `.visual_graph` property to get the graphviz.Digraph
|
|
153
|
+
- Use `.resize_graph(scale=1.5)` to adjust graph size
|
|
154
|
+
|
|
155
|
+
Raises
|
|
156
|
+
------
|
|
157
|
+
ImportError
|
|
158
|
+
If torchview or graphviz is not installed.
|
|
159
|
+
ValueError
|
|
160
|
+
If model has no recognizable feature attributes.
|
|
161
|
+
|
|
162
|
+
Notes
|
|
163
|
+
-----
|
|
164
|
+
Default Display Behavior:
|
|
165
|
+
When `save_path` is None (default):
|
|
166
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
167
|
+
- In Python script: opens the graph with system default viewer
|
|
168
|
+
|
|
169
|
+
Requires graphviz system package: apt/brew/choco install graphviz.
|
|
170
|
+
For Jupyter display issues, try: graphviz.set_jupyter_format('png').
|
|
171
|
+
|
|
172
|
+
Examples
|
|
173
|
+
--------
|
|
174
|
+
>>> from torch_rechub.models.ranking import DeepFM
|
|
175
|
+
>>> from torch_rechub.utils.visualization import visualize_model
|
|
176
|
+
>>>
|
|
177
|
+
>>> # Auto-display in Jupyter or open in viewer
|
|
178
|
+
>>> visualize_model(model, depth=4) # No save_path needed
|
|
179
|
+
>>>
|
|
180
|
+
>>> # Save to high-DPI PNG for paper
|
|
181
|
+
>>> visualize_model(model, save_path="model.png", dpi=300)
|
|
182
|
+
"""
|
|
183
|
+
if not TORCHVIEW_AVAILABLE:
|
|
184
|
+
raise ImportError(
|
|
185
|
+
f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
|
|
186
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
187
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
188
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
189
|
+
" - macOS: brew install graphviz\n"
|
|
190
|
+
" - Windows: choco install graphviz"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Import feature extraction utilities from model_utils
|
|
194
|
+
from .model_utils import extract_feature_info, generate_dummy_input_dict
|
|
195
|
+
|
|
196
|
+
model.eval()
|
|
197
|
+
model.to(device)
|
|
198
|
+
|
|
199
|
+
# Auto-generate input data if not provided
|
|
200
|
+
if input_data is None:
|
|
201
|
+
feature_info = extract_feature_info(model)
|
|
202
|
+
features = feature_info['features']
|
|
203
|
+
|
|
204
|
+
if not features:
|
|
205
|
+
raise ValueError("Could not extract feature information from model. "
|
|
206
|
+
"Please provide input_data parameter manually.")
|
|
207
|
+
|
|
208
|
+
# Generate dummy input dict
|
|
209
|
+
input_data = generate_dummy_input_dict(features, batch_size=batch_size, seq_length=seq_length, device=device)
|
|
210
|
+
else:
|
|
211
|
+
# Ensure input tensors are on correct device
|
|
212
|
+
input_data = {k: v.to(device) for k, v in input_data.items()}
|
|
213
|
+
|
|
214
|
+
# IMPORTANT: Wrap input_data dict in a tuple to work around torchview's behavior
|
|
215
|
+
#
|
|
216
|
+
# torchview's forward_prop function checks the type of input_data:
|
|
217
|
+
# - If isinstance(x, (list, tuple)): model(*x)
|
|
218
|
+
# - If isinstance(x, Mapping): model(**x) <- This unpacks dict as kwargs!
|
|
219
|
+
# - Else: model(x)
|
|
220
|
+
#
|
|
221
|
+
# torch-rechub models expect forward(self, x) where x is a complete dict.
|
|
222
|
+
# By wrapping the dict in a tuple, torchview will call:
|
|
223
|
+
# model(*(input_dict,)) = model(input_dict)
|
|
224
|
+
# which is exactly what our models expect.
|
|
225
|
+
input_data_wrapped = (input_data,)
|
|
226
|
+
|
|
227
|
+
# Call torchview.draw_graph without saving (we'll save manually with DPI)
|
|
228
|
+
graph = draw_graph(
|
|
229
|
+
model,
|
|
230
|
+
input_data=input_data_wrapped,
|
|
231
|
+
graph_name=graph_name,
|
|
232
|
+
depth=depth,
|
|
233
|
+
device=device,
|
|
234
|
+
expand_nested=expand_nested,
|
|
235
|
+
show_shapes=show_shapes,
|
|
236
|
+
save_graph=False, # Don't save here, we'll save manually with DPI
|
|
237
|
+
**kwargs
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Set DPI for high-quality output (must be set BEFORE rendering/saving)
|
|
241
|
+
graph.visual_graph.graph_attr['dpi'] = str(dpi)
|
|
242
|
+
|
|
243
|
+
# Handle save_path: manually save with DPI applied
|
|
244
|
+
if save_path:
|
|
245
|
+
import os
|
|
246
|
+
directory = os.path.dirname(save_path) or "."
|
|
247
|
+
filename = os.path.splitext(os.path.basename(save_path))[0]
|
|
248
|
+
ext = os.path.splitext(save_path)[1].lstrip('.')
|
|
249
|
+
# Default to pdf if no extension
|
|
250
|
+
output_format = ext if ext else 'pdf'
|
|
251
|
+
# Create directory if it doesn't exist
|
|
252
|
+
if directory != "." and not os.path.exists(directory):
|
|
253
|
+
os.makedirs(directory, exist_ok=True)
|
|
254
|
+
# Render and save with DPI applied
|
|
255
|
+
graph.visual_graph.render(
|
|
256
|
+
filename=filename,
|
|
257
|
+
directory=directory,
|
|
258
|
+
format=output_format,
|
|
259
|
+
cleanup=True # Remove intermediate .gv file
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Handle default display behavior when save_path is None
|
|
263
|
+
if save_path is None:
|
|
264
|
+
if _is_jupyter_environment():
|
|
265
|
+
# In Jupyter: display inline
|
|
266
|
+
display_graph(graph)
|
|
267
|
+
else:
|
|
268
|
+
# In script: open with system viewer
|
|
269
|
+
graph.visual_graph.view(cleanup=True)
|
|
270
|
+
|
|
271
|
+
return graph
|