torch-rechub 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -189,3 +189,100 @@ class CTRTrainer(object):
189
189
 
190
190
  exporter = ONNXExporter(model, device=export_device)
191
191
  return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
192
+
193
+ def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
194
+ """Visualize the model's computation graph.
195
+
196
+ This method generates a visual representation of the model architecture,
197
+ showing layer connections, tensor shapes, and nested module structures.
198
+ It automatically extracts feature information from the model.
199
+
200
+ Parameters
201
+ ----------
202
+ input_data : dict, optional
203
+ Example input dict {feature_name: tensor}.
204
+ If not provided, dummy inputs will be generated automatically.
205
+ batch_size : int, default=2
206
+ Batch size for auto-generated dummy input.
207
+ seq_length : int, default=10
208
+ Sequence length for SequenceFeature.
209
+ depth : int, default=3
210
+ Visualization depth, higher values show more detail.
211
+ Set to -1 to show all layers.
212
+ show_shapes : bool, default=True
213
+ Whether to display tensor shapes.
214
+ expand_nested : bool, default=True
215
+ Whether to expand nested modules.
216
+ save_path : str, optional
217
+ Path to save the graph image (.pdf, .svg, .png).
218
+ If None, displays in Jupyter or opens system viewer.
219
+ graph_name : str, default="model"
220
+ Name for the graph.
221
+ device : str, optional
222
+ Device for model execution. If None, defaults to 'cpu'.
223
+ dpi : int, default=300
224
+ Resolution in dots per inch for output image.
225
+ Higher values produce sharper images suitable for papers.
226
+ **kwargs : dict
227
+ Additional arguments passed to torchview.draw_graph().
228
+
229
+ Returns
230
+ -------
231
+ ComputationGraph
232
+ A torchview ComputationGraph object.
233
+
234
+ Raises
235
+ ------
236
+ ImportError
237
+ If torchview or graphviz is not installed.
238
+
239
+ Notes
240
+ -----
241
+ Default Display Behavior:
242
+ When `save_path` is None (default):
243
+ - In Jupyter/IPython: automatically displays the graph inline
244
+ - In Python script: opens the graph with system default viewer
245
+
246
+ Examples
247
+ --------
248
+ >>> trainer = CTRTrainer(model, ...)
249
+ >>> trainer.fit(train_dl, val_dl)
250
+ >>>
251
+ >>> # Auto-display in Jupyter (no save_path needed)
252
+ >>> trainer.visualization(depth=4)
253
+ >>>
254
+ >>> # Save to high-DPI PNG for papers
255
+ >>> trainer.visualization(save_path="model.png", dpi=300)
256
+ """
257
+ from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
258
+
259
+ if not TORCHVIEW_AVAILABLE:
260
+ raise ImportError(
261
+ "Visualization requires torchview. "
262
+ "Install with: pip install torch-rechub[visualization]\n"
263
+ "Also ensure graphviz is installed on your system:\n"
264
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
265
+ " - macOS: brew install graphviz\n"
266
+ " - Windows: choco install graphviz"
267
+ )
268
+
269
+ # Handle DataParallel wrapped model
270
+ model = self.model.module if hasattr(self.model, 'module') else self.model
271
+
272
+ # Use provided device or default to 'cpu'
273
+ viz_device = device if device is not None else 'cpu'
274
+
275
+ return visualize_model(
276
+ model,
277
+ input_data=input_data,
278
+ batch_size=batch_size,
279
+ seq_length=seq_length,
280
+ depth=depth,
281
+ show_shapes=show_shapes,
282
+ expand_nested=expand_nested,
283
+ save_path=save_path,
284
+ graph_name=graph_name,
285
+ device=viz_device,
286
+ dpi=dpi,
287
+ **kwargs
288
+ )
@@ -237,3 +237,100 @@ class MatchTrainer(object):
237
237
  # Restore original mode
238
238
  if hasattr(model, 'mode'):
239
239
  model.mode = original_mode
240
+
241
+ def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
242
+ """Visualize the model's computation graph.
243
+
244
+ This method generates a visual representation of the model architecture,
245
+ showing layer connections, tensor shapes, and nested module structures.
246
+ It automatically extracts feature information from the model.
247
+
248
+ Parameters
249
+ ----------
250
+ input_data : dict, optional
251
+ Example input dict {feature_name: tensor}.
252
+ If not provided, dummy inputs will be generated automatically.
253
+ batch_size : int, default=2
254
+ Batch size for auto-generated dummy input.
255
+ seq_length : int, default=10
256
+ Sequence length for SequenceFeature.
257
+ depth : int, default=3
258
+ Visualization depth, higher values show more detail.
259
+ Set to -1 to show all layers.
260
+ show_shapes : bool, default=True
261
+ Whether to display tensor shapes.
262
+ expand_nested : bool, default=True
263
+ Whether to expand nested modules.
264
+ save_path : str, optional
265
+ Path to save the graph image (.pdf, .svg, .png).
266
+ If None, displays in Jupyter or opens system viewer.
267
+ graph_name : str, default="model"
268
+ Name for the graph.
269
+ device : str, optional
270
+ Device for model execution. If None, defaults to 'cpu'.
271
+ dpi : int, default=300
272
+ Resolution in dots per inch for output image.
273
+ Higher values produce sharper images suitable for papers.
274
+ **kwargs : dict
275
+ Additional arguments passed to torchview.draw_graph().
276
+
277
+ Returns
278
+ -------
279
+ ComputationGraph
280
+ A torchview ComputationGraph object.
281
+
282
+ Raises
283
+ ------
284
+ ImportError
285
+ If torchview or graphviz is not installed.
286
+
287
+ Notes
288
+ -----
289
+ Default Display Behavior:
290
+ When `save_path` is None (default):
291
+ - In Jupyter/IPython: automatically displays the graph inline
292
+ - In Python script: opens the graph with system default viewer
293
+
294
+ Examples
295
+ --------
296
+ >>> trainer = MatchTrainer(model, ...)
297
+ >>> trainer.fit(train_dl)
298
+ >>>
299
+ >>> # Auto-display in Jupyter (no save_path needed)
300
+ >>> trainer.visualization(depth=4)
301
+ >>>
302
+ >>> # Save to high-DPI PNG for papers
303
+ >>> trainer.visualization(save_path="model.png", dpi=300)
304
+ """
305
+ from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
306
+
307
+ if not TORCHVIEW_AVAILABLE:
308
+ raise ImportError(
309
+ "Visualization requires torchview. "
310
+ "Install with: pip install torch-rechub[visualization]\n"
311
+ "Also ensure graphviz is installed on your system:\n"
312
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
313
+ " - macOS: brew install graphviz\n"
314
+ " - Windows: choco install graphviz"
315
+ )
316
+
317
+ # Handle DataParallel wrapped model
318
+ model = self.model.module if hasattr(self.model, 'module') else self.model
319
+
320
+ # Use provided device or default to 'cpu'
321
+ viz_device = device if device is not None else 'cpu'
322
+
323
+ return visualize_model(
324
+ model,
325
+ input_data=input_data,
326
+ batch_size=batch_size,
327
+ seq_length=seq_length,
328
+ depth=depth,
329
+ show_shapes=show_shapes,
330
+ expand_nested=expand_nested,
331
+ save_path=save_path,
332
+ graph_name=graph_name,
333
+ device=viz_device,
334
+ dpi=dpi,
335
+ **kwargs
336
+ )
@@ -257,3 +257,100 @@ class MTLTrainer(object):
257
257
 
