torch-rechub 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/trainers/ctr_trainer.py +97 -0
- torch_rechub/trainers/match_trainer.py +97 -0
- torch_rechub/trainers/mtl_trainer.py +97 -0
- torch_rechub/trainers/seq_trainer.py +134 -0
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/onnx_export.py +3 -136
- torch_rechub/utils/visualization.py +271 -0
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.5.dist-info}/METADATA +56 -45
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.5.dist-info}/RECORD +11 -9
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -189,3 +189,100 @@ class CTRTrainer(object):
|
|
|
189
189
|
|
|
190
190
|
exporter = ONNXExporter(model, device=export_device)
|
|
191
191
|
return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
|
|
192
|
+
|
|
193
|
+
def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
194
|
+
"""Visualize the model's computation graph.
|
|
195
|
+
|
|
196
|
+
This method generates a visual representation of the model architecture,
|
|
197
|
+
showing layer connections, tensor shapes, and nested module structures.
|
|
198
|
+
It automatically extracts feature information from the model.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
input_data : dict, optional
|
|
203
|
+
Example input dict {feature_name: tensor}.
|
|
204
|
+
If not provided, dummy inputs will be generated automatically.
|
|
205
|
+
batch_size : int, default=2
|
|
206
|
+
Batch size for auto-generated dummy input.
|
|
207
|
+
seq_length : int, default=10
|
|
208
|
+
Sequence length for SequenceFeature.
|
|
209
|
+
depth : int, default=3
|
|
210
|
+
Visualization depth, higher values show more detail.
|
|
211
|
+
Set to -1 to show all layers.
|
|
212
|
+
show_shapes : bool, default=True
|
|
213
|
+
Whether to display tensor shapes.
|
|
214
|
+
expand_nested : bool, default=True
|
|
215
|
+
Whether to expand nested modules.
|
|
216
|
+
save_path : str, optional
|
|
217
|
+
Path to save the graph image (.pdf, .svg, .png).
|
|
218
|
+
If None, displays in Jupyter or opens system viewer.
|
|
219
|
+
graph_name : str, default="model"
|
|
220
|
+
Name for the graph.
|
|
221
|
+
device : str, optional
|
|
222
|
+
Device for model execution. If None, defaults to 'cpu'.
|
|
223
|
+
dpi : int, default=300
|
|
224
|
+
Resolution in dots per inch for output image.
|
|
225
|
+
Higher values produce sharper images suitable for papers.
|
|
226
|
+
**kwargs : dict
|
|
227
|
+
Additional arguments passed to torchview.draw_graph().
|
|
228
|
+
|
|
229
|
+
Returns
|
|
230
|
+
-------
|
|
231
|
+
ComputationGraph
|
|
232
|
+
A torchview ComputationGraph object.
|
|
233
|
+
|
|
234
|
+
Raises
|
|
235
|
+
------
|
|
236
|
+
ImportError
|
|
237
|
+
If torchview or graphviz is not installed.
|
|
238
|
+
|
|
239
|
+
Notes
|
|
240
|
+
-----
|
|
241
|
+
Default Display Behavior:
|
|
242
|
+
When `save_path` is None (default):
|
|
243
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
244
|
+
- In Python script: opens the graph with system default viewer
|
|
245
|
+
|
|
246
|
+
Examples
|
|
247
|
+
--------
|
|
248
|
+
>>> trainer = CTRTrainer(model, ...)
|
|
249
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
250
|
+
>>>
|
|
251
|
+
>>> # Auto-display in Jupyter (no save_path needed)
|
|
252
|
+
>>> trainer.visualization(depth=4)
|
|
253
|
+
>>>
|
|
254
|
+
>>> # Save to high-DPI PNG for papers
|
|
255
|
+
>>> trainer.visualization(save_path="model.png", dpi=300)
|
|
256
|
+
"""
|
|
257
|
+
from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
|
|
258
|
+
|
|
259
|
+
if not TORCHVIEW_AVAILABLE:
|
|
260
|
+
raise ImportError(
|
|
261
|
+
"Visualization requires torchview. "
|
|
262
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
263
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
264
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
265
|
+
" - macOS: brew install graphviz\n"
|
|
266
|
+
" - Windows: choco install graphviz"
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Handle DataParallel wrapped model
|
|
270
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
271
|
+
|
|
272
|
+
# Use provided device or default to 'cpu'
|
|
273
|
+
viz_device = device if device is not None else 'cpu'
|
|
274
|
+
|
|
275
|
+
return visualize_model(
|
|
276
|
+
model,
|
|
277
|
+
input_data=input_data,
|
|
278
|
+
batch_size=batch_size,
|
|
279
|
+
seq_length=seq_length,
|
|
280
|
+
depth=depth,
|
|
281
|
+
show_shapes=show_shapes,
|
|
282
|
+
expand_nested=expand_nested,
|
|
283
|
+
save_path=save_path,
|
|
284
|
+
graph_name=graph_name,
|
|
285
|
+
device=viz_device,
|
|
286
|
+
dpi=dpi,
|
|
287
|
+
**kwargs
|
|
288
|
+
)
|
|
@@ -237,3 +237,100 @@ class MatchTrainer(object):
|
|
|
237
237
|
# Restore original mode
|
|
238
238
|
if hasattr(model, 'mode'):
|
|
239
239
|
model.mode = original_mode
|
|
240
|
+
|
|
241
|
+
def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
242
|
+
"""Visualize the model's computation graph.
|
|
243
|
+
|
|
244
|
+
This method generates a visual representation of the model architecture,
|
|
245
|
+
showing layer connections, tensor shapes, and nested module structures.
|
|
246
|
+
It automatically extracts feature information from the model.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
input_data : dict, optional
|
|
251
|
+
Example input dict {feature_name: tensor}.
|
|
252
|
+
If not provided, dummy inputs will be generated automatically.
|
|
253
|
+
batch_size : int, default=2
|
|
254
|
+
Batch size for auto-generated dummy input.
|
|
255
|
+
seq_length : int, default=10
|
|
256
|
+
Sequence length for SequenceFeature.
|
|
257
|
+
depth : int, default=3
|
|
258
|
+
Visualization depth, higher values show more detail.
|
|
259
|
+
Set to -1 to show all layers.
|
|
260
|
+
show_shapes : bool, default=True
|
|
261
|
+
Whether to display tensor shapes.
|
|
262
|
+
expand_nested : bool, default=True
|
|
263
|
+
Whether to expand nested modules.
|
|
264
|
+
save_path : str, optional
|
|
265
|
+
Path to save the graph image (.pdf, .svg, .png).
|
|
266
|
+
If None, displays in Jupyter or opens system viewer.
|
|
267
|
+
graph_name : str, default="model"
|
|
268
|
+
Name for the graph.
|
|
269
|
+
device : str, optional
|
|
270
|
+
Device for model execution. If None, defaults to 'cpu'.
|
|
271
|
+
dpi : int, default=300
|
|
272
|
+
Resolution in dots per inch for output image.
|
|
273
|
+
Higher values produce sharper images suitable for papers.
|
|
274
|
+
**kwargs : dict
|
|
275
|
+
Additional arguments passed to torchview.draw_graph().
|
|
276
|
+
|
|
277
|
+
Returns
|
|
278
|
+
-------
|
|
279
|
+
ComputationGraph
|
|
280
|
+
A torchview ComputationGraph object.
|
|
281
|
+
|
|
282
|
+
Raises
|
|
283
|
+
------
|
|
284
|
+
ImportError
|
|
285
|
+
If torchview or graphviz is not installed.
|
|
286
|
+
|
|
287
|
+
Notes
|
|
288
|
+
-----
|
|
289
|
+
Default Display Behavior:
|
|
290
|
+
When `save_path` is None (default):
|
|
291
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
292
|
+
- In Python script: opens the graph with system default viewer
|
|
293
|
+
|
|
294
|
+
Examples
|
|
295
|
+
--------
|
|
296
|
+
>>> trainer = MatchTrainer(model, ...)
|
|
297
|
+
>>> trainer.fit(train_dl)
|
|
298
|
+
>>>
|
|
299
|
+
>>> # Auto-display in Jupyter (no save_path needed)
|
|
300
|
+
>>> trainer.visualization(depth=4)
|
|
301
|
+
>>>
|
|
302
|
+
>>> # Save to high-DPI PNG for papers
|
|
303
|
+
>>> trainer.visualization(save_path="model.png", dpi=300)
|
|
304
|
+
"""
|
|
305
|
+
from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
|
|
306
|
+
|
|
307
|
+
if not TORCHVIEW_AVAILABLE:
|
|
308
|
+
raise ImportError(
|
|
309
|
+
"Visualization requires torchview. "
|
|
310
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
311
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
312
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
313
|
+
" - macOS: brew install graphviz\n"
|
|
314
|
+
" - Windows: choco install graphviz"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Handle DataParallel wrapped model
|
|
318
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
319
|
+
|
|
320
|
+
# Use provided device or default to 'cpu'
|
|
321
|
+
viz_device = device if device is not None else 'cpu'
|
|
322
|
+
|
|
323
|
+
return visualize_model(
|
|
324
|
+
model,
|
|
325
|
+
input_data=input_data,
|
|
326
|
+
batch_size=batch_size,
|
|
327
|
+
seq_length=seq_length,
|
|
328
|
+
depth=depth,
|
|
329
|
+
show_shapes=show_shapes,
|
|
330
|
+
expand_nested=expand_nested,
|
|
331
|
+
save_path=save_path,
|
|
332
|
+
graph_name=graph_name,
|
|
333
|
+
device=viz_device,
|
|
334
|
+
dpi=dpi,
|
|
335
|
+
**kwargs
|
|
336
|
+
)
|
|
@@ -257,3 +257,100 @@ class MTLTrainer(object):
|
|
|
257
257
|
|
|
258
258
|
exporter = ONNXExporter(model, device=export_device)
|
|
259
259
|
return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
|
|
260
|
+
|
|
261
|
+
def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
262
|
+
"""Visualize the model's computation graph.
|
|
263
|
+
|
|
264
|
+
This method generates a visual representation of the model architecture,
|
|
265
|
+
showing layer connections, tensor shapes, and nested module structures.
|
|
266
|
+
It automatically extracts feature information from the model.
|
|
267
|
+
|
|
268
|
+
Parameters
|
|
269
|
+
----------
|
|
270
|
+
input_data : dict, optional
|
|
271
|
+
Example input dict {feature_name: tensor}.
|
|
272
|
+
If not provided, dummy inputs will be generated automatically.
|
|
273
|
+
batch_size : int, default=2
|
|
274
|
+
Batch size for auto-generated dummy input.
|
|
275
|
+
seq_length : int, default=10
|
|
276
|
+
Sequence length for SequenceFeature.
|
|
277
|
+
depth : int, default=3
|
|
278
|
+
Visualization depth, higher values show more detail.
|
|
279
|
+
Set to -1 to show all layers.
|
|
280
|
+
show_shapes : bool, default=True
|
|
281
|
+
Whether to display tensor shapes.
|
|
282
|
+
expand_nested : bool, default=True
|
|
283
|
+
Whether to expand nested modules.
|
|
284
|
+
save_path : str, optional
|
|
285
|
+
Path to save the graph image (.pdf, .svg, .png).
|
|
286
|
+
If None, displays in Jupyter or opens system viewer.
|
|
287
|
+
graph_name : str, default="model"
|
|
288
|
+
Name for the graph.
|
|
289
|
+
device : str, optional
|
|
290
|
+
Device for model execution. If None, defaults to 'cpu'.
|
|
291
|
+
dpi : int, default=300
|
|
292
|
+
Resolution in dots per inch for output image.
|
|
293
|
+
Higher values produce sharper images suitable for papers.
|
|
294
|
+
**kwargs : dict
|
|
295
|
+
Additional arguments passed to torchview.draw_graph().
|
|
296
|
+
|
|
297
|
+
Returns
|
|
298
|
+
-------
|
|
299
|
+
ComputationGraph
|
|
300
|
+
A torchview ComputationGraph object.
|
|
301
|
+
|
|
302
|
+
Raises
|
|
303
|
+
------
|
|
304
|
+
ImportError
|
|
305
|
+
If torchview or graphviz is not installed.
|
|
306
|
+
|
|
307
|
+
Notes
|
|
308
|
+
-----
|
|
309
|
+
Default Display Behavior:
|
|
310
|
+
When `save_path` is None (default):
|
|
311
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
312
|
+
- In Python script: opens the graph with system default viewer
|
|
313
|
+
|
|
314
|
+
Examples
|
|
315
|
+
--------
|
|
316
|
+
>>> trainer = MTLTrainer(model, task_types=["classification", "classification"])
|
|
317
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
318
|
+
>>>
|
|
319
|
+
>>> # Auto-display in Jupyter (no save_path needed)
|
|
320
|
+
>>> trainer.visualization(depth=4)
|
|
321
|
+
>>>
|
|
322
|
+
>>> # Save to high-DPI PNG for papers
|
|
323
|
+
>>> trainer.visualization(save_path="model.png", dpi=300)
|
|
324
|
+
"""
|
|
325
|
+
from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
|
|
326
|
+
|
|
327
|
+
if not TORCHVIEW_AVAILABLE:
|
|
328
|
+
raise ImportError(
|
|
329
|
+
"Visualization requires torchview. "
|
|
330
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
331
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
332
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
333
|
+
" - macOS: brew install graphviz\n"
|
|
334
|
+
" - Windows: choco install graphviz"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Handle DataParallel wrapped model
|
|
338
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
339
|
+
|
|
340
|
+
# Use provided device or default to 'cpu'
|
|
341
|
+
viz_device = device if device is not None else 'cpu'
|
|
342
|
+
|
|
343
|
+
return visualize_model(
|
|
344
|
+
model,
|
|
345
|
+
input_data=input_data,
|
|
346
|
+
batch_size=batch_size,
|
|
347
|
+
seq_length=seq_length,
|
|
348
|
+
depth=depth,
|
|
349
|
+
show_shapes=show_shapes,
|
|
350
|
+
expand_nested=expand_nested,
|
|
351
|
+
save_path=save_path,
|
|
352
|
+
graph_name=graph_name,
|
|
353
|
+
device=viz_device,
|
|
354
|
+
dpi=dpi,
|
|
355
|
+
**kwargs
|
|
356
|
+
)
|
|
@@ -291,3 +291,137 @@ class SeqTrainer(object):
|
|
|
291
291
|
except Exception as e:
|
|
292
292
|
warnings.warn(f"ONNX export failed: {str(e)}")
|
|
293
293
|
raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
|
|
294
|
+
|
|
295
|
+
def visualization(self, seq_length=50, vocab_size=None, batch_size=2, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
296
|
+
"""Visualize the model's computation graph.
|
|
297
|
+
|
|
298
|
+
This method generates a visual representation of the sequence model
|
|
299
|
+
architecture, showing layer connections, tensor shapes, and nested
|
|
300
|
+
module structures.
|
|
301
|
+
|
|
302
|
+
Parameters
|
|
303
|
+
----------
|
|
304
|
+
seq_length : int, default=50
|
|
305
|
+
Sequence length for dummy input.
|
|
306
|
+
vocab_size : int, optional
|
|
307
|
+
Vocabulary size for generating dummy tokens.
|
|
308
|
+
If None, will try to get from model.vocab_size or model.item_num.
|
|
309
|
+
batch_size : int, default=2
|
|
310
|
+
Batch size for dummy input.
|
|
311
|
+
depth : int, default=3
|
|
312
|
+
Visualization depth, higher values show more detail.
|
|
313
|
+
Set to -1 to show all layers.
|
|
314
|
+
show_shapes : bool, default=True
|
|
315
|
+
Whether to display tensor shapes.
|
|
316
|
+
expand_nested : bool, default=True
|
|
317
|
+
Whether to expand nested modules.
|
|
318
|
+
save_path : str, optional
|
|
319
|
+
Path to save the graph image (.pdf, .svg, .png).
|
|
320
|
+
If None, displays in Jupyter or opens system viewer.
|
|
321
|
+
graph_name : str, default="model"
|
|
322
|
+
Name for the graph.
|
|
323
|
+
device : str, optional
|
|
324
|
+
Device for model execution. If None, defaults to 'cpu'.
|
|
325
|
+
dpi : int, default=300
|
|
326
|
+
Resolution in dots per inch for output image.
|
|
327
|
+
Higher values produce sharper images suitable for papers.
|
|
328
|
+
**kwargs : dict
|
|
329
|
+
Additional arguments passed to torchview.draw_graph().
|
|
330
|
+
|
|
331
|
+
Returns
|
|
332
|
+
-------
|
|
333
|
+
ComputationGraph
|
|
334
|
+
A torchview ComputationGraph object.
|
|
335
|
+
|
|
336
|
+
Raises
|
|
337
|
+
------
|
|
338
|
+
ImportError
|
|
339
|
+
If torchview or graphviz is not installed.
|
|
340
|
+
ValueError
|
|
341
|
+
If vocab_size is not provided and cannot be inferred from model.
|
|
342
|
+
|
|
343
|
+
Notes
|
|
344
|
+
-----
|
|
345
|
+
Default Display Behavior:
|
|
346
|
+
When `save_path` is None (default):
|
|
347
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
348
|
+
- In Python script: opens the graph with system default viewer
|
|
349
|
+
|
|
350
|
+
Examples
|
|
351
|
+
--------
|
|
352
|
+
>>> trainer = SeqTrainer(hstu_model, ...)
|
|
353
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
354
|
+
>>>
|
|
355
|
+
>>> # Auto-display in Jupyter (no save_path needed)
|
|
356
|
+
>>> trainer.visualization(depth=4, vocab_size=10000)
|
|
357
|
+
>>>
|
|
358
|
+
>>> # Save to high-DPI PNG for papers
|
|
359
|
+
>>> trainer.visualization(save_path="model.png", dpi=300)
|
|
360
|
+
"""
|
|
361
|
+
try:
|
|
362
|
+
from torchview import draw_graph
|
|
363
|
+
TORCHVIEW_AVAILABLE = True
|
|
364
|
+
except ImportError:
|
|
365
|
+
TORCHVIEW_AVAILABLE = False
|
|
366
|
+
|
|
367
|
+
if not TORCHVIEW_AVAILABLE:
|
|
368
|
+
raise ImportError(
|
|
369
|
+
"Visualization requires torchview. "
|
|
370
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
371
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
372
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
373
|
+
" - macOS: brew install graphviz\n"
|
|
374
|
+
" - Windows: choco install graphviz"
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
from ..utils.visualization import _is_jupyter_environment, display_graph
|
|
378
|
+
|
|
379
|
+
# Handle DataParallel wrapped model
|
|
380
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
381
|
+
|
|
382
|
+
# Use provided device or default to 'cpu'
|
|
383
|
+
viz_device = device if device is not None else 'cpu'
|
|
384
|
+
|
|
385
|
+
# Get vocab_size from model if not provided
|
|
386
|
+
if vocab_size is None:
|
|
387
|
+
if hasattr(model, 'vocab_size'):
|
|
388
|
+
vocab_size = model.vocab_size
|
|
389
|
+
elif hasattr(model, 'item_num'):
|
|
390
|
+
vocab_size = model.item_num
|
|
391
|
+
else:
|
|
392
|
+
raise ValueError("vocab_size must be provided or model must have "
|
|
393
|
+
"'vocab_size' or 'item_num' attribute")
|
|
394
|
+
|
|
395
|
+
# Generate dummy inputs for sequence model
|
|
396
|
+
dummy_seq_tokens = torch.randint(0, vocab_size, (batch_size, seq_length), device=viz_device)
|
|
397
|
+
dummy_seq_time_diffs = torch.zeros(batch_size, seq_length, dtype=torch.float32, device=viz_device)
|
|
398
|
+
|
|
399
|
+
# Move model to device
|
|
400
|
+
model = model.to(viz_device)
|
|
401
|
+
model.eval()
|
|
402
|
+
|
|
403
|
+
# Call torchview.draw_graph
|
|
404
|
+
graph = draw_graph(model, input_data=(dummy_seq_tokens, dummy_seq_time_diffs), graph_name=graph_name, depth=depth, device=viz_device, expand_nested=expand_nested, show_shapes=show_shapes, save_graph=False, **kwargs)
|
|
405
|
+
|
|
406
|
+
# Set DPI for high-quality output
|
|
407
|
+
graph.visual_graph.graph_attr['dpi'] = str(dpi)
|
|
408
|
+
|
|
409
|
+
# Handle save_path: manually save with DPI applied
|
|
410
|
+
if save_path:
|
|
411
|
+
import os
|
|
412
|
+
directory = os.path.dirname(save_path) or "."
|
|
413
|
+
filename = os.path.splitext(os.path.basename(save_path))[0]
|
|
414
|
+
ext = os.path.splitext(save_path)[1].lstrip('.')
|
|
415
|
+
output_format = ext if ext else 'pdf'
|
|
416
|
+
if directory != "." and not os.path.exists(directory):
|
|
417
|
+
os.makedirs(directory, exist_ok=True)
|
|
418
|
+
graph.visual_graph.render(filename=filename, directory=directory, format=output_format, cleanup=True)
|
|
419
|
+
|
|
420
|
+
# Handle default display behavior when save_path is None
|
|
421
|
+
if save_path is None:
|
|
422
|
+
if _is_jupyter_environment():
|
|
423
|
+
display_graph(graph)
|
|
424
|
+
else:
|
|
425
|
+
graph.visual_graph.view(cleanup=True)
|
|
426
|
+
|
|
427
|
+
return graph
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""Common model utility functions for Torch-RecHub.
|
|
2
|
+
|
|
3
|
+
This module provides shared utilities for model introspection and input generation,
|
|
4
|
+
used by both ONNX export and visualization features.
|
|
5
|
+
|
|
6
|
+
Examples
|
|
7
|
+
--------
|
|
8
|
+
>>> from torch_rechub.utils.model_utils import extract_feature_info, generate_dummy_input
|
|
9
|
+
>>> feature_info = extract_feature_info(model)
|
|
10
|
+
>>> dummy_input = generate_dummy_input(feature_info['features'], batch_size=2)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
|
|
18
|
+
# Import feature types for type checking
|
|
19
|
+
try:
|
|
20
|
+
from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
21
|
+
except ImportError:
|
|
22
|
+
# Fallback for standalone usage
|
|
23
|
+
SparseFeature = None
|
|
24
|
+
DenseFeature = None
|
|
25
|
+
SequenceFeature = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
|
|
29
|
+
"""Extract feature information from a torch-rechub model using reflection.
|
|
30
|
+
|
|
31
|
+
This function inspects model attributes to find feature lists without
|
|
32
|
+
modifying the model code. Supports various model architectures.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
model : nn.Module
|
|
37
|
+
The recommendation model to inspect.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
dict
|
|
42
|
+
Dictionary containing:
|
|
43
|
+
- 'features': List of unique Feature objects
|
|
44
|
+
- 'input_names': List of feature names in order
|
|
45
|
+
- 'input_types': Dict mapping feature name to feature type
|
|
46
|
+
- 'user_features': List of user-side features (for dual-tower models)
|
|
47
|
+
- 'item_features': List of item-side features (for dual-tower models)
|
|
48
|
+
|
|
49
|
+
Examples
|
|
50
|
+
--------
|
|
51
|
+
>>> from torch_rechub.models.ranking import DeepFM
|
|
52
|
+
>>> model = DeepFM(deep_features, fm_features, mlp_params)
|
|
53
|
+
>>> info = extract_feature_info(model)
|
|
54
|
+
>>> print(info['input_names']) # ['user_id', 'item_id', ...]
|
|
55
|
+
"""
|
|
56
|
+
# Common feature attribute names across different model types
|
|
57
|
+
feature_attrs = [
|
|
58
|
+
'features', # MMOE, DCN, etc.
|
|
59
|
+
'deep_features', # DeepFM, WideDeep
|
|
60
|
+
'fm_features', # DeepFM
|
|
61
|
+
'wide_features', # WideDeep
|
|
62
|
+
'linear_features', # DeepFFM
|
|
63
|
+
'cross_features', # DeepFFM
|
|
64
|
+
'user_features', # DSSM, YoutubeDNN, MIND
|
|
65
|
+
'item_features', # DSSM, YoutubeDNN, MIND
|
|
66
|
+
'history_features', # DIN, MIND
|
|
67
|
+
'target_features', # DIN
|
|
68
|
+
'neg_item_feature', # YoutubeDNN, MIND
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
all_features = []
|
|
72
|
+
user_features = []
|
|
73
|
+
item_features = []
|
|
74
|
+
|
|
75
|
+
for attr in feature_attrs:
|
|
76
|
+
if hasattr(model, attr):
|
|
77
|
+
feat_list = getattr(model, attr)
|
|
78
|
+
if isinstance(feat_list, list) and len(feat_list) > 0:
|
|
79
|
+
all_features.extend(feat_list)
|
|
80
|
+
# Track user/item features for dual-tower models
|
|
81
|
+
if 'user' in attr or 'history' in attr:
|
|
82
|
+
user_features.extend(feat_list)
|
|
83
|
+
elif 'item' in attr:
|
|
84
|
+
item_features.extend(feat_list)
|
|
85
|
+
|
|
86
|
+
# Deduplicate features by name while preserving order
|
|
87
|
+
seen = set()
|
|
88
|
+
unique_features = []
|
|
89
|
+
for f in all_features:
|
|
90
|
+
if hasattr(f, 'name') and f.name not in seen:
|
|
91
|
+
seen.add(f.name)
|
|
92
|
+
unique_features.append(f)
|
|
93
|
+
|
|
94
|
+
# Deduplicate user/item features
|
|
95
|
+
seen_user = set()
|
|
96
|
+
unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
|
|
97
|
+
seen_item = set()
|
|
98
|
+
unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
|
|
99
|
+
|
|
100
|
+
# Build input names and types
|
|
101
|
+
input_names = [f.name for f in unique_features if hasattr(f, 'name')]
|
|
102
|
+
input_types = {f.name: type(f).__name__ for f in unique_features if hasattr(f, 'name')}
|
|
103
|
+
|
|
104
|
+
return {
|
|
105
|
+
'features': unique_features,
|
|
106
|
+
'input_names': input_names,
|
|
107
|
+
'input_types': input_types,
|
|
108
|
+
'user_features': unique_user,
|
|
109
|
+
'item_features': unique_item,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
|
|
114
|
+
"""Generate dummy input tensors based on feature definitions.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
features : list
|
|
119
|
+
List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
|
|
120
|
+
batch_size : int, default=2
|
|
121
|
+
Batch size for dummy input.
|
|
122
|
+
seq_length : int, default=10
|
|
123
|
+
Sequence length for SequenceFeature.
|
|
124
|
+
device : str, default='cpu'
|
|
125
|
+
Device to create tensors on.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
tuple of Tensor
|
|
130
|
+
Tuple of tensors in the order of input features.
|
|
131
|
+
|
|
132
|
+
Examples
|
|
133
|
+
--------
|
|
134
|
+
>>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
|
|
135
|
+
>>> dummy = generate_dummy_input(features, batch_size=4)
|
|
136
|
+
>>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
|
|
137
|
+
"""
|
|
138
|
+
# Dynamic import to handle feature types
|
|
139
|
+
from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
140
|
+
|
|
141
|
+
inputs = []
|
|
142
|
+
for feat in features:
|
|
143
|
+
if isinstance(feat, SequenceFeature):
|
|
144
|
+
# Sequence features have shape [batch_size, seq_length]
|
|
145
|
+
tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
|
|
146
|
+
elif isinstance(feat, SparseFeature):
|
|
147
|
+
# Sparse features have shape [batch_size]
|
|
148
|
+
tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
|
|
149
|
+
elif isinstance(feat, DenseFeature):
|
|
150
|
+
# Dense features always have shape [batch_size, embed_dim]
|
|
151
|
+
tensor = torch.randn(batch_size, feat.embed_dim, device=device)
|
|
152
|
+
else:
|
|
153
|
+
raise TypeError(f"Unsupported feature type: {type(feat)}")
|
|
154
|
+
inputs.append(tensor)
|
|
155
|
+
return tuple(inputs)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def generate_dummy_input_dict(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Dict[str, torch.Tensor]:
|
|
159
|
+
"""Generate dummy input dict based on feature definitions.
|
|
160
|
+
|
|
161
|
+
Similar to generate_dummy_input but returns a dict mapping feature names
|
|
162
|
+
to tensors. This is the expected input format for torch-rechub models.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
features : list
|
|
167
|
+
List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
|
|
168
|
+
batch_size : int, default=2
|
|
169
|
+
Batch size for dummy input.
|
|
170
|
+
seq_length : int, default=10
|
|
171
|
+
Sequence length for SequenceFeature.
|
|
172
|
+
device : str, default='cpu'
|
|
173
|
+
Device to create tensors on.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
dict
|
|
178
|
+
Dict mapping feature names to tensors.
|
|
179
|
+
|
|
180
|
+
Examples
|
|
181
|
+
--------
|
|
182
|
+
>>> features = [SparseFeature("user_id", 1000)]
|
|
183
|
+
>>> dummy = generate_dummy_input_dict(features, batch_size=4)
|
|
184
|
+
>>> # Returns {"user_id": tensor[4]}
|
|
185
|
+
"""
|
|
186
|
+
dummy_tuple = generate_dummy_input(features, batch_size, seq_length, device)
|
|
187
|
+
input_names = [f.name for f in features if hasattr(f, 'name')]
|
|
188
|
+
return {name: tensor for name, tensor in zip(input_names, dummy_tuple)}
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def generate_dynamic_axes(input_names: List[str], output_names: Optional[List[str]] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: Optional[List[str]] = None) -> Dict[str, Dict[int, str]]:
|
|
192
|
+
"""Generate dynamic axes configuration for ONNX export.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
input_names : list of str
|
|
197
|
+
List of input tensor names.
|
|
198
|
+
output_names : list of str, optional
|
|
199
|
+
List of output tensor names. Default is ["output"].
|
|
200
|
+
batch_dim : int, default=0
|
|
201
|
+
Dimension index for batch size.
|
|
202
|
+
include_seq_dim : bool, default=True
|
|
203
|
+
Whether to include sequence dimension as dynamic.
|
|
204
|
+
seq_features : list of str, optional
|
|
205
|
+
List of feature names that are sequences.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
dict
|
|
210
|
+
Dynamic axes dict for torch.onnx.export.
|
|
211
|
+
|
|
212
|
+
Examples
|
|
213
|
+
--------
|
|
214
|
+
>>> axes = generate_dynamic_axes(["user_id", "item_id"], seq_features=["hist"])
|
|
215
|
+
>>> # Returns {"user_id": {0: "batch_size"}, "item_id": {0: "batch_size"}, ...}
|
|
216
|
+
"""
|
|
217
|
+
if output_names is None:
|
|
218
|
+
output_names = ["output"]
|
|
219
|
+
|
|
220
|
+
dynamic_axes = {}
|
|
221
|
+
|
|
222
|
+
# Input axes
|
|
223
|
+
for name in input_names:
|
|
224
|
+
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
225
|
+
# Add sequence dimension for sequence features
|
|
226
|
+
if include_seq_dim and seq_features and name in seq_features:
|
|
227
|
+
dynamic_axes[name][1] = "seq_length"
|
|
228
|
+
|
|
229
|
+
# Output axes
|
|
230
|
+
for name in output_names:
|
|
231
|
+
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
232
|
+
|
|
233
|
+
return dynamic_axes
|
|
@@ -62,142 +62,9 @@ class ONNXWrapper(nn.Module):
|
|
|
62
62
|
self.model.mode = self._original_mode
|
|
63
63
|
|
|
64
64
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
This function inspects model attributes to find feature lists without
|
|
69
|
-
modifying the model code. Supports various model architectures.
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
model: The recommendation model to inspect.
|
|
73
|
-
|
|
74
|
-
Returns:
|
|
75
|
-
Dict containing:
|
|
76
|
-
- 'features': List of unique Feature objects
|
|
77
|
-
- 'input_names': List of feature names in order
|
|
78
|
-
- 'input_types': Dict mapping feature name to feature type
|
|
79
|
-
- 'user_features': List of user-side features (for dual-tower models)
|
|
80
|
-
- 'item_features': List of item-side features (for dual-tower models)
|
|
81
|
-
"""
|
|
82
|
-
# Common feature attribute names across different model types
|
|
83
|
-
feature_attrs = [
|
|
84
|
-
'features', # MMOE, DCN, etc.
|
|
85
|
-
'deep_features', # DeepFM, WideDeep
|
|
86
|
-
'fm_features', # DeepFM
|
|
87
|
-
'wide_features', # WideDeep
|
|
88
|
-
'linear_features', # DeepFFM
|
|
89
|
-
'cross_features', # DeepFFM
|
|
90
|
-
'user_features', # DSSM, YoutubeDNN, MIND
|
|
91
|
-
'item_features', # DSSM, YoutubeDNN, MIND
|
|
92
|
-
'history_features', # DIN, MIND
|
|
93
|
-
'target_features', # DIN
|
|
94
|
-
'neg_item_feature', # YoutubeDNN, MIND
|
|
95
|
-
]
|
|
96
|
-
|
|
97
|
-
all_features = []
|
|
98
|
-
user_features = []
|
|
99
|
-
item_features = []
|
|
100
|
-
|
|
101
|
-
for attr in feature_attrs:
|
|
102
|
-
if hasattr(model, attr):
|
|
103
|
-
feat_list = getattr(model, attr)
|
|
104
|
-
if isinstance(feat_list, list) and len(feat_list) > 0:
|
|
105
|
-
all_features.extend(feat_list)
|
|
106
|
-
# Track user/item features for dual-tower models
|
|
107
|
-
if 'user' in attr or 'history' in attr:
|
|
108
|
-
user_features.extend(feat_list)
|
|
109
|
-
elif 'item' in attr:
|
|
110
|
-
item_features.extend(feat_list)
|
|
111
|
-
|
|
112
|
-
# Deduplicate features by name while preserving order
|
|
113
|
-
seen = set()
|
|
114
|
-
unique_features = []
|
|
115
|
-
for f in all_features:
|
|
116
|
-
if hasattr(f, 'name') and f.name not in seen:
|
|
117
|
-
seen.add(f.name)
|
|
118
|
-
unique_features.append(f)
|
|
119
|
-
|
|
120
|
-
# Deduplicate user/item features
|
|
121
|
-
seen_user = set()
|
|
122
|
-
unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
|
|
123
|
-
seen_item = set()
|
|
124
|
-
unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
|
|
125
|
-
|
|
126
|
-
return {
|
|
127
|
-
'features': unique_features,
|
|
128
|
-
'input_names': [f.name for f in unique_features],
|
|
129
|
-
'input_types': {
|
|
130
|
-
f.name: type(f).__name__ for f in unique_features
|
|
131
|
-
},
|
|
132
|
-
'user_features': unique_user,
|
|
133
|
-
'item_features': unique_item,
|
|
134
|
-
}
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
|
|
138
|
-
"""Generate dummy input tensors for ONNX export based on feature definitions.
|
|
139
|
-
|
|
140
|
-
Args:
|
|
141
|
-
features: List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
|
|
142
|
-
batch_size: Batch size for dummy input (default: 2).
|
|
143
|
-
seq_length: Sequence length for SequenceFeature (default: 10).
|
|
144
|
-
device: Device to create tensors on (default: 'cpu').
|
|
145
|
-
|
|
146
|
-
Returns:
|
|
147
|
-
Tuple of tensors in the order of input features.
|
|
148
|
-
|
|
149
|
-
Example:
|
|
150
|
-
>>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
|
|
151
|
-
>>> dummy = generate_dummy_input(features, batch_size=4)
|
|
152
|
-
>>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
|
|
153
|
-
"""
|
|
154
|
-
inputs = []
|
|
155
|
-
for feat in features:
|
|
156
|
-
if isinstance(feat, SequenceFeature):
|
|
157
|
-
# Sequence features have shape [batch_size, seq_length]
|
|
158
|
-
tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
|
|
159
|
-
elif isinstance(feat, SparseFeature):
|
|
160
|
-
# Sparse features have shape [batch_size]
|
|
161
|
-
tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
|
|
162
|
-
elif isinstance(feat, DenseFeature):
|
|
163
|
-
# Dense features have shape [batch_size, embed_dim]
|
|
164
|
-
tensor = torch.randn(batch_size, feat.embed_dim, device=device)
|
|
165
|
-
else:
|
|
166
|
-
raise TypeError(f"Unsupported feature type: {type(feat)}")
|
|
167
|
-
inputs.append(tensor)
|
|
168
|
-
return tuple(inputs)
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
def generate_dynamic_axes(input_names: List[str], output_names: List[str] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: List[str] = None) -> Dict[str, Dict[int, str]]:
|
|
172
|
-
"""Generate dynamic axes configuration for ONNX export.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
input_names: List of input tensor names.
|
|
176
|
-
output_names: List of output tensor names (default: ["output"]).
|
|
177
|
-
batch_dim: Dimension index for batch size (default: 0).
|
|
178
|
-
include_seq_dim: Whether to include sequence dimension as dynamic (default: True).
|
|
179
|
-
seq_features: List of feature names that are sequences (default: auto-detect).
|
|
180
|
-
|
|
181
|
-
Returns:
|
|
182
|
-
Dynamic axes dict for torch.onnx.export.
|
|
183
|
-
"""
|
|
184
|
-
if output_names is None:
|
|
185
|
-
output_names = ["output"]
|
|
186
|
-
|
|
187
|
-
dynamic_axes = {}
|
|
188
|
-
|
|
189
|
-
# Input axes
|
|
190
|
-
for name in input_names:
|
|
191
|
-
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
192
|
-
# Add sequence dimension for sequence features
|
|
193
|
-
if include_seq_dim and seq_features and name in seq_features:
|
|
194
|
-
dynamic_axes[name][1] = "seq_length"
|
|
195
|
-
|
|
196
|
-
# Output axes
|
|
197
|
-
for name in output_names:
|
|
198
|
-
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
199
|
-
|
|
200
|
-
return dynamic_axes
|
|
65
|
+
# Re-export from model_utils for backward compatibility
|
|
66
|
+
# The actual implementations are now in model_utils.py
|
|
67
|
+
from .model_utils import extract_feature_info, generate_dummy_input, generate_dummy_input_dict, generate_dynamic_axes
|
|
201
68
|
|
|
202
69
|
|
|
203
70
|
class ONNXExporter:
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Visualization Utilities for Torch-RecHub.
|
|
3
|
+
|
|
4
|
+
This module provides model structure visualization using torchview library.
|
|
5
|
+
Requires optional dependencies: pip install torch-rechub[visualization]
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
>>> from torch_rechub.utils.visualization import visualize_model, display_graph
|
|
9
|
+
>>> graph = visualize_model(model, depth=4)
|
|
10
|
+
>>> display_graph(graph) # Display in Jupyter Notebook
|
|
11
|
+
|
|
12
|
+
>>> # Save to file
|
|
13
|
+
>>> visualize_model(model, save_path="model_arch.pdf")
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from typing import Any, Dict, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
|
|
21
|
+
# Check for optional dependencies
|
|
22
|
+
TORCHVIEW_AVAILABLE = False
|
|
23
|
+
TORCHVIEW_SKIP_REASON = "torchview not installed"
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from torchview import draw_graph
|
|
27
|
+
TORCHVIEW_AVAILABLE = True
|
|
28
|
+
except ImportError as e:
|
|
29
|
+
TORCHVIEW_SKIP_REASON = f"torchview not available: {e}"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _is_jupyter_environment() -> bool:
|
|
33
|
+
"""Check if running in Jupyter/IPython environment."""
|
|
34
|
+
try:
|
|
35
|
+
from IPython import get_ipython
|
|
36
|
+
shell = get_ipython()
|
|
37
|
+
if shell is None:
|
|
38
|
+
return False
|
|
39
|
+
# Check for Jupyter notebook or qtconsole
|
|
40
|
+
shell_class = shell.__class__.__name__
|
|
41
|
+
return shell_class in ('ZMQInteractiveShell', 'TerminalInteractiveShell')
|
|
42
|
+
except (ImportError, NameError):
|
|
43
|
+
return False
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def display_graph(graph: Any, format: str = 'png') -> Any:
|
|
47
|
+
"""Display a torchview ComputationGraph in Jupyter Notebook.
|
|
48
|
+
|
|
49
|
+
This function provides a reliable way to display visualization graphs
|
|
50
|
+
in Jupyter environments, especially VSCode Jupyter.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
graph : ComputationGraph
|
|
55
|
+
A torchview ComputationGraph object returned by visualize_model().
|
|
56
|
+
format : str, default='png'
|
|
57
|
+
Output format, 'png' recommended for VSCode.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
graphviz.Digraph or None
|
|
62
|
+
The displayed graph object, or None if display fails.
|
|
63
|
+
|
|
64
|
+
Examples
|
|
65
|
+
--------
|
|
66
|
+
>>> graph = visualize_model(model, depth=4)
|
|
67
|
+
>>> display_graph(graph) # Works in VSCode Jupyter
|
|
68
|
+
"""
|
|
69
|
+
if not TORCHVIEW_AVAILABLE:
|
|
70
|
+
raise ImportError(f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
|
|
71
|
+
"Install with: pip install torch-rechub[visualization]")
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
import graphviz
|
|
75
|
+
|
|
76
|
+
# Set format for VSCode compatibility
|
|
77
|
+
graphviz.set_jupyter_format(format)
|
|
78
|
+
except ImportError:
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
# Get the visual_graph (graphviz.Digraph object)
|
|
82
|
+
visual = graph.visual_graph
|
|
83
|
+
|
|
84
|
+
# Try to use IPython display for explicit rendering
|
|
85
|
+
try:
|
|
86
|
+
from IPython.display import display
|
|
87
|
+
display(visual)
|
|
88
|
+
return visual
|
|
89
|
+
except ImportError:
|
|
90
|
+
# Not in IPython/Jupyter environment, return the graph
|
|
91
|
+
return visual
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def visualize_model(
|
|
95
|
+
model: nn.Module,
|
|
96
|
+
input_data: Optional[Dict[str,
|
|
97
|
+
torch.Tensor]] = None,
|
|
98
|
+
batch_size: int = 2,
|
|
99
|
+
seq_length: int = 10,
|
|
100
|
+
depth: int = 3,
|
|
101
|
+
show_shapes: bool = True,
|
|
102
|
+
expand_nested: bool = True,
|
|
103
|
+
save_path: Optional[str] = None,
|
|
104
|
+
graph_name: str = "model",
|
|
105
|
+
device: str = "cpu",
|
|
106
|
+
dpi: int = 300,
|
|
107
|
+
**kwargs
|
|
108
|
+
) -> Any:
|
|
109
|
+
"""Visualize a Torch-RecHub model's computation graph.
|
|
110
|
+
|
|
111
|
+
This function generates a visual representation of the model architecture,
|
|
112
|
+
showing layer connections, tensor shapes, and nested module structures.
|
|
113
|
+
It automatically extracts feature information from the model to generate
|
|
114
|
+
appropriate dummy inputs.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
model : nn.Module
|
|
119
|
+
PyTorch model to visualize. Should be a Torch-RecHub model
|
|
120
|
+
with feature attributes (e.g., DeepFM, DSSM, MMOE).
|
|
121
|
+
input_data : dict, optional
|
|
122
|
+
Dict of example inputs {feature_name: tensor}.
|
|
123
|
+
If None, inputs are auto-generated based on model features.
|
|
124
|
+
batch_size : int, default=2
|
|
125
|
+
Batch size for auto-generated inputs.
|
|
126
|
+
seq_length : int, default=10
|
|
127
|
+
Sequence length for SequenceFeature inputs.
|
|
128
|
+
depth : int, default=3
|
|
129
|
+
Visualization depth - higher values show more detail.
|
|
130
|
+
Set to -1 to show all layers.
|
|
131
|
+
show_shapes : bool, default=True
|
|
132
|
+
Whether to display tensor shapes on edges.
|
|
133
|
+
expand_nested : bool, default=True
|
|
134
|
+
Whether to expand nested nn.Module with dashed borders.
|
|
135
|
+
save_path : str, optional
|
|
136
|
+
Path to save the graph image. Supports .pdf, .svg, .png formats.
|
|
137
|
+
If None, displays in Jupyter or opens system viewer.
|
|
138
|
+
graph_name : str, default="model"
|
|
139
|
+
Name for the computation graph.
|
|
140
|
+
device : str, default="cpu"
|
|
141
|
+
Device for model execution during tracing.
|
|
142
|
+
dpi : int, default=300
|
|
143
|
+
Resolution in dots per inch for output image.
|
|
144
|
+
Higher values produce sharper images suitable for papers.
|
|
145
|
+
**kwargs : dict
|
|
146
|
+
Additional arguments passed to torchview.draw_graph().
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
ComputationGraph
|
|
151
|
+
A torchview ComputationGraph object.
|
|
152
|
+
- Use `.visual_graph` property to get the graphviz.Digraph
|
|
153
|
+
- Use `.resize_graph(scale=1.5)` to adjust graph size
|
|
154
|
+
|
|
155
|
+
Raises
|
|
156
|
+
------
|
|
157
|
+
ImportError
|
|
158
|
+
If torchview or graphviz is not installed.
|
|
159
|
+
ValueError
|
|
160
|
+
If model has no recognizable feature attributes.
|
|
161
|
+
|
|
162
|
+
Notes
|
|
163
|
+
-----
|
|
164
|
+
Default Display Behavior:
|
|
165
|
+
When `save_path` is None (default):
|
|
166
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
167
|
+
- In Python script: opens the graph with system default viewer
|
|
168
|
+
|
|
169
|
+
Requires graphviz system package: apt/brew/choco install graphviz.
|
|
170
|
+
For Jupyter display issues, try: graphviz.set_jupyter_format('png').
|
|
171
|
+
|
|
172
|
+
Examples
|
|
173
|
+
--------
|
|
174
|
+
>>> from torch_rechub.models.ranking import DeepFM
|
|
175
|
+
>>> from torch_rechub.utils.visualization import visualize_model
|
|
176
|
+
>>>
|
|
177
|
+
>>> # Auto-display in Jupyter or open in viewer
|
|
178
|
+
>>> visualize_model(model, depth=4) # No save_path needed
|
|
179
|
+
>>>
|
|
180
|
+
>>> # Save to high-DPI PNG for paper
|
|
181
|
+
>>> visualize_model(model, save_path="model.png", dpi=300)
|
|
182
|
+
"""
|
|
183
|
+
if not TORCHVIEW_AVAILABLE:
|
|
184
|
+
raise ImportError(
|
|
185
|
+
f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
|
|
186
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
187
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
188
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
189
|
+
" - macOS: brew install graphviz\n"
|
|
190
|
+
" - Windows: choco install graphviz"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Import feature extraction utilities from model_utils
|
|
194
|
+
from .model_utils import extract_feature_info, generate_dummy_input_dict
|
|
195
|
+
|
|
196
|
+
model.eval()
|
|
197
|
+
model.to(device)
|
|
198
|
+
|
|
199
|
+
# Auto-generate input data if not provided
|
|
200
|
+
if input_data is None:
|
|
201
|
+
feature_info = extract_feature_info(model)
|
|
202
|
+
features = feature_info['features']
|
|
203
|
+
|
|
204
|
+
if not features:
|
|
205
|
+
raise ValueError("Could not extract feature information from model. "
|
|
206
|
+
"Please provide input_data parameter manually.")
|
|
207
|
+
|
|
208
|
+
# Generate dummy input dict
|
|
209
|
+
input_data = generate_dummy_input_dict(features, batch_size=batch_size, seq_length=seq_length, device=device)
|
|
210
|
+
else:
|
|
211
|
+
# Ensure input tensors are on correct device
|
|
212
|
+
input_data = {k: v.to(device) for k, v in input_data.items()}
|
|
213
|
+
|
|
214
|
+
# IMPORTANT: Wrap input_data dict in a tuple to work around torchview's behavior
|
|
215
|
+
#
|
|
216
|
+
# torchview's forward_prop function checks the type of input_data:
|
|
217
|
+
# - If isinstance(x, (list, tuple)): model(*x)
|
|
218
|
+
# - If isinstance(x, Mapping): model(**x) <- This unpacks dict as kwargs!
|
|
219
|
+
# - Else: model(x)
|
|
220
|
+
#
|
|
221
|
+
# torch-rechub models expect forward(self, x) where x is a complete dict.
|
|
222
|
+
# By wrapping the dict in a tuple, torchview will call:
|
|
223
|
+
# model(*(input_dict,)) = model(input_dict)
|
|
224
|
+
# which is exactly what our models expect.
|
|
225
|
+
input_data_wrapped = (input_data,)
|
|
226
|
+
|
|
227
|
+
# Call torchview.draw_graph without saving (we'll save manually with DPI)
|
|
228
|
+
graph = draw_graph(
|
|
229
|
+
model,
|
|
230
|
+
input_data=input_data_wrapped,
|
|
231
|
+
graph_name=graph_name,
|
|
232
|
+
depth=depth,
|
|
233
|
+
device=device,
|
|
234
|
+
expand_nested=expand_nested,
|
|
235
|
+
show_shapes=show_shapes,
|
|
236
|
+
save_graph=False, # Don't save here, we'll save manually with DPI
|
|
237
|
+
**kwargs
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Set DPI for high-quality output (must be set BEFORE rendering/saving)
|
|
241
|
+
graph.visual_graph.graph_attr['dpi'] = str(dpi)
|
|
242
|
+
|
|
243
|
+
# Handle save_path: manually save with DPI applied
|
|
244
|
+
if save_path:
|
|
245
|
+
import os
|
|
246
|
+
directory = os.path.dirname(save_path) or "."
|
|
247
|
+
filename = os.path.splitext(os.path.basename(save_path))[0]
|
|
248
|
+
ext = os.path.splitext(save_path)[1].lstrip('.')
|
|
249
|
+
# Default to pdf if no extension
|
|
250
|
+
output_format = ext if ext else 'pdf'
|
|
251
|
+
# Create directory if it doesn't exist
|
|
252
|
+
if directory != "." and not os.path.exists(directory):
|
|
253
|
+
os.makedirs(directory, exist_ok=True)
|
|
254
|
+
# Render and save with DPI applied
|
|
255
|
+
graph.visual_graph.render(
|
|
256
|
+
filename=filename,
|
|
257
|
+
directory=directory,
|
|
258
|
+
format=output_format,
|
|
259
|
+
cleanup=True # Remove intermediate .gv file
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Handle default display behavior when save_path is None
|
|
263
|
+
if save_path is None:
|
|
264
|
+
if _is_jupyter_environment():
|
|
265
|
+
# In Jupyter: display inline
|
|
266
|
+
display_graph(graph)
|
|
267
|
+
else:
|
|
268
|
+
# In script: open with system viewer
|
|
269
|
+
graph.visual_graph.view(cleanup=True)
|
|
270
|
+
|
|
271
|
+
return graph
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-rechub
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.5
|
|
4
4
|
Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
|
|
5
5
|
Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
|
|
6
6
|
Project-URL: Documentation, https://www.torch-rechub.com
|
|
@@ -41,6 +41,9 @@ Requires-Dist: yapf==0.43.0; extra == 'dev'
|
|
|
41
41
|
Provides-Extra: onnx
|
|
42
42
|
Requires-Dist: onnx>=1.12.0; extra == 'onnx'
|
|
43
43
|
Requires-Dist: onnxruntime>=1.12.0; extra == 'onnx'
|
|
44
|
+
Provides-Extra: visualization
|
|
45
|
+
Requires-Dist: graphviz>=0.20; extra == 'visualization'
|
|
46
|
+
Requires-Dist: torchview>=0.2.6; extra == 'visualization'
|
|
44
47
|
Description-Content-Type: text/markdown
|
|
45
48
|
|
|
46
49
|
# 🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架
|
|
@@ -69,13 +72,13 @@ Description-Content-Type: text/markdown
|
|
|
69
72
|
|
|
70
73
|
## 🎯 为什么选择 Torch-RecHub?
|
|
71
74
|
|
|
72
|
-
| 特性
|
|
73
|
-
|
|
74
|
-
| 代码行数
|
|
75
|
-
| 模型覆盖
|
|
76
|
-
| 生成式推荐
|
|
77
|
-
| ONNX 一键导出 | ✅ 内置支持
|
|
78
|
-
| 学习曲线
|
|
75
|
+
| 特性 | Torch-RecHub | 其他框架 |
|
|
76
|
+
| ------------- | --------------------------- | ---------- |
|
|
77
|
+
| 代码行数 | **10行** 完成训练+评估+部署 | 100+ 行 |
|
|
78
|
+
| 模型覆盖 | **30+** 主流模型 | 有限 |
|
|
79
|
+
| 生成式推荐 | ✅ HSTU/HLLM (Meta 2024) | ❌ |
|
|
80
|
+
| ONNX 一键导出 | ✅ 内置支持 | 需手动适配 |
|
|
81
|
+
| 学习曲线 | 极低 | 陡峭 |
|
|
79
82
|
|
|
80
83
|
## ✨ 特性
|
|
81
84
|
|
|
@@ -205,52 +208,52 @@ torch-rechub/ # 根目录
|
|
|
205
208
|
|
|
206
209
|
### 排序模型 (Ranking Models) - 13个
|
|
207
210
|
|
|
208
|
-
| 模型
|
|
209
|
-
|
|
210
|
-
| **DeepFM**
|
|
211
|
-
| **Wide&Deep** | [DLRS 2016](https://arxiv.org/abs/1606.07792)
|
|
212
|
-
| **DCN**
|
|
213
|
-
| **DCN-v2**
|
|
214
|
-
| **DIN**
|
|
215
|
-
| **DIEN**
|
|
216
|
-
| **BST**
|
|
217
|
-
| **AFM**
|
|
218
|
-
| **AutoInt**
|
|
219
|
-
| **FiBiNET**
|
|
220
|
-
| **DeepFFM**
|
|
221
|
-
| **EDCN**
|
|
211
|
+
| 模型 | 论文 | 简介 |
|
|
212
|
+
| ------------- | ------------------------------------------------ | ----------------------- |
|
|
213
|
+
| **DeepFM** | [IJCAI 2017](https://arxiv.org/abs/1703.04247) | FM + Deep 联合训练 |
|
|
214
|
+
| **Wide&Deep** | [DLRS 2016](https://arxiv.org/abs/1606.07792) | 记忆 + 泛化能力结合 |
|
|
215
|
+
| **DCN** | [KDD 2017](https://arxiv.org/abs/1708.05123) | 显式特征交叉网络 |
|
|
216
|
+
| **DCN-v2** | [WWW 2021](https://arxiv.org/abs/2008.13535) | 增强版交叉网络 |
|
|
217
|
+
| **DIN** | [KDD 2018](https://arxiv.org/abs/1706.06978) | 注意力机制捕捉用户兴趣 |
|
|
218
|
+
| **DIEN** | [AAAI 2019](https://arxiv.org/abs/1809.03672) | 兴趣演化建模 |
|
|
219
|
+
| **BST** | [DLP-KDD 2019](https://arxiv.org/abs/1905.06874) | Transformer 序列建模 |
|
|
220
|
+
| **AFM** | [IJCAI 2017](https://arxiv.org/abs/1708.04617) | 注意力因子分解机 |
|
|
221
|
+
| **AutoInt** | [CIKM 2019](https://arxiv.org/abs/1810.11921) | 自动特征交互学习 |
|
|
222
|
+
| **FiBiNET** | [RecSys 2019](https://arxiv.org/abs/1905.09433) | 特征重要性 + 双线性交互 |
|
|
223
|
+
| **DeepFFM** | [RecSys 2019](https://arxiv.org/abs/1611.00144) | 场感知因子分解机 |
|
|
224
|
+
| **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络 |
|
|
222
225
|
|
|
223
226
|
### 召回模型 (Matching Models) - 12个
|
|
224
227
|
|
|
225
|
-
| 模型
|
|
226
|
-
|
|
227
|
-
| **DSSM**
|
|
228
|
-
| **YoutubeDNN** | [RecSys 2016](https://dl.acm.org/doi/10.1145/2959100.2959190)
|
|
229
|
-
| **YoutubeSBC** | [RecSys 2019](https://dl.acm.org/doi/10.1145/3298689.3346997)
|
|
230
|
-
| **MIND**
|
|
231
|
-
| **SINE**
|
|
232
|
-
| **GRU4Rec**
|
|
233
|
-
| **SASRec**
|
|
234
|
-
| **NARM**
|
|
235
|
-
| **STAMP**
|
|
236
|
-
| **ComiRec**
|
|
228
|
+
| 模型 | 论文 | 简介 |
|
|
229
|
+
| -------------- | ------------------------------------------------------------------------------ | ------------------ |
|
|
230
|
+
| **DSSM** | [CIKM 2013](https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf) | 经典双塔召回模型 |
|
|
231
|
+
| **YoutubeDNN** | [RecSys 2016](https://dl.acm.org/doi/10.1145/2959100.2959190) | YouTube 深度召回 |
|
|
232
|
+
| **YoutubeSBC** | [RecSys 2019](https://dl.acm.org/doi/10.1145/3298689.3346997) | 采样偏差校正版本 |
|
|
233
|
+
| **MIND** | [CIKM 2019](https://arxiv.org/abs/1904.08030) | 多兴趣动态路由 |
|
|
234
|
+
| **SINE** | [WSDM 2021](https://arxiv.org/abs/2103.06920) | 稀疏兴趣网络 |
|
|
235
|
+
| **GRU4Rec** | [ICLR 2016](https://arxiv.org/abs/1511.06939) | GRU 序列推荐 |
|
|
236
|
+
| **SASRec** | [ICDM 2018](https://arxiv.org/abs/1808.09781) | 自注意力序列推荐 |
|
|
237
|
+
| **NARM** | [CIKM 2017](https://arxiv.org/abs/1711.04725) | 神经注意力会话推荐 |
|
|
238
|
+
| **STAMP** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3219895) | 短期注意力记忆优先 |
|
|
239
|
+
| **ComiRec** | [KDD 2020](https://arxiv.org/abs/2005.09347) | 可控多兴趣推荐 |
|
|
237
240
|
|
|
238
241
|
### 多任务模型 (Multi-Task Models) - 5个
|
|
239
242
|
|
|
240
|
-
| 模型
|
|
241
|
-
|
|
242
|
-
| **ESMM**
|
|
243
|
-
| **MMoE**
|
|
244
|
-
| **PLE**
|
|
245
|
-
| **AITM**
|
|
246
|
-
| **SharedBottom** | -
|
|
243
|
+
| 模型 | 论文 | 简介 |
|
|
244
|
+
| ---------------- | ------------------------------------------------------------- | ------------------ |
|
|
245
|
+
| **ESMM** | [SIGIR 2018](https://arxiv.org/abs/1804.07931) | 全空间多任务建模 |
|
|
246
|
+
| **MMoE** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3220007) | 多门控专家混合 |
|
|
247
|
+
| **PLE** | [RecSys 2020](https://dl.acm.org/doi/10.1145/3383313.3412236) | 渐进式分层提取 |
|
|
248
|
+
| **AITM** | [KDD 2021](https://arxiv.org/abs/2105.08489) | 自适应信息迁移 |
|
|
249
|
+
| **SharedBottom** | - | 经典多任务共享底层 |
|
|
247
250
|
|
|
248
251
|
### 生成式推荐 (Generative Recommendation) - 2个
|
|
249
252
|
|
|
250
|
-
| 模型
|
|
251
|
-
|
|
253
|
+
| 模型 | 论文 | 简介 |
|
|
254
|
+
| -------- | --------------------------------------------- | -------------------------------------------- |
|
|
252
255
|
| **HSTU** | [Meta 2024](https://arxiv.org/abs/2402.17152) | 层级序列转换单元,支撑 Meta 万亿参数推荐系统 |
|
|
253
|
-
| **HLLM** | [2024](https://arxiv.org/abs/2409.12740)
|
|
256
|
+
| **HLLM** | [2024](https://arxiv.org/abs/2409.12740) | 层级大语言模型推荐,融合 LLM 语义理解能力 |
|
|
254
257
|
|
|
255
258
|
## 📊 支持的数据集
|
|
256
259
|
|
|
@@ -338,11 +341,19 @@ model = DSSM(user_features, item_features, temperature=0.02,
|
|
|
338
341
|
match_trainer = MatchTrainer(model)
|
|
339
342
|
match_trainer.fit(train_dl)
|
|
340
343
|
match_trainer.export_onnx("dssm.onnx")
|
|
341
|
-
# 双塔模型可分别导出用户塔和物品塔:
|
|
344
|
+
# 双塔模型可分别导出用户塔和物品塔:
|
|
342
345
|
# match_trainer.export_onnx("user_tower.onnx", mode="user")
|
|
343
346
|
# match_trainer.export_onnx("dssm_item.onnx", tower="item")
|
|
344
347
|
```
|
|
345
348
|
|
|
349
|
+
### 模型可视化
|
|
350
|
+
|
|
351
|
+
```python
|
|
352
|
+
# 可视化模型架构(需要安装: pip install torch-rechub[visualization])
|
|
353
|
+
graph = ctr_trainer.visualization(depth=4) # 生成计算图
|
|
354
|
+
ctr_trainer.visualization(save_path="model.pdf", dpi=300) # 保存为高清 PDF
|
|
355
|
+
```
|
|
356
|
+
|
|
346
357
|
## 👨💻 贡献者
|
|
347
358
|
|
|
348
359
|
感谢所有的贡献者!
|
|
@@ -45,18 +45,20 @@ torch_rechub/models/ranking/edcn.py,sha256=6f_S8I6Ir16kCIU54R4EfumWfUFOND5KDKUPH
|
|
|
45
45
|
torch_rechub/models/ranking/fibinet.py,sha256=fmEJ9WkO8Mn0RtK_8aRHlnQFh_jMBPO0zODoHZPWmDA,2234
|
|
46
46
|
torch_rechub/models/ranking/widedeep.py,sha256=eciRvWRBHLlctabLLS5NB7k3MnqrWXCBdpflOU6jMB0,1636
|
|
47
47
|
torch_rechub/trainers/__init__.py,sha256=NSa2DqgfE1HGDyj40YgrbtUrfBHBxNBpw57XtaAB_jE,148
|
|
48
|
-
torch_rechub/trainers/ctr_trainer.py,sha256=
|
|
49
|
-
torch_rechub/trainers/match_trainer.py,sha256=
|
|
48
|
+
torch_rechub/trainers/ctr_trainer.py,sha256=ECXaK0x2_6jZVxtEazgN3hkBpSAMPeGeNtunqI_OECo,12860
|
|
49
|
+
torch_rechub/trainers/match_trainer.py,sha256=QHZb32Rf7yp-NvEzdeiG1HQghQ76_vuu59K1IsdK60k,15055
|
|
50
50
|
torch_rechub/trainers/matching.md,sha256=vIBQ3UMmVpUpyk38rrkelFwm_wXVXqMOuqzYZ4M8bzw,30
|
|
51
|
-
torch_rechub/trainers/mtl_trainer.py,sha256=
|
|
52
|
-
torch_rechub/trainers/seq_trainer.py,sha256=
|
|
51
|
+
torch_rechub/trainers/mtl_trainer.py,sha256=MjasE_QOPfGxiUW1JpYYQ2iuBSSk-lissAGp4Sw1CWk,16427
|
|
52
|
+
torch_rechub/trainers/seq_trainer.py,sha256=uAo9XymwQupCqvm5otKW81tz1nxd3crJ2ul2r7lrEAE,17633
|
|
53
53
|
torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
54
54
|
torch_rechub/utils/data.py,sha256=vzLAAVt6dujg_vbGhQewiJc0l6JzwzdcM_9EjoOz898,19882
|
|
55
55
|
torch_rechub/utils/hstu_utils.py,sha256=qLON_pJDC-kDyQn1PoN_HaHi5xTNCwZPgJeV51Z61Lc,6207
|
|
56
56
|
torch_rechub/utils/match.py,sha256=l9qDwJGHPP9gOQTMYoqGVdWrlhDx1F1-8UnQwDWrEyk,18143
|
|
57
|
+
torch_rechub/utils/model_utils.py,sha256=VLhSbTpupxrFyyY3NzMQ32PPmo5YHm1T96u9KDlwiWE,8450
|
|
57
58
|
torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,5737
|
|
58
|
-
torch_rechub/utils/onnx_export.py,sha256=
|
|
59
|
-
torch_rechub
|
|
60
|
-
torch_rechub-0.0.
|
|
61
|
-
torch_rechub-0.0.
|
|
62
|
-
torch_rechub-0.0.
|
|
59
|
+
torch_rechub/utils/onnx_export.py,sha256=LRHyZaR9zZJyg6xtuqQHWmusWq-yEvw9EhlmoEwcqsg,8364
|
|
60
|
+
torch_rechub/utils/visualization.py,sha256=Djv8W5SkCk3P2dol5VXf0_eanIhxDwRd7fzNOQY4uiU,9506
|
|
61
|
+
torch_rechub-0.0.5.dist-info/METADATA,sha256=7k9N1xGB4JeWzri7iA7kJbPnAJ-KhXF7vBV-_b8Ghrg,17998
|
|
62
|
+
torch_rechub-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
63
|
+
torch_rechub-0.0.5.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
|
|
64
|
+
torch_rechub-0.0.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|