torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +3 -1
  3. torch_rechub/basic/callback.py +2 -2
  4. torch_rechub/basic/features.py +38 -8
  5. torch_rechub/basic/initializers.py +92 -0
  6. torch_rechub/basic/layers.py +800 -46
  7. torch_rechub/basic/loss_func.py +223 -0
  8. torch_rechub/basic/metaoptimizer.py +76 -0
  9. torch_rechub/basic/metric.py +251 -0
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -0
  14. torch_rechub/models/matching/comirec.py +193 -0
  15. torch_rechub/models/matching/dssm.py +72 -0
  16. torch_rechub/models/matching/dssm_facebook.py +77 -0
  17. torch_rechub/models/matching/dssm_senet.py +87 -0
  18. torch_rechub/models/matching/gru4rec.py +85 -0
  19. torch_rechub/models/matching/mind.py +103 -0
  20. torch_rechub/models/matching/narm.py +82 -0
  21. torch_rechub/models/matching/sasrec.py +143 -0
  22. torch_rechub/models/matching/sine.py +148 -0
  23. torch_rechub/models/matching/stamp.py +81 -0
  24. torch_rechub/models/matching/youtube_dnn.py +75 -0
  25. torch_rechub/models/matching/youtube_sbc.py +98 -0
  26. torch_rechub/models/multi_task/__init__.py +5 -2
  27. torch_rechub/models/multi_task/aitm.py +83 -0
  28. torch_rechub/models/multi_task/esmm.py +19 -8
  29. torch_rechub/models/multi_task/mmoe.py +18 -12
  30. torch_rechub/models/multi_task/ple.py +41 -29
  31. torch_rechub/models/multi_task/shared_bottom.py +3 -2
  32. torch_rechub/models/ranking/__init__.py +13 -2
  33. torch_rechub/models/ranking/afm.py +65 -0
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -0
  36. torch_rechub/models/ranking/dcn.py +38 -0
  37. torch_rechub/models/ranking/dcn_v2.py +59 -0
  38. torch_rechub/models/ranking/deepffm.py +131 -0
  39. torch_rechub/models/ranking/deepfm.py +8 -7
  40. torch_rechub/models/ranking/dien.py +191 -0
  41. torch_rechub/models/ranking/din.py +31 -19
  42. torch_rechub/models/ranking/edcn.py +101 -0
  43. torch_rechub/models/ranking/fibinet.py +42 -0
  44. torch_rechub/models/ranking/widedeep.py +6 -6
  45. torch_rechub/trainers/__init__.py +4 -2
  46. torch_rechub/trainers/ctr_trainer.py +191 -0
  47. torch_rechub/trainers/match_trainer.py +239 -0
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +137 -23
  50. torch_rechub/trainers/seq_trainer.py +293 -0
  51. torch_rechub/utils/__init__.py +0 -0
  52. torch_rechub/utils/data.py +492 -0
  53. torch_rechub/utils/hstu_utils.py +198 -0
  54. torch_rechub/utils/match.py +457 -0
  55. torch_rechub/utils/mtl.py +136 -0
  56. torch_rechub/utils/onnx_export.py +353 -0
  57. torch_rechub-0.0.4.dist-info/METADATA +391 -0
  58. torch_rechub-0.0.4.dist-info/RECORD +62 -0
  59. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
  60. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
  61. torch_rechub/basic/utils.py +0 -168
  62. torch_rechub/trainers/trainer.py +0 -111
  63. torch_rechub-0.0.1.dist-info/METADATA +0 -105
  64. torch_rechub-0.0.1.dist-info/RECORD +0 -26
  65. torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,353 @@