258
258
  exporter = ONNXExporter(model, device=export_device)
259
259
  return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
260
+
261
+ def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
262
+ """Visualize the model's computation graph.
263
+
264
+ This method generates a visual representation of the model architecture,
265
+ showing layer connections, tensor shapes, and nested module structures.
266
+ It automatically extracts feature information from the model.
267
+
268
+ Parameters
269
+ ----------
270
+ input_data : dict, optional
271
+ Example input dict {feature_name: tensor}.
272
+ If not provided, dummy inputs will be generated automatically.
273
+ batch_size : int, default=2
274
+ Batch size for auto-generated dummy input.
275
+ seq_length : int, default=10
276
+ Sequence length for SequenceFeature.
277
+ depth : int, default=3
278
+ Visualization depth, higher values show more detail.
279
+ Set to -1 to show all layers.
280
+ show_shapes : bool, default=True
281
+ Whether to display tensor shapes.
282
+ expand_nested : bool, default=True
283
+ Whether to expand nested modules.
284
+ save_path : str, optional
285
+ Path to save the graph image (.pdf, .svg, .png).
286
+ If None, displays in Jupyter or opens system viewer.
287
+ graph_name : str, default="model"
288
+ Name for the graph.
289
+ device : str, optional
290
+ Device for model execution. If None, defaults to 'cpu'.
291
+ dpi : int, default=300
292
+ Resolution in dots per inch for output image.
293
+ Higher values produce sharper images suitable for papers.
294
+ **kwargs : dict
295
+ Additional arguments passed to torchview.draw_graph().
296
+
297
+ Returns
298
+ -------
299
+ ComputationGraph
300
+ A torchview ComputationGraph object.
301
+
302
+ Raises
303
+ ------
304
+ ImportError
305
+ If torchview or graphviz is not installed.
306
+
307
+ Notes
308
+ -----
309
+ Default Display Behavior:
310
+ When `save_path` is None (default):
311
+ - In Jupyter/IPython: automatically displays the graph inline
312
+ - In Python script: opens the graph with system default viewer
313
+
314
+ Examples
315
+ --------
316
+ >>> trainer = MTLTrainer(model, task_types=["classification", "classification"])
317
+ >>> trainer.fit(train_dl, val_dl)
318
+ >>>
319
+ >>> # Auto-display in Jupyter (no save_path needed)
320
+ >>> trainer.visualization(depth=4)
321
+ >>>
322
+ >>> # Save to high-DPI PNG for papers
323
+ >>> trainer.visualization(save_path="model.png", dpi=300)
324
+ """
325
+ from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
326
+
327
+ if not TORCHVIEW_AVAILABLE:
328
+ raise ImportError(
329
+ "Visualization requires torchview. "
330
+ "Install with: pip install torch-rechub[visualization]\n"
331
+ "Also ensure graphviz is installed on your system:\n"
332
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
333
+ " - macOS: brew install graphviz\n"
334
+ " - Windows: choco install graphviz"
335
+ )
336
+
337
+ # Handle DataParallel wrapped model
338
+ model = self.model.module if hasattr(self.model, 'module') else self.model
339
+
340
+ # Use provided device or default to 'cpu'
341
+ viz_device = device if device is not None else 'cpu'
342
+
343
+ return visualize_model(
344
+ model,
345
+ input_data=input_data,
346
+ batch_size=batch_size,
347
+ seq_length=seq_length,
348
+ depth=depth,
349
+ show_shapes=show_shapes,
350
+ expand_nested=expand_nested,
351
+ save_path=save_path,
352
+ graph_name=graph_name,
353
+ device=viz_device,
354
+ dpi=dpi,
355
+ **kwargs
356
+ )
@@ -291,3 +291,137 @@ class SeqTrainer(object):
291
291
  except Exception as e:
292
292
  warnings.warn(f"ONNX export failed: {str(e)}")
293
293
  raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
