torch-rechub 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/tracking.py +198 -0
- torch_rechub/data/__init__.py +0 -0
- torch_rechub/data/convert.py +67 -0
- torch_rechub/data/dataset.py +120 -0
- torch_rechub/trainers/ctr_trainer.py +137 -1
- torch_rechub/trainers/match_trainer.py +136 -1
- torch_rechub/trainers/mtl_trainer.py +146 -1
- torch_rechub/trainers/seq_trainer.py +193 -2
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/onnx_export.py +3 -136
- torch_rechub/utils/visualization.py +271 -0
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/METADATA +68 -49
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/RECORD +15 -9
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""Common model utility functions for Torch-RecHub.
|
|
2
|
+
|
|
3
|
+
This module provides shared utilities for model introspection and input generation,
|
|
4
|
+
used by both ONNX export and visualization features.
|
|
5
|
+
|
|
6
|
+
Examples
|
|
7
|
+
--------
|
|
8
|
+
>>> from torch_rechub.utils.model_utils import extract_feature_info, generate_dummy_input
|
|
9
|
+
>>> feature_info = extract_feature_info(model)
|
|
10
|
+
>>> dummy_input = generate_dummy_input(feature_info['features'], batch_size=2)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
|
|
18
|
+
# Import feature types for type checking
|
|
19
|
+
try:
|
|
20
|
+
from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
21
|
+
except ImportError:
|
|
22
|
+
# Fallback for standalone usage
|
|
23
|
+
SparseFeature = None
|
|
24
|
+
DenseFeature = None
|
|
25
|
+
SequenceFeature = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
|
|
29
|
+
"""Extract feature information from a torch-rechub model using reflection.
|
|
30
|
+
|
|
31
|
+
This function inspects model attributes to find feature lists without
|
|
32
|
+
modifying the model code. Supports various model architectures.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
model : nn.Module
|
|
37
|
+
The recommendation model to inspect.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
dict
|
|
42
|
+
Dictionary containing:
|
|
43
|
+
- 'features': List of unique Feature objects
|
|
44
|
+
- 'input_names': List of feature names in order
|
|
45
|
+
- 'input_types': Dict mapping feature name to feature type
|
|
46
|
+
- 'user_features': List of user-side features (for dual-tower models)
|
|
47
|
+
- 'item_features': List of item-side features (for dual-tower models)
|
|
48
|
+
|
|
49
|
+
Examples
|
|
50
|
+
--------
|
|
51
|
+
>>> from torch_rechub.models.ranking import DeepFM
|
|
52
|
+
>>> model = DeepFM(deep_features, fm_features, mlp_params)
|
|
53
|
+
>>> info = extract_feature_info(model)
|
|
54
|
+
>>> print(info['input_names']) # ['user_id', 'item_id', ...]
|
|
55
|
+
"""
|
|
56
|
+
# Common feature attribute names across different model types
|
|
57
|
+
feature_attrs = [
|
|
58
|
+
'features', # MMOE, DCN, etc.
|
|
59
|
+
'deep_features', # DeepFM, WideDeep
|
|
60
|
+
'fm_features', # DeepFM
|
|
61
|
+
'wide_features', # WideDeep
|
|
62
|
+
'linear_features', # DeepFFM
|
|
63
|
+
'cross_features', # DeepFFM
|
|
64
|
+
'user_features', # DSSM, YoutubeDNN, MIND
|
|
65
|
+
'item_features', # DSSM, YoutubeDNN, MIND
|
|
66
|
+
'history_features', # DIN, MIND
|
|
67
|
+
'target_features', # DIN
|
|
68
|
+
'neg_item_feature', # YoutubeDNN, MIND
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
all_features = []
|
|
72
|
+
user_features = []
|
|
73
|
+
item_features = []
|
|
74
|
+
|
|
75
|
+
for attr in feature_attrs:
|
|
76
|
+
if hasattr(model, attr):
|
|
77
|
+
feat_list = getattr(model, attr)
|
|
78
|
+
if isinstance(feat_list, list) and len(feat_list) > 0:
|
|
79
|
+
all_features.extend(feat_list)
|
|
80
|
+
# Track user/item features for dual-tower models
|
|
81
|
+
if 'user' in attr or 'history' in attr:
|
|
82
|
+
user_features.extend(feat_list)
|
|
83
|
+
elif 'item' in attr:
|
|
84
|
+
item_features.extend(feat_list)
|
|
85
|
+
|
|
86
|
+
# Deduplicate features by name while preserving order
|
|
87
|
+
seen = set()
|
|
88
|
+
unique_features = []
|
|
89
|
+
for f in all_features:
|
|
90
|
+
if hasattr(f, 'name') and f.name not in seen:
|
|
91
|
+
seen.add(f.name)
|
|
92
|
+
unique_features.append(f)
|
|
93
|
+
|
|
94
|
+
# Deduplicate user/item features
|
|
95
|
+
seen_user = set()
|
|
96
|
+
unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
|
|
97
|
+
seen_item = set()
|
|
98
|
+
unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
|
|
99
|
+
|
|
100
|
+
# Build input names and types
|
|
101
|
+
input_names = [f.name for f in unique_features if hasattr(f, 'name')]
|
|
102
|
+
input_types = {f.name: type(f).__name__ for f in unique_features if hasattr(f, 'name')}
|
|
103
|
+
|
|
104
|
+
return {
|
|
105
|
+
'features': unique_features,
|
|
106
|
+
'input_names': input_names,
|
|
107
|
+
'input_types': input_types,
|
|
108
|
+
'user_features': unique_user,
|
|
109
|
+
'item_features': unique_item,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
|
|
114
|
+
"""Generate dummy input tensors based on feature definitions.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
features : list
|
|
119
|
+
List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
|
|
120
|
+
batch_size : int, default=2
|
|
121
|
+
Batch size for dummy input.
|
|
122
|
+
seq_length : int, default=10
|
|
123
|
+
Sequence length for SequenceFeature.
|
|
124
|
+
device : str, default='cpu'
|
|
125
|
+
Device to create tensors on.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
tuple of Tensor
|
|
130
|
+
Tuple of tensors in the order of input features.
|
|
131
|
+
|
|
132
|
+
Examples
|
|
133
|
+
--------
|
|
134
|
+
>>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
|
|
135
|
+
>>> dummy = generate_dummy_input(features, batch_size=4)
|
|
136
|
+
>>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
|
|
137
|
+
"""
|
|
138
|
+
# Dynamic import to handle feature types
|
|
139
|
+
from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
140
|
+
|
|
141
|
+
inputs = []
|
|
142
|
+
for feat in features:
|
|
143
|
+
if isinstance(feat, SequenceFeature):
|
|
144
|
+
# Sequence features have shape [batch_size, seq_length]
|
|
145
|
+
tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
|
|
146
|
+
elif isinstance(feat, SparseFeature):
|
|
147
|
+
# Sparse features have shape [batch_size]
|
|
148
|
+
tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
|
|
149
|
+
elif isinstance(feat, DenseFeature):
|
|
150
|
+
# Dense features always have shape [batch_size, embed_dim]
|
|
151
|
+
tensor = torch.randn(batch_size, feat.embed_dim, device=device)
|
|
152
|
+
else:
|
|
153
|
+
raise TypeError(f"Unsupported feature type: {type(feat)}")
|
|
154
|
+
inputs.append(tensor)
|
|
155
|
+
return tuple(inputs)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def generate_dummy_input_dict(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Dict[str, torch.Tensor]:
|
|
159
|
+
"""Generate dummy input dict based on feature definitions.
|
|
160
|
+
|
|
161
|
+
Similar to generate_dummy_input but returns a dict mapping feature names
|
|
162
|
+
to tensors. This is the expected input format for torch-rechub models.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
features : list
|
|
167
|
+
List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
|
|
168
|
+
batch_size : int, default=2
|
|
169
|
+
Batch size for dummy input.
|
|
170
|
+
seq_length : int, default=10
|
|
171
|
+
Sequence length for SequenceFeature.
|
|
172
|
+
device : str, default='cpu'
|
|
173
|
+
Device to create tensors on.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
dict
|
|
178
|
+
Dict mapping feature names to tensors.
|
|
179
|
+
|
|
180
|
+
Examples
|
|
181
|
+
--------
|
|
182
|
+
>>> features = [SparseFeature("user_id", 1000)]
|
|
183
|
+
>>> dummy = generate_dummy_input_dict(features, batch_size=4)
|
|
184
|
+
>>> # Returns {"user_id": tensor[4]}
|
|
185
|
+
"""
|
|
186
|
+
dummy_tuple = generate_dummy_input(features, batch_size, seq_length, device)
|
|
187
|
+
input_names = [f.name for f in features if hasattr(f, 'name')]
|
|
188
|
+
return {name: tensor for name, tensor in zip(input_names, dummy_tuple)}
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def generate_dynamic_axes(input_names: List[str], output_names: Optional[List[str]] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: Optional[List[str]] = None) -> Dict[str, Dict[int, str]]:
|
|
192
|
+
"""Generate dynamic axes configuration for ONNX export.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
input_names : list of str
|
|
197
|
+
List of input tensor names.
|
|
198
|
+
output_names : list of str, optional
|
|
199
|
+
List of output tensor names. Default is ["output"].
|
|
200
|
+
batch_dim : int, default=0
|
|
201
|
+
Dimension index for batch size.
|
|
202
|
+
include_seq_dim : bool, default=True
|
|
203
|
+
Whether to include sequence dimension as dynamic.
|
|
204
|
+
seq_features : list of str, optional
|
|
205
|
+
List of feature names that are sequences.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
dict
|
|
210
|
+
Dynamic axes dict for torch.onnx.export.
|
|
211
|
+
|
|
212
|
+
Examples
|
|
213
|
+
--------
|
|
214
|
+
>>> axes = generate_dynamic_axes(["user_id", "item_id"], seq_features=["hist"])
|
|
215
|
+
>>> # Returns {"user_id": {0: "batch_size"}, "item_id": {0: "batch_size"}, ...}
|
|
216
|
+
"""
|
|
217
|
+
if output_names is None:
|
|
218
|
+
output_names = ["output"]
|
|
219
|
+
|
|
220
|
+
dynamic_axes = {}
|
|
221
|
+
|
|
222
|
+
# Input axes
|
|
223
|
+
for name in input_names:
|
|
224
|
+
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
225
|
+
# Add sequence dimension for sequence features
|
|
226
|
+
if include_seq_dim and seq_features and name in seq_features:
|
|
227
|
+
dynamic_axes[name][1] = "seq_length"
|
|
228
|
+
|
|
229
|
+
# Output axes
|
|
230
|
+
for name in output_names:
|
|
231
|
+
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
232
|
+
|
|
233
|
+
return dynamic_axes
|
|
@@ -62,142 +62,9 @@ class ONNXWrapper(nn.Module):
|
|
|
62
62
|
self.model.mode = self._original_mode
|
|
63
63
|
|
|
64
64
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
This function inspects model attributes to find feature lists without
|
|
69
|
-
modifying the model code. Supports various model architectures.
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
model: The recommendation model to inspect.
|
|
73
|
-
|
|
74
|
-
Returns:
|
|
75
|
-
Dict containing:
|
|
76
|
-
- 'features': List of unique Feature objects
|
|
77
|
-
- 'input_names': List of feature names in order
|
|
78
|
-
- 'input_types': Dict mapping feature name to feature type
|
|
79
|
-
- 'user_features': List of user-side features (for dual-tower models)
|
|
80
|
-
- 'item_features': List of item-side features (for dual-tower models)
|
|
81
|
-
"""
|
|
82
|
-
# Common feature attribute names across different model types
|
|
83
|
-
feature_attrs = [
|
|
84
|
-
'features', # MMOE, DCN, etc.
|
|
85
|
-
'deep_features', # DeepFM, WideDeep
|
|
86
|
-
'fm_features', # DeepFM
|
|
87
|
-
'wide_features', # WideDeep
|
|
88
|
-
'linear_features', # DeepFFM
|
|
89
|
-
'cross_features', # DeepFFM
|
|
90
|
-
'user_features', # DSSM, YoutubeDNN, MIND
|
|
91
|
-
'item_features', # DSSM, YoutubeDNN, MIND
|
|
92
|
-
'history_features', # DIN, MIND
|
|
93
|
-
'target_features', # DIN
|
|
94
|
-
'neg_item_feature', # YoutubeDNN, MIND
|
|
95
|
-
]
|
|
96
|
-
|
|
97
|
-
all_features = []
|
|
98
|
-
user_features = []
|
|
99
|
-
item_features = []
|
|
100
|
-
|
|
101
|
-
for attr in feature_attrs:
|
|
102
|
-
if hasattr(model, attr):
|
|
103
|
-
feat_list = getattr(model, attr)
|
|
104
|
-
if isinstance(feat_list, list) and len(feat_list) > 0:
|
|
105
|
-
all_features.extend(feat_list)
|
|
106
|
-
# Track user/item features for dual-tower models
|
|
107
|
-
if 'user' in attr or 'history' in attr:
|
|
108
|
-
user_features.extend(feat_list)
|
|
109
|
-
elif 'item' in attr:
|
|
110
|
-
item_features.extend(feat_list)
|
|
111
|
-
|
|
112
|
-
# Deduplicate features by name while preserving order
|
|
113
|
-
seen = set()
|
|
114
|
-
unique_features = []
|
|
115
|
-
for f in all_features:
|
|
116
|
-
if hasattr(f, 'name') and f.name not in seen:
|
|
117
|
-
seen.add(f.name)
|
|
118
|
-
unique_features.append(f)
|
|
119
|
-
|
|
120
|
-
# Deduplicate user/item features
|
|
121
|
-
seen_user = set()
|
|
122
|
-
unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
|
|
123
|
-
seen_item = set()
|
|
124
|
-
unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
|
|
125
|
-
|
|
126
|
-
return {
|
|
127
|
-
'features': unique_features,
|
|
128
|
-
'input_names': [f.name for f in unique_features],
|
|
129
|
-
'input_types': {
|
|
130
|
-
f.name: type(f).__name__ for f in unique_features
|
|
131
|
-
},
|
|
132
|
-
'user_features': unique_user,
|
|
133
|
-
'item_features': unique_item,
|
|
134
|
-
}
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
|
|
138
|
-
"""Generate dummy input tensors for ONNX export based on feature definitions.
|
|
139
|
-
|
|
140
|
-
Args:
|
|
141
|
-
features: List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
|
|
142
|
-
batch_size: Batch size for dummy input (default: 2).
|
|
143
|
-
seq_length: Sequence length for SequenceFeature (default: 10).
|
|
144
|
-
device: Device to create tensors on (default: 'cpu').
|
|
145
|
-
|
|
146
|
-
Returns:
|
|
147
|
-
Tuple of tensors in the order of input features.
|
|
148
|
-
|
|
149
|
-
Example:
|
|
150
|
-
>>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
|
|
151
|
-
>>> dummy = generate_dummy_input(features, batch_size=4)
|
|
152
|
-
>>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
|
|
153
|
-
"""
|
|
154
|
-
inputs = []
|
|
155
|
-
for feat in features:
|
|
156
|
-
if isinstance(feat, SequenceFeature):
|
|
157
|
-
# Sequence features have shape [batch_size, seq_length]
|
|
158
|
-
tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
|
|
159
|
-
elif isinstance(feat, SparseFeature):
|
|
160
|
-
# Sparse features have shape [batch_size]
|
|
161
|
-
tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
|
|
162
|
-
elif isinstance(feat, DenseFeature):
|
|
163
|
-
# Dense features have shape [batch_size, embed_dim]
|
|
164
|
-
tensor = torch.randn(batch_size, feat.embed_dim, device=device)
|
|
165
|
-
else:
|
|
166
|
-
raise TypeError(f"Unsupported feature type: {type(feat)}")
|
|
167
|
-
inputs.append(tensor)
|
|
168
|
-
return tuple(inputs)
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
def generate_dynamic_axes(input_names: List[str], output_names: List[str] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: List[str] = None) -> Dict[str, Dict[int, str]]:
|
|
172
|
-
"""Generate dynamic axes configuration for ONNX export.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
input_names: List of input tensor names.
|
|
176
|
-
output_names: List of output tensor names (default: ["output"]).
|
|
177
|
-
batch_dim: Dimension index for batch size (default: 0).
|
|
178
|
-
include_seq_dim: Whether to include sequence dimension as dynamic (default: True).
|
|
179
|
-
seq_features: List of feature names that are sequences (default: auto-detect).
|
|
180
|
-
|
|
181
|
-
Returns:
|
|
182
|
-
Dynamic axes dict for torch.onnx.export.
|
|
183
|
-
"""
|
|
184
|
-
if output_names is None:
|
|
185
|
-
output_names = ["output"]
|
|
186
|
-
|
|
187
|
-
dynamic_axes = {}
|
|
188
|
-
|
|
189
|
-
# Input axes
|
|
190
|
-
for name in input_names:
|
|
191
|
-
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
192
|
-
# Add sequence dimension for sequence features
|
|
193
|
-
if include_seq_dim and seq_features and name in seq_features:
|
|
194
|
-
dynamic_axes[name][1] = "seq_length"
|
|
195
|
-
|
|
196
|
-
# Output axes
|
|
197
|
-
for name in output_names:
|
|
198
|
-
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
199
|
-
|
|
200
|
-
return dynamic_axes
|
|
65
|
+
# Re-export from model_utils for backward compatibility
|
|
66
|
+
# The actual implementations are now in model_utils.py
|
|
67
|
+
from .model_utils import extract_feature_info, generate_dummy_input, generate_dummy_input_dict, generate_dynamic_axes
|
|
201
68
|
|
|
202
69
|
|
|
203
70
|
class ONNXExporter:
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Visualization Utilities for Torch-RecHub.
|
|
3
|
+
|
|
4
|
+
This module provides model structure visualization using torchview library.
|
|
5
|
+
Requires optional dependencies: pip install torch-rechub[visualization]
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
>>> from torch_rechub.utils.visualization import visualize_model, display_graph
|
|
9
|
+
>>> graph = visualize_model(model, depth=4)
|
|
10
|
+
>>> display_graph(graph) # Display in Jupyter Notebook
|
|
11
|
+
|
|
12
|
+
>>> # Save to file
|
|
13
|
+
>>> visualize_model(model, save_path="model_arch.pdf")
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from typing import Any, Dict, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
|
|
21
|
+
# Check for optional dependencies
|
|
22
|
+
TORCHVIEW_AVAILABLE = False
|
|
23
|
+
TORCHVIEW_SKIP_REASON = "torchview not installed"
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from torchview import draw_graph
|
|
27
|
+
TORCHVIEW_AVAILABLE = True
|
|
28
|
+
except ImportError as e:
|
|
29
|
+
TORCHVIEW_SKIP_REASON = f"torchview not available: {e}"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _is_jupyter_environment() -> bool:
|
|
33
|
+
"""Check if running in Jupyter/IPython environment."""
|
|
34
|
+
try:
|
|
35
|
+
from IPython import get_ipython
|
|
36
|
+
shell = get_ipython()
|
|
37
|
+
if shell is None:
|
|
38
|
+
return False
|
|
39
|
+
# Check for Jupyter notebook or qtconsole
|
|
40
|
+
shell_class = shell.__class__.__name__
|
|
41
|
+
return shell_class in ('ZMQInteractiveShell', 'TerminalInteractiveShell')
|
|
42
|
+
except (ImportError, NameError):
|
|
43
|
+
return False
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def display_graph(graph: Any, format: str = 'png') -> Any:
|
|
47
|
+
"""Display a torchview ComputationGraph in Jupyter Notebook.
|
|
48
|
+
|
|
49
|
+
This function provides a reliable way to display visualization graphs
|
|
50
|
+
in Jupyter environments, especially VSCode Jupyter.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
graph : ComputationGraph
|
|
55
|
+
A torchview ComputationGraph object returned by visualize_model().
|
|
56
|
+
format : str, default='png'
|
|
57
|
+
Output format, 'png' recommended for VSCode.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
graphviz.Digraph or None
|
|
62
|
+
The displayed graph object, or None if display fails.
|
|
63
|
+
|
|
64
|
+
Examples
|
|
65
|
+
--------
|
|
66
|
+
>>> graph = visualize_model(model, depth=4)
|
|
67
|
+
>>> display_graph(graph) # Works in VSCode Jupyter
|
|
68
|
+
"""
|
|
69
|
+
if not TORCHVIEW_AVAILABLE:
|
|
70
|
+
raise ImportError(f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
|
|
71
|
+
"Install with: pip install torch-rechub[visualization]")
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
import graphviz
|
|
75
|
+
|
|
76
|
+
# Set format for VSCode compatibility
|
|
77
|
+
graphviz.set_jupyter_format(format)
|
|
78
|
+
except ImportError:
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
# Get the visual_graph (graphviz.Digraph object)
|
|
82
|
+
visual = graph.visual_graph
|
|
83
|
+
|
|
84
|
+
# Try to use IPython display for explicit rendering
|
|
85
|
+
try:
|
|
86
|
+
from IPython.display import display
|
|
87
|
+
display(visual)
|
|
88
|
+
return visual
|
|
89
|
+
except ImportError:
|
|
90
|
+
# Not in IPython/Jupyter environment, return the graph
|
|
91
|
+
return visual
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def visualize_model(
|
|
95
|
+
model: nn.Module,
|
|
96
|
+
input_data: Optional[Dict[str,
|
|
97
|
+
torch.Tensor]] = None,
|
|
98
|
+
batch_size: int = 2,
|
|
99
|
+
seq_length: int = 10,
|
|
100
|
+
depth: int = 3,
|
|
101
|
+
show_shapes: bool = True,
|
|
102
|
+
expand_nested: bool = True,
|
|
103
|
+
save_path: Optional[str] = None,
|
|
104
|
+
graph_name: str = "model",
|
|
105
|
+
device: str = "cpu",
|
|
106
|
+
dpi: int = 300,
|
|
107
|
+
**kwargs
|
|
108
|
+
) -> Any:
|
|
109
|
+
"""Visualize a Torch-RecHub model's computation graph.
|
|
110
|
+
|
|
111
|
+
This function generates a visual representation of the model architecture,
|
|
112
|
+
showing layer connections, tensor shapes, and nested module structures.
|
|
113
|
+
It automatically extracts feature information from the model to generate
|
|
114
|
+
appropriate dummy inputs.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
model : nn.Module
|
|
119
|
+
PyTorch model to visualize. Should be a Torch-RecHub model
|
|
120
|
+
with feature attributes (e.g., DeepFM, DSSM, MMOE).
|
|
121
|
+
input_data : dict, optional
|
|
122
|
+
Dict of example inputs {feature_name: tensor}.
|
|
123
|
+
If None, inputs are auto-generated based on model features.
|
|
124
|
+
batch_size : int, default=2
|
|
125
|
+
Batch size for auto-generated inputs.
|
|
126
|
+
seq_length : int, default=10
|
|
127
|
+
Sequence length for SequenceFeature inputs.
|
|
128
|
+
depth : int, default=3
|
|
129
|
+
Visualization depth - higher values show more detail.
|
|
130
|
+
Set to -1 to show all layers.
|
|
131
|
+
show_shapes : bool, default=True
|
|
132
|
+
Whether to display tensor shapes on edges.
|
|
133
|
+
expand_nested : bool, default=True
|
|
134
|
+
Whether to expand nested nn.Module with dashed borders.
|
|
135
|
+
save_path : str, optional
|
|
136
|
+
Path to save the graph image. Supports .pdf, .svg, .png formats.
|
|
137
|
+
If None, displays in Jupyter or opens system viewer.
|
|
138
|
+
graph_name : str, default="model"
|
|
139
|
+
Name for the computation graph.
|
|
140
|
+
device : str, default="cpu"
|
|
141
|
+
Device for model execution during tracing.
|
|
142
|
+
dpi : int, default=300
|
|
143
|
+
Resolution in dots per inch for output image.
|
|
144
|
+
Higher values produce sharper images suitable for papers.
|
|
145
|
+
**kwargs : dict
|
|
146
|
+
Additional arguments passed to torchview.draw_graph().
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
ComputationGraph
|
|
151
|
+
A torchview ComputationGraph object.
|
|
152
|
+
- Use `.visual_graph` property to get the graphviz.Digraph
|
|
153
|
+
- Use `.resize_graph(scale=1.5)` to adjust graph size
|
|
154
|
+
|
|
155
|
+
Raises
|
|
156
|
+
------
|
|
157
|
+
ImportError
|
|
158
|
+
If torchview or graphviz is not installed.
|
|
159
|
+
ValueError
|
|
160
|
+
If model has no recognizable feature attributes.
|
|
161
|
+
|
|
162
|
+
Notes
|
|
163
|
+
-----
|
|
164
|
+
Default Display Behavior:
|
|
165
|
+
When `save_path` is None (default):
|
|
166
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
167
|
+
- In Python script: opens the graph with system default viewer
|
|
168
|
+
|
|
169
|
+
Requires graphviz system package: apt/brew/choco install graphviz.
|
|
170
|
+
For Jupyter display issues, try: graphviz.set_jupyter_format('png').
|
|
171
|
+
|
|
172
|
+
Examples
|
|
173
|
+
--------
|
|
174
|
+
>>> from torch_rechub.models.ranking import DeepFM
|
|
175
|
+
>>> from torch_rechub.utils.visualization import visualize_model
|
|
176
|
+
>>>
|
|
177
|
+
>>> # Auto-display in Jupyter or open in viewer
|
|
178
|
+
>>> visualize_model(model, depth=4) # No save_path needed
|
|
179
|
+
>>>
|
|
180
|
+
>>> # Save to high-DPI PNG for paper
|
|
181
|
+
>>> visualize_model(model, save_path="model.png", dpi=300)
|
|
182
|
+
"""
|
|
183
|
+
if not TORCHVIEW_AVAILABLE:
|
|
184
|
+
raise ImportError(
|
|
185
|
+
f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
|
|
186
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
187
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
188
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
189
|
+
" - macOS: brew install graphviz\n"
|
|
190
|
+
" - Windows: choco install graphviz"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Import feature extraction utilities from model_utils
|
|
194
|
+
from .model_utils import extract_feature_info, generate_dummy_input_dict
|
|
195
|
+
|
|
196
|
+
model.eval()
|
|
197
|
+
model.to(device)
|
|
198
|
+
|
|
199
|
+
# Auto-generate input data if not provided
|
|
200
|
+
if input_data is None:
|
|
201
|
+
feature_info = extract_feature_info(model)
|
|
202
|
+
features = feature_info['features']
|
|
203
|
+
|
|
204
|
+
if not features:
|
|
205
|
+
raise ValueError("Could not extract feature information from model. "
|
|
206
|
+
"Please provide input_data parameter manually.")
|
|
207
|
+
|
|
208
|
+
# Generate dummy input dict
|
|
209
|
+
input_data = generate_dummy_input_dict(features, batch_size=batch_size, seq_length=seq_length, device=device)
|
|
210
|
+
else:
|
|
211
|
+
# Ensure input tensors are on correct device
|
|
212
|
+
input_data = {k: v.to(device) for k, v in input_data.items()}
|
|
213
|
+
|
|
214
|
+
# IMPORTANT: Wrap input_data dict in a tuple to work around torchview's behavior
|
|
215
|
+
#
|
|
216
|
+
# torchview's forward_prop function checks the type of input_data:
|
|
217
|
+
# - If isinstance(x, (list, tuple)): model(*x)
|
|
218
|
+
# - If isinstance(x, Mapping): model(**x) <- This unpacks dict as kwargs!
|
|
219
|
+
# - Else: model(x)
|
|
220
|
+
#
|
|
221
|
+
# torch-rechub models expect forward(self, x) where x is a complete dict.
|
|
222
|
+
# By wrapping the dict in a tuple, torchview will call:
|
|
223
|
+
# model(*(input_dict,)) = model(input_dict)
|
|
224
|
+
# which is exactly what our models expect.
|
|
225
|
+
input_data_wrapped = (input_data,)
|
|
226
|
+
|
|
227
|
+
# Call torchview.draw_graph without saving (we'll save manually with DPI)
|
|
228
|
+
graph = draw_graph(
|
|
229
|
+
model,
|
|
230
|
+
input_data=input_data_wrapped,
|
|
231
|
+
graph_name=graph_name,
|
|
232
|
+
depth=depth,
|
|
233
|
+
device=device,
|
|
234
|
+
expand_nested=expand_nested,
|
|
235
|
+
show_shapes=show_shapes,
|
|
236
|
+
save_graph=False, # Don't save here, we'll save manually with DPI
|
|
237
|
+
**kwargs
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Set DPI for high-quality output (must be set BEFORE rendering/saving)
|
|
241
|
+
graph.visual_graph.graph_attr['dpi'] = str(dpi)
|
|
242
|
+
|
|
243
|
+
# Handle save_path: manually save with DPI applied
|
|
244
|
+
if save_path:
|
|
245
|
+
import os
|
|
246
|
+
directory = os.path.dirname(save_path) or "."
|
|
247
|
+
filename = os.path.splitext(os.path.basename(save_path))[0]
|
|
248
|
+
ext = os.path.splitext(save_path)[1].lstrip('.')
|
|
249
|
+
# Default to pdf if no extension
|
|
250
|
+
output_format = ext if ext else 'pdf'
|
|
251
|
+
# Create directory if it doesn't exist
|
|
252
|
+
if directory != "." and not os.path.exists(directory):
|
|
253
|
+
os.makedirs(directory, exist_ok=True)
|
|
254
|
+
# Render and save with DPI applied
|
|
255
|
+
graph.visual_graph.render(
|
|
256
|
+
filename=filename,
|
|
257
|
+
directory=directory,
|
|
258
|
+
format=output_format,
|
|
259
|
+
cleanup=True # Remove intermediate .gv file
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Handle default display behavior when save_path is None
|
|
263
|
+
if save_path is None:
|
|
264
|
+
if _is_jupyter_environment():
|
|
265
|
+
# In Jupyter: display inline
|
|
266
|
+
display_graph(graph)
|
|
267
|
+
else:
|
|
268
|
+
# In script: open with system viewer
|
|
269
|
+
graph.visual_graph.view(cleanup=True)
|
|
270
|
+
|
|
271
|
+
return graph
|