1
+ """
2
+ ONNX Export Utilities for Torch-RecHub models.
3
+
4
+ This module provides non-invasive ONNX export functionality for recommendation models.
5
+ It uses reflection to extract feature information from models and wraps dict-input models
6
+ to be compatible with ONNX's positional argument requirements.
7
+
8
+ Date: 2024
9
+ References:
10
+ - PyTorch ONNX Export: https://pytorch.org/docs/stable/onnx.html
11
+ - ONNX Runtime: https://onnxruntime.ai/docs/
12
+ Authors: Torch-RecHub Contributors
13
+ """
14
+
15
+ import os
16
+ import warnings
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
23
+
24
+
25
+ class ONNXWrapper(nn.Module):
26
+ """Wraps a dict-input model to accept positional arguments for ONNX compatibility.
27
+
28
+ ONNX does not support dict as input, so this wrapper converts positional arguments
29
+ back to dict format before passing to the original model.
30
+
31
+ Args:
32
+ model: The original PyTorch model that accepts dict input.
33
+ input_names: Ordered list of feature names corresponding to input positions.
34
+ mode: Optional mode for dual-tower models ("user" or "item").
35
+
36
+ Example:
37
+ >>> wrapper = ONNXWrapper(dssm_model, ["user_id", "movie_id", "hist_movie_id"])
38
+ >>> # Now can call: wrapper(user_id_tensor, movie_id_tensor, hist_tensor)
39
+ """
40
+
41
+ def __init__(self, model: nn.Module, input_names: List[str], mode: Optional[str] = None):
42
+ super().__init__()
43
+ self.model = model
44
+ self.input_names = input_names
45
+ self._original_mode = getattr(model, 'mode', None)
46
+
47
+ # Set mode for dual-tower models
48
+ if mode is not None and hasattr(model, 'mode'):
49
+ model.mode = mode
50
+
51
+ def forward(self, *args) -> torch.Tensor:
52
+ """Convert positional args to dict and call original model."""
53
+ if len(args) != len(self.input_names):
54
+ raise ValueError(f"Expected {len(self.input_names)} inputs, got {len(args)}. "
55
+ f"Expected names: {self.input_names}")
56
+ x_dict = {name: arg for name, arg in zip(self.input_names, args)}
57
+ return self.model(x_dict)
58
+
59
+ def restore_mode(self):
60
+ """Restore the original mode of the model."""
61
+ if hasattr(self.model, 'mode'):
62
+ self.model.mode = self._original_mode
63
+
64
+
65
+ def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
66
+ """Extract feature information from a model using reflection.
67
+
68
+ This function inspects model attributes to find feature lists without
69
+ modifying the model code. Supports various model architectures.
70
+
71
+ Args:
72
+ model: The recommendation model to inspect.
73
+
74
+ Returns:
75
+ Dict containing:
76
+ - 'features': List of unique Feature objects
77
+ - 'input_names': List of feature names in order
78
+ - 'input_types': Dict mapping feature name to feature type
79
+ - 'user_features': List of user-side features (for dual-tower models)
80
+ - 'item_features': List of item-side features (for dual-tower models)
81
+ """
82
+ # Common feature attribute names across different model types
83
+ feature_attrs = [
84
+ 'features', # MMOE, DCN, etc.
85
+ 'deep_features', # DeepFM, WideDeep
86
+ 'fm_features', # DeepFM
87
+ 'wide_features', # WideDeep
88
+ 'linear_features', # DeepFFM
89
+ 'cross_features', # DeepFFM
90
+ 'user_features', # DSSM, YoutubeDNN, MIND
91
+ 'item_features', # DSSM, YoutubeDNN, MIND
92
+ 'history_features', # DIN, MIND
93
+ 'target_features', # DIN
94
+ 'neg_item_feature', # YoutubeDNN, MIND
95
+ ]
96
+
97
+ all_features = []
98
+ user_features = []
99
+ item_features = []
100
+
101
+ for attr in feature_attrs:
102
+ if hasattr(model, attr):
103
+ feat_list = getattr(model, attr)
104
+ if isinstance(feat_list, list) and len(feat_list) > 0:
105
+ all_features.extend(feat_list)
106
+ # Track user/item features for dual-tower models
107
+ if 'user' in attr or 'history' in attr:
108
+ user_features.extend(feat_list)
109
+ elif 'item' in attr:
110
+ item_features.extend(feat_list)
111
+
112
+ # Deduplicate features by name while preserving order
113
+ seen = set()
114
+ unique_features = []
115
+ for f in all_features:
116
+ if hasattr(f, 'name') and f.name not in seen:
117
+ seen.add(f.name)
118
+ unique_features.append(f)
119
+
120
+ # Deduplicate user/item features
121
+ seen_user = set()
122
+ unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
123
+ seen_item = set()
124
+ unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
125
+
126
+ return {
127
+ 'features': unique_features,
128
+ 'input_names': [f.name for f in unique_features],
129
+ 'input_types': {
130
+ f.name: type(f).__name__ for f in unique_features
131
+ },
132
+ 'user_features': unique_user,
133
+ 'item_features': unique_item,
134
+ }
135
+
136
+
137
+ def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
138
+ """Generate dummy input tensors for ONNX export based on feature definitions.
139
+
140
+ Args:
141
+ features: List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
142
+ batch_size: Batch size for dummy input (default: 2).
143
+ seq_length: Sequence length for SequenceFeature (default: 10).
144
+ device: Device to create tensors on (default: 'cpu').
145
+
146
+ Returns:
147
+ Tuple of tensors in the order of input features.
148
+
149
+ Example:
150
+ >>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
151
+ >>> dummy = generate_dummy_input(features, batch_size=4)
152
+ >>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
153
+ """
154
+ inputs = []
155
+ for feat in features:
156
+ if isinstance(feat, SequenceFeature):
157
+ # Sequence features have shape [batch_size, seq_length]
158
+ tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
159
+ elif isinstance(feat, SparseFeature):
160
+ # Sparse features have shape [batch_size]
161
+ tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
162
+ elif isinstance(feat, DenseFeature):
163
+ # Dense features have shape [batch_size, embed_dim]
164
+ tensor = torch.randn(batch_size, feat.embed_dim, device=device)
165
+ else:
166
+ raise TypeError(f"Unsupported feature type: {type(feat)}")
167
+ inputs.append(tensor)
168
+ return tuple(inputs)
169
+
170
+
171
+ def generate_dynamic_axes(input_names: List[str], output_names: List[str] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: List[str] = None) -> Dict[str, Dict[int, str]]:
172
+ """Generate dynamic axes configuration for ONNX export.
173
+
174
+ Args:
175
+ input_names: List of input tensor names.
176
+ output_names: List of output tensor names (default: ["output"]).
177
+ batch_dim: Dimension index for batch size (default: 0).
178
+ include_seq_dim: Whether to include sequence dimension as dynamic (default: True).
179
+ seq_features: List of feature names that are sequences (default: auto-detect).
180
+
181
+ Returns:
182
+ Dynamic axes dict for torch.onnx.export.
183
+ """
184
+ if output_names is None:
185
+ output_names = ["output"]
186
+
187
+ dynamic_axes = {}
188
+
189
+ # Input axes
190
+ for name in input_names:
191
+ dynamic_axes[name] = {batch_dim: "batch_size"}
192
+ # Add sequence dimension for sequence features
193
+ if include_seq_dim and seq_features and name in seq_features:
194
+ dynamic_axes[name][1] = "seq_length"
195
+
196
+ # Output axes
197
+ for name in output_names:
198
+ dynamic_axes[name] = {batch_dim: "batch_size"}
199
+
200
+ return dynamic_axes
201
+
202
+
203
+ class ONNXExporter:
204
+ """Main class for exporting Torch-RecHub models to ONNX format.
205
+
206
+ This exporter handles the complexity of converting dict-input models to ONNX
207
+ by automatically extracting feature information and wrapping the model.
208
+
209
+ Args:
210
+ model: The PyTorch recommendation model to export.
211
+ device: Device for export operations (default: 'cpu').
212
+
213
+ Example:
214
+ >>> exporter = ONNXExporter(deepfm_model)
215
+ >>> exporter.export("model.onnx")
216
+
217
+ >>> # For dual-tower models
218
+ >>> exporter = ONNXExporter(dssm_model)
219
+ >>> exporter.export("user_tower.onnx", mode="user")
220
+ >>> exporter.export("item_tower.onnx", mode="item")
221
+ """
222
+
223
+ def __init__(self, model: nn.Module, device: str = 'cpu'):
224
+ self.model = model
225
+ self.device = device
226
+ self.feature_info = extract_feature_info(model)
227
+
228
+ def export(
229
+ self,
230
+ output_path: str,
231
+ mode: Optional[str] = None,
232
+ dummy_input: Optional[Dict[str,
233
+ torch.Tensor]] = None,
234
+ batch_size: int = 2,
235
+ seq_length: int = 10,
236
+ opset_version: int = 14,
237
+ dynamic_batch: bool = True,
238
+ verbose: bool = False
239
+ ) -> bool:
240
+ """Export the model to ONNX format.
241
+
242
+ Args:
243
+ output_path: Path to save the ONNX model.
244
+ mode: For dual-tower models, specify "user" or "item" to export
245
+ only that tower. None exports the full model.
246
+ dummy_input: Optional dict of example inputs. If not provided,
247
+ dummy inputs will be generated automatically.
248
+ batch_size: Batch size for generated dummy input (default: 2).
249
+ seq_length: Sequence length for SequenceFeature (default: 10).
250
+ opset_version: ONNX opset version (default: 14).
251
+ dynamic_batch: Whether to enable dynamic batch size (default: True).
252
+ verbose: Whether to print export details (default: False).
253
+
254
+ Returns:
255
+ True if export succeeded, False otherwise.
256
+
257
+ Raises:
258
+ RuntimeError: If ONNX export fails.
259
+ """
260
+ self.model.eval()
261
+ self.model.to(self.device)
262
+
263
+ # Determine which features to use based on mode
264
+ if mode == "user":
265
+ features = self.feature_info['user_features']
266
+ if not features:
267
+ raise ValueError("No user features found in model for mode='user'")
268
+ elif mode == "item":
269
+ features = self.feature_info['item_features']
270
+ if not features:
271
+ raise ValueError("No item features found in model for mode='item'")
272
+ else:
273
+ features = self.feature_info['features']
274
+
275
+ input_names = [f.name for f in features]
276
+
277
+ # Create wrapped model
278
+ wrapper = ONNXWrapper(self.model, input_names, mode=mode)
279
+ wrapper.eval()
280
+
281
+ # Generate or use provided dummy input
282
+ if dummy_input is not None:
283
+ dummy_tuple = tuple(dummy_input[name].to(self.device) for name in input_names)
284
+ else:
285
+ dummy_tuple = generate_dummy_input(features, batch_size=batch_size, seq_length=seq_length, device=self.device)
286
+
287
+ # Configure dynamic axes
288
+ dynamic_axes = None
289
+ if dynamic_batch:
290
+ seq_feature_names = [f.name for f in features if isinstance(f, SequenceFeature)]
291
+ dynamic_axes = generate_dynamic_axes(input_names=input_names, output_names=["output"], seq_features=seq_feature_names)
292
+
293
+ # Ensure output directory exists
294
+ output_dir = os.path.dirname(output_path)
295
+ if output_dir and not os.path.exists(output_dir):
296
+ os.makedirs(output_dir)
297
+
298
+ try:
299
+ with torch.no_grad():
300
+ torch.onnx.export(
301
+ wrapper,
302
+ dummy_tuple,
303
+ output_path,
304
+ input_names=input_names,
305
+ output_names=["output"],
306
+ dynamic_axes=dynamic_axes,
307
+ opset_version=opset_version,
308
+ do_constant_folding=True,
309
+ verbose=verbose,
310
+ dynamo=False # Use legacy exporter for dynamic_axes support
311
+ )
312
+
313
+ if verbose:
314
+ print(f"Successfully exported ONNX model to: {output_path}")
315
+ print(f" Input names: {input_names}")
316
+ print(f" Opset version: {opset_version}")
317
+ print(f" Dynamic batch: {dynamic_batch}")
318
+
319
+ return True
320
+
321
+ except Exception as e:
322
+ warnings.warn(f"ONNX export failed: {str(e)}")
323
+ raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
324
+ finally:
325
+ # Restore original mode
326
+ wrapper.restore_mode()
327
+
328
+ def get_input_info(self, mode: Optional[str] = None) -> Dict[str, Any]:
329
+ """Get information about model inputs.
330
+
331
+ Args:
332
+ mode: For dual-tower models, "user" or "item".
333
+
334
+ Returns:
335
+ Dict with input names, types, and shapes.
336
+ """
337
+ if mode == "user":
338
+ features = self.feature_info['user_features']
339
+ elif mode == "item":
340
+ features = self.feature_info['item_features']
341
+ else:
342
+ features = self.feature_info['features']
343
+
344
+ info = []
345
+ for f in features:
346
+ feat_info = {'name': f.name, 'type': type(f).__name__, 'embed_dim': f.embed_dim}
347
+ if hasattr(f, 'vocab_size'):
348
+ feat_info['vocab_size'] = f.vocab_size
349
+ if hasattr(f, 'pooling'):
350
+ feat_info['pooling'] = f.pooling
351
+ info.append(feat_info)
352
+
353
+ return {'mode': mode, 'inputs': info, 'input_names': [f.name for f in features]}