torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -0,0 +1,220 @@
1
+ """
2
+ ONNX Export Utilities for Torch-RecHub models.
3
+
4
+ This module provides non-invasive ONNX export functionality for recommendation models.
5
+ It uses reflection to extract feature information from models and wraps dict-input models
6
+ to be compatible with ONNX's positional argument requirements.
7
+
8
+ Date: 2024
9
+ References:
10
+ - PyTorch ONNX Export: https://pytorch.org/docs/stable/onnx.html
11
+ - ONNX Runtime: https://onnxruntime.ai/docs/
12
+ Authors: Torch-RecHub Contributors
13
+ """
14
+
15
+ import os
16
+ import warnings
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
23
+
24
+
25
+ class ONNXWrapper(nn.Module):
26
+ """Wraps a dict-input model to accept positional arguments for ONNX compatibility.
27
+
28
+ ONNX does not support dict as input, so this wrapper converts positional arguments
29
+ back to dict format before passing to the original model.
30
+
31
+ Args:
32
+ model: The original PyTorch model that accepts dict input.
33
+ input_names: Ordered list of feature names corresponding to input positions.
34
+ mode: Optional mode for dual-tower models ("user" or "item").
35
+
36
+ Example:
37
+ >>> wrapper = ONNXWrapper(dssm_model, ["user_id", "movie_id", "hist_movie_id"])
38
+ >>> # Now can call: wrapper(user_id_tensor, movie_id_tensor, hist_tensor)
39
+ """
40
+
41
+ def __init__(self, model: nn.Module, input_names: List[str], mode: Optional[str] = None):
42
+ super().__init__()
43
+ self.model = model
44
+ self.input_names = input_names
45
+ self._original_mode = getattr(model, 'mode', None)
46
+
47
+ # Set mode for dual-tower models
48
+ if mode is not None and hasattr(model, 'mode'):
49
+ model.mode = mode
50
+
51
+ def forward(self, *args) -> torch.Tensor:
52
+ """Convert positional args to dict and call original model."""
53
+ if len(args) != len(self.input_names):
54
+ raise ValueError(f"Expected {len(self.input_names)} inputs, got {len(args)}. "
55
+ f"Expected names: {self.input_names}")
56
+ x_dict = {name: arg for name, arg in zip(self.input_names, args)}
57
+ return self.model(x_dict)
58
+
59
+ def restore_mode(self):
60
+ """Restore the original mode of the model."""
61
+ if hasattr(self.model, 'mode'):
62
+ self.model.mode = self._original_mode
63
+
64
+
65
+ # Re-export from model_utils for backward compatibility
66
+ # The actual implementations are now in model_utils.py
67
+ from .model_utils import extract_feature_info, generate_dummy_input, generate_dummy_input_dict, generate_dynamic_axes
68
+
69
+
70
+ class ONNXExporter:
71
+ """Main class for exporting Torch-RecHub models to ONNX format.
72
+
73
+ This exporter handles the complexity of converting dict-input models to ONNX
74
+ by automatically extracting feature information and wrapping the model.
75
+
76
+ Args:
77
+ model: The PyTorch recommendation model to export.
78
+ device: Device for export operations (default: 'cpu').
79
+
80
+ Example:
81
+ >>> exporter = ONNXExporter(deepfm_model)
82
+ >>> exporter.export("model.onnx")
83
+
84
+ >>> # For dual-tower models
85
+ >>> exporter = ONNXExporter(dssm_model)
86
+ >>> exporter.export("user_tower.onnx", mode="user")
87
+ >>> exporter.export("item_tower.onnx", mode="item")
88
+ """
89
+
90
+ def __init__(self, model: nn.Module, device: str = 'cpu'):
91
+ self.model = model
92
+ self.device = device
93
+ self.feature_info = extract_feature_info(model)
94
+
95
+ def export(
96
+ self,
97
+ output_path: str,
98
+ mode: Optional[str] = None,
99
+ dummy_input: Optional[Dict[str,
100
+ torch.Tensor]] = None,
101
+ batch_size: int = 2,
102
+ seq_length: int = 10,
103
+ opset_version: int = 14,
104
+ dynamic_batch: bool = True,
105
+ verbose: bool = False
106
+ ) -> bool:
107
+ """Export the model to ONNX format.
108
+
109
+ Args:
110
+ output_path: Path to save the ONNX model.
111
+ mode: For dual-tower models, specify "user" or "item" to export
112
+ only that tower. None exports the full model.
113
+ dummy_input: Optional dict of example inputs. If not provided,
114
+ dummy inputs will be generated automatically.
115
+ batch_size: Batch size for generated dummy input (default: 2).
116
+ seq_length: Sequence length for SequenceFeature (default: 10).
117
+ opset_version: ONNX opset version (default: 14).
118
+ dynamic_batch: Whether to enable dynamic batch size (default: True).
119
+ verbose: Whether to print export details (default: False).
120
+
121
+ Returns:
122
+ True if export succeeded, False otherwise.
123
+
124
+ Raises:
125
+ RuntimeError: If ONNX export fails.
126
+ """
127
+ self.model.eval()
128
+ self.model.to(self.device)
129
+
130
+ # Determine which features to use based on mode
131
+ if mode == "user":
132
+ features = self.feature_info['user_features']
133
+ if not features:
134
+ raise ValueError("No user features found in model for mode='user'")
135
+ elif mode == "item":
136
+ features = self.feature_info['item_features']
137
+ if not features:
138
+ raise ValueError("No item features found in model for mode='item'")
139
+ else:
140
+ features = self.feature_info['features']
141
+
142
+ input_names = [f.name for f in features]
143
+
144
+ # Create wrapped model
145
+ wrapper = ONNXWrapper(self.model, input_names, mode=mode)
146
+ wrapper.eval()
147
+
148
+ # Generate or use provided dummy input
149
+ if dummy_input is not None:
150
+ dummy_tuple = tuple(dummy_input[name].to(self.device) for name in input_names)
151
+ else:
152
+ dummy_tuple = generate_dummy_input(features, batch_size=batch_size, seq_length=seq_length, device=self.device)
153
+
154
+ # Configure dynamic axes
155
+ dynamic_axes = None
156
+ if dynamic_batch:
157
+ seq_feature_names = [f.name for f in features if isinstance(f, SequenceFeature)]
158
+ dynamic_axes = generate_dynamic_axes(input_names=input_names, output_names=["output"], seq_features=seq_feature_names)
159
+
160
+ # Ensure output directory exists
161
+ output_dir = os.path.dirname(output_path)
162
+ if output_dir and not os.path.exists(output_dir):
163
+ os.makedirs(output_dir)
164
+
165
+ try:
166
+ with torch.no_grad():
167
+ torch.onnx.export(
168
+ wrapper,
169
+ dummy_tuple,
170
+ output_path,
171
+ input_names=input_names,
172
+ output_names=["output"],
173
+ dynamic_axes=dynamic_axes,
174
+ opset_version=opset_version,
175
+ do_constant_folding=True,
176
+ verbose=verbose,
177
+ dynamo=False # Use legacy exporter for dynamic_axes support
178
+ )
179
+
180
+ if verbose:
181
+ print(f"Successfully exported ONNX model to: {output_path}")
182
+ print(f" Input names: {input_names}")
183
+ print(f" Opset version: {opset_version}")
184
+ print(f" Dynamic batch: {dynamic_batch}")
185
+
186
+ return True
187
+
188
+ except Exception as e:
189
+ warnings.warn(f"ONNX export failed: {str(e)}")
190
+ raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
191
+ finally:
192
+ # Restore original mode
193
+ wrapper.restore_mode()
194
+
195
+ def get_input_info(self, mode: Optional[str] = None) -> Dict[str, Any]:
196
+ """Get information about model inputs.
197
+
198
+ Args:
199
+ mode: For dual-tower models, "user" or "item".
200
+
201
+ Returns:
202
+ Dict with input names, types, and shapes.
203
+ """
204
+ if mode == "user":
205
+ features = self.feature_info['user_features']
206
+ elif mode == "item":
207
+ features = self.feature_info['item_features']
208
+ else:
209
+ features = self.feature_info['features']
210
+
211
+ info = []
212
+ for f in features:
213
+ feat_info = {'name': f.name, 'type': type(f).__name__, 'embed_dim': f.embed_dim}
214
+ if hasattr(f, 'vocab_size'):
215
+ feat_info['vocab_size'] = f.vocab_size
216
+ if hasattr(f, 'pooling'):
217
+ feat_info['pooling'] = f.pooling
218
+ info.append(feat_info)
219
+
220
+ return {'mode': mode, 'inputs': info, 'input_names': [f.name for f in features]}
@@ -0,0 +1,271 @@
1
+ """
2
+ Model Visualization Utilities for Torch-RecHub.
3
+
4
+ This module provides model structure visualization using torchview library.
5
+ Requires optional dependencies: pip install torch-rechub[visualization]
6
+
7
+ Example:
8
+ >>> from torch_rechub.utils.visualization import visualize_model, display_graph
9
+ >>> graph = visualize_model(model, depth=4)
10
+ >>> display_graph(graph) # Display in Jupyter Notebook
11
+
12
+ >>> # Save to file
13
+ >>> visualize_model(model, save_path="model_arch.pdf")
14
+ """
15
+
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ # Check for optional dependencies
22
+ TORCHVIEW_AVAILABLE = False
23
+ TORCHVIEW_SKIP_REASON = "torchview not installed"
24
+
25
+ try:
26
+ from torchview import draw_graph
27
+ TORCHVIEW_AVAILABLE = True
28
+ except ImportError as e:
29
+ TORCHVIEW_SKIP_REASON = f"torchview not available: {e}"
30
+
31
+
32
+ def _is_jupyter_environment() -> bool:
33
+ """Check if running in Jupyter/IPython environment."""
34
+ try:
35
+ from IPython import get_ipython
36
+ shell = get_ipython()
37
+ if shell is None:
38
+ return False
39
+ # Check for Jupyter notebook or qtconsole
40
+ shell_class = shell.__class__.__name__
41
+ return shell_class in ('ZMQInteractiveShell', 'TerminalInteractiveShell')
42
+ except (ImportError, NameError):
43
+ return False
44
+
45
+
46
+ def display_graph(graph: Any, format: str = 'png') -> Any:
47
+ """Display a torchview ComputationGraph in Jupyter Notebook.
48
+
49
+ This function provides a reliable way to display visualization graphs
50
+ in Jupyter environments, especially VSCode Jupyter.
51
+
52
+ Parameters
53
+ ----------
54
+ graph : ComputationGraph
55
+ A torchview ComputationGraph object returned by visualize_model().
56
+ format : str, default='png'
57
+ Output format, 'png' recommended for VSCode.
58
+
59
+ Returns
60
+ -------
61
+ graphviz.Digraph or None
62
+ The displayed graph object, or None if display fails.
63
+
64
+ Examples
65
+ --------
66
+ >>> graph = visualize_model(model, depth=4)
67
+ >>> display_graph(graph) # Works in VSCode Jupyter
68
+ """
69
+ if not TORCHVIEW_AVAILABLE:
70
+ raise ImportError(f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
71
+ "Install with: pip install torch-rechub[visualization]")
72
+
73
+ try:
74
+ import graphviz
75
+
76
+ # Set format for VSCode compatibility
77
+ graphviz.set_jupyter_format(format)
78
+ except ImportError:
79
+ pass
80
+
81
+ # Get the visual_graph (graphviz.Digraph object)
82
+ visual = graph.visual_graph
83
+
84
+ # Try to use IPython display for explicit rendering
85
+ try:
86
+ from IPython.display import display
87
+ display(visual)
88
+ return visual
89
+ except ImportError:
90
+ # Not in IPython/Jupyter environment, return the graph
91
+ return visual
92
+
93
+
94
+ def visualize_model(
95
+ model: nn.Module,
96
+ input_data: Optional[Dict[str,
97
+ torch.Tensor]] = None,
98
+ batch_size: int = 2,
99
+ seq_length: int = 10,
100
+ depth: int = 3,
101
+ show_shapes: bool = True,
102
+ expand_nested: bool = True,
103
+ save_path: Optional[str] = None,
104
+ graph_name: str = "model",
105
+ device: str = "cpu",
106
+ dpi: int = 300,
107
+ **kwargs
108
+ ) -> Any:
109
+ """Visualize a Torch-RecHub model's computation graph.
110
+
111
+ This function generates a visual representation of the model architecture,
112
+ showing layer connections, tensor shapes, and nested module structures.
113
+ It automatically extracts feature information from the model to generate
114
+ appropriate dummy inputs.
115
+
116
+ Parameters
117
+ ----------
118
+ model : nn.Module
119
+ PyTorch model to visualize. Should be a Torch-RecHub model
120
+ with feature attributes (e.g., DeepFM, DSSM, MMOE).
121
+ input_data : dict, optional
122
+ Dict of example inputs {feature_name: tensor}.
123
+ If None, inputs are auto-generated based on model features.
124
+ batch_size : int, default=2
125
+ Batch size for auto-generated inputs.
126
+ seq_length : int, default=10
127
+ Sequence length for SequenceFeature inputs.
128
+ depth : int, default=3
129
+ Visualization depth - higher values show more detail.
130
+ Set to -1 to show all layers.
131
+ show_shapes : bool, default=True
132
+ Whether to display tensor shapes on edges.
133
+ expand_nested : bool, default=True
134
+ Whether to expand nested nn.Module with dashed borders.
135
+ save_path : str, optional
136
+ Path to save the graph image. Supports .pdf, .svg, .png formats.
137
+ If None, displays in Jupyter or opens system viewer.
138
+ graph_name : str, default="model"
139
+ Name for the computation graph.
140
+ device : str, default="cpu"
141
+ Device for model execution during tracing.
142
+ dpi : int, default=300
143
+ Resolution in dots per inch for output image.
144
+ Higher values produce sharper images suitable for papers.
145
+ **kwargs : dict
146
+ Additional arguments passed to torchview.draw_graph().
147
+
148
+ Returns
149
+ -------
150
+ ComputationGraph
151
+ A torchview ComputationGraph object.
152
+ - Use `.visual_graph` property to get the graphviz.Digraph
153
+ - Use `.resize_graph(scale=1.5)` to adjust graph size
154
+
155
+ Raises
156
+ ------
157
+ ImportError
158
+ If torchview or graphviz is not installed.
159
+ ValueError
160
+ If model has no recognizable feature attributes.
161
+
162
+ Notes
163
+ -----
164
+ Default Display Behavior:
165
+ When `save_path` is None (default):
166
+ - In Jupyter/IPython: automatically displays the graph inline
167
+ - In Python script: opens the graph with system default viewer
168
+
169
+ Requires graphviz system package: apt/brew/choco install graphviz.
170
+ For Jupyter display issues, try: graphviz.set_jupyter_format('png').
171
+
172
+ Examples
173
+ --------
174
+ >>> from torch_rechub.models.ranking import DeepFM
175
+ >>> from torch_rechub.utils.visualization import visualize_model
176
+ >>>
177
+ >>> # Auto-display in Jupyter or open in viewer
178
+ >>> visualize_model(model, depth=4) # No save_path needed
179
+ >>>
180
+ >>> # Save to high-DPI PNG for paper
181
+ >>> visualize_model(model, save_path="model.png", dpi=300)
182
+ """
183
+ if not TORCHVIEW_AVAILABLE:
184
+ raise ImportError(
185
+ f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
186
+ "Install with: pip install torch-rechub[visualization]\n"
187
+ "Also ensure graphviz is installed on your system:\n"
188
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
189
+ " - macOS: brew install graphviz\n"
190
+ " - Windows: choco install graphviz"
191
+ )
192
+
193
+ # Import feature extraction utilities from model_utils
194
+ from .model_utils import extract_feature_info, generate_dummy_input_dict
195
+
196
+ model.eval()
197
+ model.to(device)
198
+
199
+ # Auto-generate input data if not provided
200
+ if input_data is None:
201
+ feature_info = extract_feature_info(model)
202
+ features = feature_info['features']
203
+
204
+ if not features:
205
+ raise ValueError("Could not extract feature information from model. "
206
+ "Please provide input_data parameter manually.")
207
+
208
+ # Generate dummy input dict
209
+ input_data = generate_dummy_input_dict(features, batch_size=batch_size, seq_length=seq_length, device=device)
210
+ else:
211
+ # Ensure input tensors are on correct device
212
+ input_data = {k: v.to(device) for k, v in input_data.items()}
213
+
214
+ # IMPORTANT: Wrap input_data dict in a tuple to work around torchview's behavior
215
+ #
216
+ # torchview's forward_prop function checks the type of input_data:
217
+ # - If isinstance(x, (list, tuple)): model(*x)
218
+ # - If isinstance(x, Mapping): model(**x) <- This unpacks dict as kwargs!
219
+ # - Else: model(x)
220
+ #
221
+ # torch-rechub models expect forward(self, x) where x is a complete dict.
222
+ # By wrapping the dict in a tuple, torchview will call:
223
+ # model(*(input_dict,)) = model(input_dict)
224
+ # which is exactly what our models expect.
225
+ input_data_wrapped = (input_data,)
226
+
227
+ # Call torchview.draw_graph without saving (we'll save manually with DPI)
228
+ graph = draw_graph(
229
+ model,
230
+ input_data=input_data_wrapped,
231
+ graph_name=graph_name,
232
+ depth=depth,
233
+ device=device,
234
+ expand_nested=expand_nested,
235
+ show_shapes=show_shapes,
236
+ save_graph=False, # Don't save here, we'll save manually with DPI
237
+ **kwargs
238
+ )
239
+
240
+ # Set DPI for high-quality output (must be set BEFORE rendering/saving)
241
+ graph.visual_graph.graph_attr['dpi'] = str(dpi)
242
+
243
+ # Handle save_path: manually save with DPI applied
244
+ if save_path:
245
+ import os
246
+ directory = os.path.dirname(save_path) or "."
247
+ filename = os.path.splitext(os.path.basename(save_path))[0]
248
+ ext = os.path.splitext(save_path)[1].lstrip('.')
249
+ # Default to pdf if no extension
250
+ output_format = ext if ext else 'pdf'
251
+ # Create directory if it doesn't exist
252
+ if directory != "." and not os.path.exists(directory):
253
+ os.makedirs(directory, exist_ok=True)
254
+ # Render and save with DPI applied
255
+ graph.visual_graph.render(
256
+ filename=filename,
257
+ directory=directory,
258
+ format=output_format,
259
+ cleanup=True # Remove intermediate .gv file
260
+ )
261
+
262
+ # Handle default display behavior when save_path is None
263
+ if save_path is None:
264
+ if _is_jupyter_environment():
265
+ # In Jupyter: display inline
266
+ display_graph(graph)
267
+ else:
268
+ # In script: open with system viewer
269
+ graph.visual_graph.view(cleanup=True)
270
+
271
+ return graph