294
+
295
+ def visualization(self, seq_length=50, vocab_size=None, batch_size=2, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
296
+ """Visualize the model's computation graph.
297
+
298
+ This method generates a visual representation of the sequence model
299
+ architecture, showing layer connections, tensor shapes, and nested
300
+ module structures.
301
+
302
+ Parameters
303
+ ----------
304
+ seq_length : int, default=50
305
+ Sequence length for dummy input.
306
+ vocab_size : int, optional
307
+ Vocabulary size for generating dummy tokens.
308
+ If None, will try to get from model.vocab_size or model.item_num.
309
+ batch_size : int, default=2
310
+ Batch size for dummy input.
311
+ depth : int, default=3
312
+ Visualization depth, higher values show more detail.
313
+ Set to -1 to show all layers.
314
+ show_shapes : bool, default=True
315
+ Whether to display tensor shapes.
316
+ expand_nested : bool, default=True
317
+ Whether to expand nested modules.
318
+ save_path : str, optional
319
+ Path to save the graph image (.pdf, .svg, .png).
320
+ If None, displays in Jupyter or opens system viewer.
321
+ graph_name : str, default="model"
322
+ Name for the graph.
323
+ device : str, optional
324
+ Device for model execution. If None, defaults to 'cpu'.
325
+ dpi : int, default=300
326
+ Resolution in dots per inch for output image.
327
+ Higher values produce sharper images suitable for papers.
328
+ **kwargs : dict
329
+ Additional arguments passed to torchview.draw_graph().
330
+
331
+ Returns
332
+ -------
333
+ ComputationGraph
334
+ A torchview ComputationGraph object.
335
+
336
+ Raises
337
+ ------
338
+ ImportError
339
+ If torchview or graphviz is not installed.
340
+ ValueError
341
+ If vocab_size is not provided and cannot be inferred from model.
342
+
343
+ Notes
344
+ -----
345
+ Default Display Behavior:
346
+ When `save_path` is None (default):
347
+ - In Jupyter/IPython: automatically displays the graph inline
348
+ - In Python script: opens the graph with system default viewer
349
+
350
+ Examples
351
+ --------
352
+ >>> trainer = SeqTrainer(hstu_model, ...)
353
+ >>> trainer.fit(train_dl, val_dl)
354
+ >>>
355
+ >>> # Auto-display in Jupyter (no save_path needed)
356
+ >>> trainer.visualization(depth=4, vocab_size=10000)
357
+ >>>
358
+ >>> # Save to high-DPI PNG for papers
359
+ >>> trainer.visualization(save_path="model.png", dpi=300)
360
+ """
361
+ try:
362
+ from torchview import draw_graph
363
+ TORCHVIEW_AVAILABLE = True
364
+ except ImportError:
365
+ TORCHVIEW_AVAILABLE = False
366
+
367
+ if not TORCHVIEW_AVAILABLE:
368
+ raise ImportError(
369
+ "Visualization requires torchview. "
370
+ "Install with: pip install torch-rechub[visualization]\n"
371
+ "Also ensure graphviz is installed on your system:\n"
372
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
373
+ " - macOS: brew install graphviz\n"
374
+ " - Windows: choco install graphviz"
375
+ )
376
+
377
+ from ..utils.visualization import _is_jupyter_environment, display_graph
378
+
379
+ # Handle DataParallel wrapped model
380
+ model = self.model.module if hasattr(self.model, 'module') else self.model
381
+
382
+ # Use provided device or default to 'cpu'
383
+ viz_device = device if device is not None else 'cpu'
384
+
385
+ # Get vocab_size from model if not provided
386
+ if vocab_size is None:
387
+ if hasattr(model, 'vocab_size'):
388
+ vocab_size = model.vocab_size
389
+ elif hasattr(model, 'item_num'):
390
+ vocab_size = model.item_num
391
+ else:
392
+ raise ValueError("vocab_size must be provided or model must have "
393
+ "'vocab_size' or 'item_num' attribute")
394
+
395
+ # Generate dummy inputs for sequence model
396
+ dummy_seq_tokens = torch.randint(0, vocab_size, (batch_size, seq_length), device=viz_device)
397
+ dummy_seq_time_diffs = torch.zeros(batch_size, seq_length, dtype=torch.float32, device=viz_device)
398
+
399
+ # Move model to device
400
+ model = model.to(viz_device)
401
+ model.eval()
402
+
403
+ # Call torchview.draw_graph
404
+ graph = draw_graph(model, input_data=(dummy_seq_tokens, dummy_seq_time_diffs), graph_name=graph_name, depth=depth, device=viz_device, expand_nested=expand_nested, show_shapes=show_shapes, save_graph=False, **kwargs)
405
+
406
+ # Set DPI for high-quality output
407
+ graph.visual_graph.graph_attr['dpi'] = str(dpi)
408
+
409
+ # Handle save_path: manually save with DPI applied
410
+ if save_path:
411
+ import os
412
+ directory = os.path.dirname(save_path) or "."
413
+ filename = os.path.splitext(os.path.basename(save_path))[0]
414
+ ext = os.path.splitext(save_path)[1].lstrip('.')
415
+ output_format = ext if ext else 'pdf'
416
+ if directory != "." and not os.path.exists(directory):
417
+ os.makedirs(directory, exist_ok=True)
418
+ graph.visual_graph.render(filename=filename, directory=directory, format=output_format, cleanup=True)
419
+
420
+ # Handle default display behavior when save_path is None
421
+ if save_path is None:
422
+ if _is_jupyter_environment():
423
+ display_graph(graph)
424
+ else:
425
+ graph.visual_graph.view(cleanup=True)
426
+
427
+ return graph
@@ -0,0 +1,233 @@
1
+ """Common model utility functions for Torch-RecHub.
2
+
3
+ This module provides shared utilities for model introspection and input generation,
4
+ used by both ONNX export and visualization features.
5
+
6
+ Examples
7
+ --------
8
+ >>> from torch_rechub.utils.model_utils import extract_feature_info, generate_dummy_input
9
+ >>> feature_info = extract_feature_info(model)
10
+ >>> dummy_input = generate_dummy_input(feature_info['features'], batch_size=2)
11
+ """
12
+
13
+ from typing import Any, Dict, List, Optional, Tuple
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ # Import feature types for type checking
19
+ try:
20
+ from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
21
+ except ImportError:
22
+ # Fallback for standalone usage
23
+ SparseFeature = None
24
+ DenseFeature = None
25
+ SequenceFeature = None
26
+
27
+
28
+ def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
29
+ """Extract feature information from a torch-rechub model using reflection.
30
+
31
+ This function inspects model attributes to find feature lists without
32
+ modifying the model code. Supports various model architectures.
33
+
34
+ Parameters
35
+ ----------
36
+ model : nn.Module
37
+ The recommendation model to inspect.
38
+
39
+ Returns
40
+ -------
41
+ dict
42
+ Dictionary containing:
43
+ - 'features': List of unique Feature objects
44
+ - 'input_names': List of feature names in order
45
+ - 'input_types': Dict mapping feature name to feature type
46
+ - 'user_features': List of user-side features (for dual-tower models)
47
+ - 'item_features': List of item-side features (for dual-tower models)
48
+
49
+ Examples
50
+ --------
51
+ >>> from torch_rechub.models.ranking import DeepFM
52
+ >>> model = DeepFM(deep_features, fm_features, mlp_params)
53
+ >>> info = extract_feature_info(model)
54
+ >>> print(info['input_names']) # ['user_id', 'item_id', ...]
55
+ """
56
+ # Common feature attribute names across different model types
57
+ feature_attrs = [
58
+ 'features', # MMOE, DCN, etc.
59
+ 'deep_features', # DeepFM, WideDeep
60
+ 'fm_features', # DeepFM
61
+ 'wide_features', # WideDeep
62
+ 'linear_features', # DeepFFM
63
+ 'cross_features', # DeepFFM
64
+ 'user_features', # DSSM, YoutubeDNN, MIND
65
+ 'item_features', # DSSM, YoutubeDNN, MIND
66
+ 'history_features', # DIN, MIND
67
+ 'target_features', # DIN
68
+ 'neg_item_feature', # YoutubeDNN, MIND
69
+ ]
70
+
71
+ all_features = []
72
+ user_features = []
73
+ item_features = []
74
+
75
+ for attr in feature_attrs:
76
+ if hasattr(model, attr):
77
+ feat_list = getattr(model, attr)
78
+ if isinstance(feat_list, list) and len(feat_list) > 0:
79
+ all_features.extend(feat_list)
80
+ # Track user/item features for dual-tower models
81
+ if 'user' in attr or 'history' in attr:
82
+ user_features.extend(feat_list)
83
+ elif 'item' in attr:
84
+ item_features.extend(feat_list)
85
+
86
+ # Deduplicate features by name while preserving order
87
+ seen = set()
88
+ unique_features = []
89
+ for f in all_features:
90
+ if hasattr(f, 'name') and f.name not in seen:
91
+ seen.add(f.name)
92
+ unique_features.append(f)
93
+
94
+ # Deduplicate user/item features
95
+ seen_user = set()
96
+ unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
97
+ seen_item = set()
98
+ unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
99
+
100
+ # Build input names and types
101
+ input_names = [f.name for f in unique_features if hasattr(f, 'name')]
102
+ input_types = {f.name: type(f).__name__ for f in unique_features if hasattr(f, 'name')}
103
+
104
+ return {
105
+ 'features': unique_features,
106
+ 'input_names': input_names,
107
+ 'input_types': input_types,
108
+ 'user_features': unique_user,
109
+ 'item_features': unique_item,
110
+ }
111
+
112
+
113
+ def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
114
+ """Generate dummy input tensors based on feature definitions.
115
+
116
+ Parameters
117
+ ----------
118
+ features : list
119
+ List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
120
+ batch_size : int, default=2
121
+ Batch size for dummy input.
122
+ seq_length : int, default=10
123
+ Sequence length for SequenceFeature.
124
+ device : str, default='cpu'
125
+ Device to create tensors on.
126
+
127
+ Returns
128
+ -------
129
+ tuple of Tensor
130
+ Tuple of tensors in the order of input features.
131
+
132
+ Examples
133
+ --------
134
+ >>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
135
+ >>> dummy = generate_dummy_input(features, batch_size=4)
136
+ >>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
137
+ """
138
+ # Dynamic import to handle feature types
139
+ from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
140
+
141
+ inputs = []
142
+ for feat in features:
143
+ if isinstance(feat, SequenceFeature):
144
+ # Sequence features have shape [batch_size, seq_length]
145
+ tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
146
+ elif isinstance(feat, SparseFeature):
147
+ # Sparse features have shape [batch_size]
148
+ tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
149
+ elif isinstance(feat, DenseFeature):
150
+ # Dense features always have shape [batch_size, embed_dim]
151
+ tensor = torch.randn(batch_size, feat.embed_dim, device=device)
152
+ else:
153
+ raise TypeError(f"Unsupported feature type: {type(feat)}")
154
+ inputs.append(tensor)
155
+ return tuple(inputs)
156
+
157
+
158
+ def generate_dummy_input_dict(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Dict[str, torch.Tensor]:
159
+ """Generate dummy input dict based on feature definitions.
160
+
161
+ Similar to generate_dummy_input but returns a dict mapping feature names
162
+ to tensors. This is the expected input format for torch-rechub models.
163
+
164
+ Parameters
165
+ ----------
166
+ features : list
167
+ List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
168
+ batch_size : int, default=2
169
+ Batch size for dummy input.
170
+ seq_length : int, default=10
171
+ Sequence length for SequenceFeature.
172
+ device : str, default='cpu'
173
+ Device to create tensors on.
174
+
175
+ Returns
176
+ -------
177
+ dict
178
+ Dict mapping feature names to tensors.
179
+
180
+ Examples
181
+ --------
182
+ >>> features = [SparseFeature("user_id", 1000)]
183
+ >>> dummy = generate_dummy_input_dict(features, batch_size=4)
184
+ >>> # Returns {"user_id": tensor[4]}
185
+ """
186
+ dummy_tuple = generate_dummy_input(features, batch_size, seq_length, device)
187
+ input_names = [f.name for f in features if hasattr(f, 'name')]
188
+ return {name: tensor for name, tensor in zip(input_names, dummy_tuple)}
189
+
190
+
191
+ def generate_dynamic_axes(input_names: List[str], output_names: Optional[List[str]] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: Optional[List[str]] = None) -> Dict[str, Dict[int, str]]:
192
+ """Generate dynamic axes configuration for ONNX export.
193
+
194
+ Parameters
195
+ ----------
196
+ input_names : list of str
197
+ List of input tensor names.
198
+ output_names : list of str, optional
199
+ List of output tensor names. Default is ["output"].
200
+ batch_dim : int, default=0
201
+ Dimension index for batch size.
202
+ include_seq_dim : bool, default=True
203
+ Whether to include sequence dimension as dynamic.
204
+ seq_features : list of str, optional
205
+ List of feature names that are sequences.
206
+
207
+ Returns
208
+ -------
209
+ dict
210
+ Dynamic axes dict for torch.onnx.export.
211
+
212
+ Examples
213
+ --------
214
+ >>> axes = generate_dynamic_axes(["user_id", "item_id"], seq_features=["hist"])
215
+ >>> # Returns {"user_id": {0: "batch_size"}, "item_id": {0: "batch_size"}, ...}
216
+ """
217
+ if output_names is None:
218
+ output_names = ["output"]
219
+
220
+ dynamic_axes = {}
221
+
222
+ # Input axes
223
+ for name in input_names:
224
+ dynamic_axes[name] = {batch_dim: "batch_size"}
225
+ # Add sequence dimension for sequence features
226
+ if include_seq_dim and seq_features and name in seq_features:
227
+ dynamic_axes[name][1] = "seq_length"
228
+
229
+ # Output axes
230
+ for name in output_names:
231
+ dynamic_axes[name] = {batch_dim: "batch_size"}
232
+
233
+ return dynamic_axes
@@ -62,142 +62,9 @@ class ONNXWrapper(nn.Module):
62
62
  self.model.mode = self._original_mode
63
63
 
64
64
 
65
- def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
66
- """Extract feature information from a model using reflection.
67
-
68
- This function inspects model attributes to find feature lists without
69
- modifying the model code. Supports various model architectures.
70
-
71
- Args:
72
- model: The recommendation model to inspect.
73
-
74
- Returns:
75
- Dict containing:
76
- - 'features': List of unique Feature objects
77
- - 'input_names': List of feature names in order
78
- - 'input_types': Dict mapping feature name to feature type
79
- - 'user_features': List of user-side features (for dual-tower models)
80
- - 'item_features': List of item-side features (for dual-tower models)
81
- """
82
- # Common feature attribute names across different model types
83
- feature_attrs = [
84
- 'features', # MMOE, DCN, etc.
85
- 'deep_features', # DeepFM, WideDeep
86
- 'fm_features', # DeepFM
87
- 'wide_features', # WideDeep
88
- 'linear_features', # DeepFFM
89
- 'cross_features', # DeepFFM
90
- 'user_features', # DSSM, YoutubeDNN, MIND
91
- 'item_features', # DSSM, YoutubeDNN, MIND
92
- 'history_features', # DIN, MIND
93
- 'target_features', # DIN
94
- 'neg_item_feature', # YoutubeDNN, MIND
95
- ]
96
-
97
- all_features = []
98
- user_features = []
99
- item_features = []
100
-
101
- for attr in feature_attrs:
102
- if hasattr(model, attr):
103
- feat_list = getattr(model, attr)
104
- if isinstance(feat_list, list) and len(feat_list) > 0:
105
- all_features.extend(feat_list)
106
- # Track user/item features for dual-tower models
107
- if 'user' in attr or 'history' in attr:
108
- user_features.extend(feat_list)
109
- elif 'item' in attr:
110
- item_features.extend(feat_list)
111
-
112
- # Deduplicate features by name while preserving order
113
- seen = set()
114
- unique_features = []
115
- for f in all_features:
116
- if hasattr(f, 'name') and f.name not in seen:
117
- seen.add(f.name)
118
- unique_features.append(f)
119
-
120
- # Deduplicate user/item features
121
- seen_user = set()
122
- unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
123
- seen_item = set()
124
- unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
125
-
126
- return {
127
- 'features': unique_features,
128
- 'input_names': [f.name for f in unique_features],
129
- 'input_types': {
130
- f.name: type(f).__name__ for f in unique_features
131
- },
132
- 'user_features': unique_user,
133
- 'item_features': unique_item,
134
- }
135
-
136
-
137
- def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
138
- """Generate dummy input tensors for ONNX export based on feature definitions.
139
-
140
- Args:
141
- features: List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
142
- batch_size: Batch size for dummy input (default: 2).
143
- seq_length: Sequence length for SequenceFeature (default: 10).
144
- device: Device to create tensors on (default: 'cpu').
145
-
146
- Returns:
147
- Tuple of tensors in the order of input features.
148
-
149
- Example:
150
- >>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
151
- >>> dummy = generate_dummy_input(features, batch_size=4)
152
- >>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
153
- """
154
- inputs = []
155
- for feat in features:
156
- if isinstance(feat, SequenceFeature):
157
- # Sequence features have shape [batch_size, seq_length]
158
- tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
159
- elif isinstance(feat, SparseFeature):
160
- # Sparse features have shape [batch_size]
161
- tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
162
- elif isinstance(feat, DenseFeature):
163
- # Dense features have shape [batch_size, embed_dim]
164
- tensor = torch.randn(batch_size, feat.embed_dim, device=device)
165
- else:
166
- raise TypeError(f"Unsupported feature type: {type(feat)}")
167
- inputs.append(tensor)
168
- return tuple(inputs)
169
-
170
-
171
- def generate_dynamic_axes(input_names: List[str], output_names: List[str] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: List[str] = None) -> Dict[str, Dict[int, str]]:
172
- """Generate dynamic axes configuration for ONNX export.
173
-
174
- Args:
175
- input_names: List of input tensor names.
176
- output_names: List of output tensor names (default: ["output"]).
177
- batch_dim: Dimension index for batch size (default: 0).
178
- include_seq_dim: Whether to include sequence dimension as dynamic (default: True).
179
- seq_features: List of feature names that are sequences (default: auto-detect).
180
-
181
- Returns:
182
- Dynamic axes dict for torch.onnx.export.
183
- """
184
- if output_names is None:
185
- output_names = ["output"]
186
-
187
- dynamic_axes = {}
188
-
189
- # Input axes
190
- for name in input_names:
191
- dynamic_axes[name] = {batch_dim: "batch_size"}
192
- # Add sequence dimension for sequence features
193
- if include_seq_dim and seq_features and name in seq_features:
194
- dynamic_axes[name][1] = "seq_length"
195
-
196
- # Output axes
197
- for name in output_names:
198
- dynamic_axes[name] = {batch_dim: "batch_size"}
199
-
200
- return dynamic_axes
65
+ # Re-export from model_utils for backward compatibility
66
+ # The actual implementations are now in model_utils.py
67
+ from .model_utils import extract_feature_info, generate_dummy_input, generate_dummy_input_dict, generate_dynamic_axes
201
68
 
202
69
 
203
70
  class ONNXExporter:
@@ -0,0 +1,271 @@
1
+ """
2
+ Model Visualization Utilities for Torch-RecHub.
3
+
4
+ This module provides model structure visualization using torchview library.
5
+ Requires optional dependencies: pip install torch-rechub[visualization]
6
+
7
+ Example:
8
+ >>> from torch_rechub.utils.visualization import visualize_model, display_graph
9
+ >>> graph = visualize_model(model, depth=4)
10
+ >>> display_graph(graph) # Display in Jupyter Notebook
11
+
12
+ >>> # Save to file
13
+ >>> visualize_model(model, save_path="model_arch.pdf")
14
+ """
15
+
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ # Check for optional dependencies
22
+ TORCHVIEW_AVAILABLE = False
23
+ TORCHVIEW_SKIP_REASON = "torchview not installed"
24
+
25
+ try:
26
+ from torchview import draw_graph
27
+ TORCHVIEW_AVAILABLE = True
28
+ except ImportError as e:
29
+ TORCHVIEW_SKIP_REASON = f"torchview not available: {e}"
30
+
31
+
32
+ def _is_jupyter_environment() -> bool:
33
+ """Check if running in Jupyter/IPython environment."""
34
+ try:
35
+ from IPython import get_ipython
36
+ shell = get_ipython()
37
+ if shell is None:
38
+ return False
39
+ # Check for Jupyter notebook or qtconsole
40
+ shell_class = shell.__class__.__name__
41
+ return shell_class in ('ZMQInteractiveShell', 'TerminalInteractiveShell')
42
+ except (ImportError, NameError):
43
+ return False
44
+
45
+
46
+ def display_graph(graph: Any, format: str = 'png') -> Any:
47
+ """Display a torchview ComputationGraph in Jupyter Notebook.
48
+
49
+ This function provides a reliable way to display visualization graphs
50
+ in Jupyter environments, especially VSCode Jupyter.
51
+
52
+ Parameters
53
+ ----------
54
+ graph : ComputationGraph
55
+ A torchview ComputationGraph object returned by visualize_model().
56
+ format : str, default='png'
57
+ Output format, 'png' recommended for VSCode.
58
+
59
+ Returns
60
+ -------
61
+ graphviz.Digraph or None
62
+ The displayed graph object, or None if display fails.
63
+
64
+ Examples
65
+ --------
66
+ >>> graph = visualize_model(model, depth=4)
67
+ >>> display_graph(graph) # Works in VSCode Jupyter
68
+ """
69
+ if not TORCHVIEW_AVAILABLE:
70
+ raise ImportError(f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
71
+ "Install with: pip install torch-rechub[visualization]")
72
+
73
+ try:
74
+ import graphviz
75
+
76
+ # Set format for VSCode compatibility
77
+ graphviz.set_jupyter_format(format)
78
+ except ImportError:
79
+ pass
80
+
81
+ # Get the visual_graph (graphviz.Digraph object)
82
+ visual = graph.visual_graph
83
+
84
+ # Try to use IPython display for explicit rendering
85
+ try:
86
+ from IPython.display import display
87
+ display(visual)
88
+ return visual
89
+ except ImportError:
90
+ # Not in IPython/Jupyter environment, return the graph
91
+ return visual
92
+
93
+
94
+ def visualize_model(
95
+ model: nn.Module,
96
+ input_data: Optional[Dict[str,
97
+ torch.Tensor]] = None,
98
+ batch_size: int = 2,
99
+ seq_length: int = 10,
100
+ depth: int = 3,
101
+ show_shapes: bool = True,
102
+ expand_nested: bool = True,
103
+ save_path: Optional[str] = None,
104
+ graph_name: str = "model",
105
+ device: str = "cpu",
106
+ dpi: int = 300,
107
+ **kwargs
108
+ ) -> Any:
109
+ """Visualize a Torch-RecHub model's computation graph.
110
+
111
+ This function generates a visual representation of the model architecture,
112
+ showing layer connections, tensor shapes, and nested module structures.
113
+ It automatically extracts feature information from the model to generate
114
+ appropriate dummy inputs.
115
+
116
+ Parameters
117
+ ----------
118
+ model : nn.Module
119
+ PyTorch model to visualize. Should be a Torch-RecHub model
120
+ with feature attributes (e.g., DeepFM, DSSM, MMOE).
121
+ input_data : dict, optional
122
+ Dict of example inputs {feature_name: tensor}.
123
+ If None, inputs are auto-generated based on model features.
124
+ batch_size : int, default=2
125
+ Batch size for auto-generated inputs.
126
+ seq_length : int, default=10
127
+ Sequence length for SequenceFeature inputs.
128
+ depth : int, default=3
129
+ Visualization depth - higher values show more detail.
130
+ Set to -1 to show all layers.
131
+ show_shapes : bool, default=True
132
+ Whether to display tensor shapes on edges.
133
+ expand_nested : bool, default=True
134
+ Whether to expand nested nn.Module with dashed borders.
135
+ save_path : str, optional
136
+ Path to save the graph image. Supports .pdf, .svg, .png formats.
137
+ If None, displays in Jupyter or opens system viewer.
138
+ graph_name : str, default="model"
139
+ Name for the computation graph.
140
+ device : str, default="cpu"
141
+ Device for model execution during tracing.
142
+ dpi : int, default=300
143
+ Resolution in dots per inch for output image.
144
+ Higher values produce sharper images suitable for papers.
145
+ **kwargs : dict
146
+ Additional arguments passed to torchview.draw_graph().
147
+
148
+ Returns
149
+ -------
150
+ ComputationGraph
151
+ A torchview ComputationGraph object.
152
+ - Use `.visual_graph` property to get the graphviz.Digraph
153
+ - Use `.resize_graph(scale=1.5)` to adjust graph size
154
+
155
+ Raises
156
+ ------
157
+ ImportError
158
+ If torchview or graphviz is not installed.
159
+ ValueError
160
+ If model has no recognizable feature attributes.
161
+
162
+ Notes
163
+ -----
164
+ Default Display Behavior:
165
+ When `save_path` is None (default):
166
+ - In Jupyter/IPython: automatically displays the graph inline
167
+ - In Python script: opens the graph with system default viewer
168
+
169
+ Requires graphviz system package: apt/brew/choco install graphviz.
170
+ For Jupyter display issues, try: graphviz.set_jupyter_format('png').
171
+
172
+ Examples
173
+ --------
174
+ >>> from torch_rechub.models.ranking import DeepFM
175
+ >>> from torch_rechub.utils.visualization import visualize_model
176
+ >>>
177
+ >>> # Auto-display in Jupyter or open in viewer
178
+ >>> visualize_model(model, depth=4) # No save_path needed
179
+ >>>
180
+ >>> # Save to high-DPI PNG for paper
181
+ >>> visualize_model(model, save_path="model.png", dpi=300)
182
+ """
183
+ if not TORCHVIEW_AVAILABLE:
184
+ raise ImportError(
185
+ f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
186
+ "Install with: pip install torch-rechub[visualization]\n"
187
+ "Also ensure graphviz is installed on your system:\n"
188
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
189
+ " - macOS: brew install graphviz\n"
190
+ " - Windows: choco install graphviz"
191
+ )
192
+
193
+ # Import feature extraction utilities from model_utils
194
+ from .model_utils import extract_feature_info, generate_dummy_input_dict
195
+
196
+ model.eval()
197
+ model.to(device)
198
+
199
+ # Auto-generate input data if not provided
200
+ if input_data is None:
201
+ feature_info = extract_feature_info(model)
202
+ features = feature_info['features']
203
+
204
+ if not features:
205
+ raise ValueError("Could not extract feature information from model. "
206
+ "Please provide input_data parameter manually.")
207
+
208
+ # Generate dummy input dict
209
+ input_data = generate_dummy_input_dict(features, batch_size=batch_size, seq_length=seq_length, device=device)
210
+ else:
211
+ # Ensure input tensors are on correct device
212
+ input_data = {k: v.to(device) for k, v in input_data.items()}
213
+
214
+ # IMPORTANT: Wrap input_data dict in a tuple to work around torchview's behavior
215
+ #
216
+ # torchview's forward_prop function checks the type of input_data:
217
+ # - If isinstance(x, (list, tuple)): model(*x)
218
+ # - If isinstance(x, Mapping): model(**x) <- This unpacks dict as kwargs!
219
+ # - Else: model(x)
220
+ #
221
+ # torch-rechub models expect forward(self, x) where x is a complete dict.
222
+ # By wrapping the dict in a tuple, torchview will call:
223
+ # model(*(input_dict,)) = model(input_dict)
224
+ # which is exactly what our models expect.
225
+ input_data_wrapped = (input_data,)
226
+
227
+ # Call torchview.draw_graph without saving (we'll save manually with DPI)
228
+ graph = draw_graph(
229
+ model,
230
+ input_data=input_data_wrapped,
231
+ graph_name=graph_name,
232
+ depth=depth,
233
+ device=device,
234
+ expand_nested=expand_nested,
235
+ show_shapes=show_shapes,
236
+ save_graph=False, # Don't save here, we'll save manually with DPI
237
+ **kwargs
238
+ )
239
+
240
+ # Set DPI for high-quality output (must be set BEFORE rendering/saving)
241
+ graph.visual_graph.graph_attr['dpi'] = str(dpi)
242
+
243
+ # Handle save_path: manually save with DPI applied
244
+ if save_path:
245
+ import os
246
+ directory = os.path.dirname(save_path) or "."
247
+ filename = os.path.splitext(os.path.basename(save_path))[0]
248
+ ext = os.path.splitext(save_path)[1].lstrip('.')
249
+ # Default to pdf if no extension
250
+ output_format = ext if ext else 'pdf'
251
+ # Create directory if it doesn't exist
252
+ if directory != "." and not os.path.exists(directory):
253
+ os.makedirs(directory, exist_ok=True)
254
+ # Render and save with DPI applied
255
+ graph.visual_graph.render(
256
+ filename=filename,
257
+ directory=directory,
258
+ format=output_format,
259
+ cleanup=True # Remove intermediate .gv file
260
+ )
261
+
262
+ # Handle default display behavior when save_path is None
263
+ if save_path is None:
264
+ if _is_jupyter_environment():
265
+ # In Jupyter: display inline
266
+ display_graph(graph)
267
+ else:
268
+ # In script: open with system viewer
269
+ graph.visual_graph.view(cleanup=True)
270
+
271
+ return graph
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-rechub
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
5
5
  Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
6
6
  Project-URL: Documentation, https://www.torch-rechub.com
@@ -41,6 +41,9 @@ Requires-Dist: yapf==0.43.0; extra == 'dev'
41
41
  Provides-Extra: onnx
42
42
  Requires-Dist: onnx>=1.12.0; extra == 'onnx'
43
43
  Requires-Dist: onnxruntime>=1.12.0; extra == 'onnx'
44
+ Provides-Extra: visualization
45
+ Requires-Dist: graphviz>=0.20; extra == 'visualization'
46
+ Requires-Dist: torchview>=0.2.6; extra == 'visualization'
44
47
  Description-Content-Type: text/markdown
45
48
 
46
49
  # 🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架
@@ -69,13 +72,13 @@ Description-Content-Type: text/markdown
69
72
 
70
73
  ## 🎯 为什么选择 Torch-RecHub?
71
74
 
72
- | 特性 | Torch-RecHub | 其他框架 |
73
- |------|-------------|---------|
74
- | 代码行数 | **10行** 完成训练+评估+部署 | 100+ 行 |
75
- | 模型覆盖 | **30+** 主流模型 | 有限 |
76
- | 生成式推荐 | ✅ HSTU/HLLM (Meta 2024) | ❌ |
77
- | ONNX 一键导出 | ✅ 内置支持 | 需手动适配 |
78
- | 学习曲线 | 极低 | 陡峭 |
75
+ | 特性 | Torch-RecHub | 其他框架 |
76
+ | ------------- | --------------------------- | ---------- |
77
+ | 代码行数 | **10行** 完成训练+评估+部署 | 100+ 行 |
78
+ | 模型覆盖 | **30+** 主流模型 | 有限 |
79
+ | 生成式推荐 | ✅ HSTU/HLLM (Meta 2024) | ❌ |
80
+ | ONNX 一键导出 | ✅ 内置支持 | 需手动适配 |
81
+ | 学习曲线 | 极低 | 陡峭 |
79
82
 
80
83
  ## ✨ 特性
81
84
 
@@ -205,52 +208,52 @@ torch-rechub/ # 根目录
205
208
 
206
209
  ### 排序模型 (Ranking Models) - 13个
207
210
 
208
- | 模型 | 论文 | 简介 |
209
- |------|------|------|
210
- | **DeepFM** | [IJCAI 2017](https://arxiv.org/abs/1703.04247) | FM + Deep 联合训练 |
211
- | **Wide&Deep** | [DLRS 2016](https://arxiv.org/abs/1606.07792) | 记忆 + 泛化能力结合 |
212
- | **DCN** | [KDD 2017](https://arxiv.org/abs/1708.05123) | 显式特征交叉网络 |
213
- | **DCN-v2** | [WWW 2021](https://arxiv.org/abs/2008.13535) | 增强版交叉网络 |
214
- | **DIN** | [KDD 2018](https://arxiv.org/abs/1706.06978) | 注意力机制捕捉用户兴趣 |
215
- | **DIEN** | [AAAI 2019](https://arxiv.org/abs/1809.03672) | 兴趣演化建模 |
216
- | **BST** | [DLP-KDD 2019](https://arxiv.org/abs/1905.06874) | Transformer 序列建模 |
217
- | **AFM** | [IJCAI 2017](https://arxiv.org/abs/1708.04617) | 注意力因子分解机 |
218
- | **AutoInt** | [CIKM 2019](https://arxiv.org/abs/1810.11921) | 自动特征交互学习 |
219
- | **FiBiNET** | [RecSys 2019](https://arxiv.org/abs/1905.09433) | 特征重要性 + 双线性交互 |
220
- | **DeepFFM** | [RecSys 2019](https://arxiv.org/abs/1611.00144) | 场感知因子分解机 |
221
- | **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络 |
211
+ | 模型 | 论文 | 简介 |
212
+ | ------------- | ------------------------------------------------ | ----------------------- |
213
+ | **DeepFM** | [IJCAI 2017](https://arxiv.org/abs/1703.04247) | FM + Deep 联合训练 |
214
+ | **Wide&Deep** | [DLRS 2016](https://arxiv.org/abs/1606.07792) | 记忆 + 泛化能力结合 |
215
+ | **DCN** | [KDD 2017](https://arxiv.org/abs/1708.05123) | 显式特征交叉网络 |
216
+ | **DCN-v2** | [WWW 2021](https://arxiv.org/abs/2008.13535) | 增强版交叉网络 |
217
+ | **DIN** | [KDD 2018](https://arxiv.org/abs/1706.06978) | 注意力机制捕捉用户兴趣 |
218
+ | **DIEN** | [AAAI 2019](https://arxiv.org/abs/1809.03672) | 兴趣演化建模 |
219
+ | **BST** | [DLP-KDD 2019](https://arxiv.org/abs/1905.06874) | Transformer 序列建模 |
220
+ | **AFM** | [IJCAI 2017](https://arxiv.org/abs/1708.04617) | 注意力因子分解机 |
221
+ | **AutoInt** | [CIKM 2019](https://arxiv.org/abs/1810.11921) | 自动特征交互学习 |
222
+ | **FiBiNET** | [RecSys 2019](https://arxiv.org/abs/1905.09433) | 特征重要性 + 双线性交互 |
223
+ | **DeepFFM** | [RecSys 2019](https://arxiv.org/abs/1611.00144) | 场感知因子分解机 |
224
+ | **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络 |
222
225
 
223
226
  ### 召回模型 (Matching Models) - 12个
224
227
 
225
- | 模型 | 论文 | 简介 |
226
- |------|------|------|
227
- | **DSSM** | [CIKM 2013](https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf) | 经典双塔召回模型 |
228
- | **YoutubeDNN** | [RecSys 2016](https://dl.acm.org/doi/10.1145/2959100.2959190) | YouTube 深度召回 |
229
- | **YoutubeSBC** | [RecSys 2019](https://dl.acm.org/doi/10.1145/3298689.3346997) | 采样偏差校正版本 |
230
- | **MIND** | [CIKM 2019](https://arxiv.org/abs/1904.08030) | 多兴趣动态路由 |
231
- | **SINE** | [WSDM 2021](https://arxiv.org/abs/2103.06920) | 稀疏兴趣网络 |
232
- | **GRU4Rec** | [ICLR 2016](https://arxiv.org/abs/1511.06939) | GRU 序列推荐 |
233
- | **SASRec** | [ICDM 2018](https://arxiv.org/abs/1808.09781) | 自注意力序列推荐 |
234
- | **NARM** | [CIKM 2017](https://arxiv.org/abs/1711.04725) | 神经注意力会话推荐 |
235
- | **STAMP** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3219895) | 短期注意力记忆优先 |
236
- | **ComiRec** | [KDD 2020](https://arxiv.org/abs/2005.09347) | 可控多兴趣推荐 |
228
+ | 模型 | 论文 | 简介 |
229
+ | -------------- | ------------------------------------------------------------------------------ | ------------------ |
230
+ | **DSSM** | [CIKM 2013](https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf) | 经典双塔召回模型 |
231
+ | **YoutubeDNN** | [RecSys 2016](https://dl.acm.org/doi/10.1145/2959100.2959190) | YouTube 深度召回 |
232
+ | **YoutubeSBC** | [RecSys 2019](https://dl.acm.org/doi/10.1145/3298689.3346997) | 采样偏差校正版本 |
233
+ | **MIND** | [CIKM 2019](https://arxiv.org/abs/1904.08030) | 多兴趣动态路由 |
234
+ | **SINE** | [WSDM 2021](https://arxiv.org/abs/2103.06920) | 稀疏兴趣网络 |
235
+ | **GRU4Rec** | [ICLR 2016](https://arxiv.org/abs/1511.06939) | GRU 序列推荐 |
236
+ | **SASRec** | [ICDM 2018](https://arxiv.org/abs/1808.09781) | 自注意力序列推荐 |
237
+ | **NARM** | [CIKM 2017](https://arxiv.org/abs/1711.04725) | 神经注意力会话推荐 |
238
+ | **STAMP** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3219895) | 短期注意力记忆优先 |
239
+ | **ComiRec** | [KDD 2020](https://arxiv.org/abs/2005.09347) | 可控多兴趣推荐 |
237
240
 
238
241
  ### 多任务模型 (Multi-Task Models) - 5个
239
242
 
240
- | 模型 | 论文 | 简介 |
241
- |------|------|------|
242
- | **ESMM** | [SIGIR 2018](https://arxiv.org/abs/1804.07931) | 全空间多任务建模 |
243
- | **MMoE** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3220007) | 多门控专家混合 |
244
- | **PLE** | [RecSys 2020](https://dl.acm.org/doi/10.1145/3383313.3412236) | 渐进式分层提取 |
245
- | **AITM** | [KDD 2021](https://arxiv.org/abs/2105.08489) | 自适应信息迁移 |
246
- | **SharedBottom** | - | 经典多任务共享底层 |
243
+ | 模型 | 论文 | 简介 |
244
+ | ---------------- | ------------------------------------------------------------- | ------------------ |
245
+ | **ESMM** | [SIGIR 2018](https://arxiv.org/abs/1804.07931) | 全空间多任务建模 |
246
+ | **MMoE** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3220007) | 多门控专家混合 |
247
+ | **PLE** | [RecSys 2020](https://dl.acm.org/doi/10.1145/3383313.3412236) | 渐进式分层提取 |
248
+ | **AITM** | [KDD 2021](https://arxiv.org/abs/2105.08489) | 自适应信息迁移 |
249
+ | **SharedBottom** | - | 经典多任务共享底层 |
247
250
 
248
251
  ### 生成式推荐 (Generative Recommendation) - 2个
249
252
 
250
- | 模型 | 论文 | 简介 |
251
- |------|------|------|
253
+ | 模型 | 论文 | 简介 |
254
+ | -------- | --------------------------------------------- | -------------------------------------------- |
252
255
  | **HSTU** | [Meta 2024](https://arxiv.org/abs/2402.17152) | 层级序列转换单元,支撑 Meta 万亿参数推荐系统 |
253
- | **HLLM** | [2024](https://arxiv.org/abs/2409.12740) | 层级大语言模型推荐,融合 LLM 语义理解能力 |
256
+ | **HLLM** | [2024](https://arxiv.org/abs/2409.12740) | 层级大语言模型推荐,融合 LLM 语义理解能力 |
254
257
 
255
258
  ## 📊 支持的数据集
256
259
 
@@ -338,11 +341,19 @@ model = DSSM(user_features, item_features, temperature=0.02,
338
341
  match_trainer = MatchTrainer(model)
339
342
  match_trainer.fit(train_dl)
340
343
  match_trainer.export_onnx("dssm.onnx")
341
- # 双塔模型可分别导出用户塔和物品塔:
344
+ # 双塔模型可分别导出用户塔和物品塔:
342
345
  # match_trainer.export_onnx("user_tower.onnx", mode="user")
343
346
  # match_trainer.export_onnx("dssm_item.onnx", tower="item")
344
347
  ```
345
348
 
349
+ ### 模型可视化
350
+
351
+ ```python
352
+ # 可视化模型架构(需要安装: pip install torch-rechub[visualization])
353
+ graph = ctr_trainer.visualization(depth=4) # 生成计算图
354
+ ctr_trainer.visualization(save_path="model.pdf", dpi=300) # 保存为高清 PDF
355
+ ```
356
+
346
357
  ## 👨‍💻‍ 贡献者
347
358
 
348
359
  感谢所有的贡献者!
@@ -45,18 +45,20 @@ torch_rechub/models/ranking/edcn.py,sha256=6f_S8I6Ir16kCIU54R4EfumWfUFOND5KDKUPH
45
45
  torch_rechub/models/ranking/fibinet.py,sha256=fmEJ9WkO8Mn0RtK_8aRHlnQFh_jMBPO0zODoHZPWmDA,2234
46
46
  torch_rechub/models/ranking/widedeep.py,sha256=eciRvWRBHLlctabLLS5NB7k3MnqrWXCBdpflOU6jMB0,1636
47
47
  torch_rechub/trainers/__init__.py,sha256=NSa2DqgfE1HGDyj40YgrbtUrfBHBxNBpw57XtaAB_jE,148
48
- torch_rechub/trainers/ctr_trainer.py,sha256=RDUXkn7GwLzs3f0kWZwGDNCpqiMeGXo7R6ezFeZdPg8,9075
49
- torch_rechub/trainers/match_trainer.py,sha256=xox5eaPKjSgErJQpbSr29sbyGs1p2sFaKEjxACE6uMI,11276
48
+ torch_rechub/trainers/ctr_trainer.py,sha256=ECXaK0x2_6jZVxtEazgN3hkBpSAMPeGeNtunqI_OECo,12860
49
+ torch_rechub/trainers/match_trainer.py,sha256=QHZb32Rf7yp-NvEzdeiG1HQghQ76_vuu59K1IsdK60k,15055
50
50
  torch_rechub/trainers/matching.md,sha256=vIBQ3UMmVpUpyk38rrkelFwm_wXVXqMOuqzYZ4M8bzw,30
51
- torch_rechub/trainers/mtl_trainer.py,sha256=tC4c2KIc-H8Wvj4qCzcW6TyfMLRPJyfQvTaN0dDePFg,12598
52
- torch_rechub/trainers/seq_trainer.py,sha256=lXKRx7XbZ3iJuqp_f05vw_jkn8X5j8HmH6Nr-typiIU,12043
51
+ torch_rechub/trainers/mtl_trainer.py,sha256=MjasE_QOPfGxiUW1JpYYQ2iuBSSk-lissAGp4Sw1CWk,16427
52
+ torch_rechub/trainers/seq_trainer.py,sha256=uAo9XymwQupCqvm5otKW81tz1nxd3crJ2ul2r7lrEAE,17633
53
53
  torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
54
54
  torch_rechub/utils/data.py,sha256=vzLAAVt6dujg_vbGhQewiJc0l6JzwzdcM_9EjoOz898,19882
55
55
  torch_rechub/utils/hstu_utils.py,sha256=qLON_pJDC-kDyQn1PoN_HaHi5xTNCwZPgJeV51Z61Lc,6207
56
56
  torch_rechub/utils/match.py,sha256=l9qDwJGHPP9gOQTMYoqGVdWrlhDx1F1-8UnQwDWrEyk,18143
57
+ torch_rechub/utils/model_utils.py,sha256=VLhSbTpupxrFyyY3NzMQ32PPmo5YHm1T96u9KDlwiWE,8450
57
58
  torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,5737
58
- torch_rechub/utils/onnx_export.py,sha256=uRcAD4uZ3eIQbM-DPhdc0bkaPaslNsOYny6BOeLVBfU,13660
59
- torch_rechub-0.0.4.dist-info/METADATA,sha256=SNm71v_YOfculnc13p266bD_8yLo0U_16F_aJQPDvYo,16149
60
- torch_rechub-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
61
- torch_rechub-0.0.4.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
62
- torch_rechub-0.0.4.dist-info/RECORD,,
59
+ torch_rechub/utils/onnx_export.py,sha256=LRHyZaR9zZJyg6xtuqQHWmusWq-yEvw9EhlmoEwcqsg,8364
60
+ torch_rechub/utils/visualization.py,sha256=Djv8W5SkCk3P2dol5VXf0_eanIhxDwRd7fzNOQY4uiU,9506
61
+ torch_rechub-0.0.5.dist-info/METADATA,sha256=7k9N1xGB4JeWzri7iA7kJbPnAJ-KhXF7vBV-_b8Ghrg,17998
62
+ torch_rechub-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
+ torch_rechub-0.0.5.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
64
+ torch_rechub-0.0.5.dist-info/RECORD,,