torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +3 -1
- torch_rechub/basic/callback.py +2 -2
- torch_rechub/basic/features.py +38 -8
- torch_rechub/basic/initializers.py +92 -0
- torch_rechub/basic/layers.py +800 -46
- torch_rechub/basic/loss_func.py +223 -0
- torch_rechub/basic/metaoptimizer.py +76 -0
- torch_rechub/basic/metric.py +251 -0
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -0
- torch_rechub/models/matching/comirec.py +193 -0
- torch_rechub/models/matching/dssm.py +72 -0
- torch_rechub/models/matching/dssm_facebook.py +77 -0
- torch_rechub/models/matching/dssm_senet.py +87 -0
- torch_rechub/models/matching/gru4rec.py +85 -0
- torch_rechub/models/matching/mind.py +103 -0
- torch_rechub/models/matching/narm.py +82 -0
- torch_rechub/models/matching/sasrec.py +143 -0
- torch_rechub/models/matching/sine.py +148 -0
- torch_rechub/models/matching/stamp.py +81 -0
- torch_rechub/models/matching/youtube_dnn.py +75 -0
- torch_rechub/models/matching/youtube_sbc.py +98 -0
- torch_rechub/models/multi_task/__init__.py +5 -2
- torch_rechub/models/multi_task/aitm.py +83 -0
- torch_rechub/models/multi_task/esmm.py +19 -8
- torch_rechub/models/multi_task/mmoe.py +18 -12
- torch_rechub/models/multi_task/ple.py +41 -29
- torch_rechub/models/multi_task/shared_bottom.py +3 -2
- torch_rechub/models/ranking/__init__.py +13 -2
- torch_rechub/models/ranking/afm.py +65 -0
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -0
- torch_rechub/models/ranking/dcn.py +38 -0
- torch_rechub/models/ranking/dcn_v2.py +59 -0
- torch_rechub/models/ranking/deepffm.py +131 -0
- torch_rechub/models/ranking/deepfm.py +8 -7
- torch_rechub/models/ranking/dien.py +191 -0
- torch_rechub/models/ranking/din.py +31 -19
- torch_rechub/models/ranking/edcn.py +101 -0
- torch_rechub/models/ranking/fibinet.py +42 -0
- torch_rechub/models/ranking/widedeep.py +6 -6
- torch_rechub/trainers/__init__.py +4 -2
- torch_rechub/trainers/ctr_trainer.py +191 -0
- torch_rechub/trainers/match_trainer.py +239 -0
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +137 -23
- torch_rechub/trainers/seq_trainer.py +293 -0
- torch_rechub/utils/__init__.py +0 -0
- torch_rechub/utils/data.py +492 -0
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -0
- torch_rechub/utils/mtl.py +136 -0
- torch_rechub/utils/onnx_export.py +353 -0
- torch_rechub-0.0.4.dist-info/METADATA +391 -0
- torch_rechub-0.0.4.dist-info/RECORD +62 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
- torch_rechub/basic/utils.py +0 -168
- torch_rechub/trainers/trainer.py +0 -111
- torch_rechub-0.0.1.dist-info/METADATA +0 -105
- torch_rechub-0.0.1.dist-info/RECORD +0 -26
- torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ONNX Export Utilities for Torch-RecHub models.
|
|
3
|
+
|
|
4
|
+
This module provides non-invasive ONNX export functionality for recommendation models.
|
|
5
|
+
It uses reflection to extract feature information from models and wraps dict-input models
|
|
6
|
+
to be compatible with ONNX's positional argument requirements.
|
|
7
|
+
|
|
8
|
+
Date: 2024
|
|
9
|
+
References:
|
|
10
|
+
- PyTorch ONNX Export: https://pytorch.org/docs/stable/onnx.html
|
|
11
|
+
- ONNX Runtime: https://onnxruntime.ai/docs/
|
|
12
|
+
Authors: Torch-RecHub Contributors
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import warnings
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
|
|
22
|
+
from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ONNXWrapper(nn.Module):
|
|
26
|
+
"""Wraps a dict-input model to accept positional arguments for ONNX compatibility.
|
|
27
|
+
|
|
28
|
+
ONNX does not support dict as input, so this wrapper converts positional arguments
|
|
29
|
+
back to dict format before passing to the original model.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model: The original PyTorch model that accepts dict input.
|
|
33
|
+
input_names: Ordered list of feature names corresponding to input positions.
|
|
34
|
+
mode: Optional mode for dual-tower models ("user" or "item").
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> wrapper = ONNXWrapper(dssm_model, ["user_id", "movie_id", "hist_movie_id"])
|
|
38
|
+
>>> # Now can call: wrapper(user_id_tensor, movie_id_tensor, hist_tensor)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, model: nn.Module, input_names: List[str], mode: Optional[str] = None):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.model = model
|
|
44
|
+
self.input_names = input_names
|
|
45
|
+
self._original_mode = getattr(model, 'mode', None)
|
|
46
|
+
|
|
47
|
+
# Set mode for dual-tower models
|
|
48
|
+
if mode is not None and hasattr(model, 'mode'):
|
|
49
|
+
model.mode = mode
|
|
50
|
+
|
|
51
|
+
def forward(self, *args) -> torch.Tensor:
|
|
52
|
+
"""Convert positional args to dict and call original model."""
|
|
53
|
+
if len(args) != len(self.input_names):
|
|
54
|
+
raise ValueError(f"Expected {len(self.input_names)} inputs, got {len(args)}. "
|
|
55
|
+
f"Expected names: {self.input_names}")
|
|
56
|
+
x_dict = {name: arg for name, arg in zip(self.input_names, args)}
|
|
57
|
+
return self.model(x_dict)
|
|
58
|
+
|
|
59
|
+
def restore_mode(self):
|
|
60
|
+
"""Restore the original mode of the model."""
|
|
61
|
+
if hasattr(self.model, 'mode'):
|
|
62
|
+
self.model.mode = self._original_mode
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
|
|
66
|
+
"""Extract feature information from a model using reflection.
|
|
67
|
+
|
|
68
|
+
This function inspects model attributes to find feature lists without
|
|
69
|
+
modifying the model code. Supports various model architectures.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
model: The recommendation model to inspect.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Dict containing:
|
|
76
|
+
- 'features': List of unique Feature objects
|
|
77
|
+
- 'input_names': List of feature names in order
|
|
78
|
+
- 'input_types': Dict mapping feature name to feature type
|
|
79
|
+
- 'user_features': List of user-side features (for dual-tower models)
|
|
80
|
+
- 'item_features': List of item-side features (for dual-tower models)
|
|
81
|
+
"""
|
|
82
|
+
# Common feature attribute names across different model types
|
|
83
|
+
feature_attrs = [
|
|
84
|
+
'features', # MMOE, DCN, etc.
|
|
85
|
+
'deep_features', # DeepFM, WideDeep
|
|
86
|
+
'fm_features', # DeepFM
|
|
87
|
+
'wide_features', # WideDeep
|
|
88
|
+
'linear_features', # DeepFFM
|
|
89
|
+
'cross_features', # DeepFFM
|
|
90
|
+
'user_features', # DSSM, YoutubeDNN, MIND
|
|
91
|
+
'item_features', # DSSM, YoutubeDNN, MIND
|
|
92
|
+
'history_features', # DIN, MIND
|
|
93
|
+
'target_features', # DIN
|
|
94
|
+
'neg_item_feature', # YoutubeDNN, MIND
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
all_features = []
|
|
98
|
+
user_features = []
|
|
99
|
+
item_features = []
|
|
100
|
+
|
|
101
|
+
for attr in feature_attrs:
|
|
102
|
+
if hasattr(model, attr):
|
|
103
|
+
feat_list = getattr(model, attr)
|
|
104
|
+
if isinstance(feat_list, list) and len(feat_list) > 0:
|
|
105
|
+
all_features.extend(feat_list)
|
|
106
|
+
# Track user/item features for dual-tower models
|
|
107
|
+
if 'user' in attr or 'history' in attr:
|
|
108
|
+
user_features.extend(feat_list)
|
|
109
|
+
elif 'item' in attr:
|
|
110
|
+
item_features.extend(feat_list)
|
|
111
|
+
|
|
112
|
+
# Deduplicate features by name while preserving order
|
|
113
|
+
seen = set()
|
|
114
|
+
unique_features = []
|
|
115
|
+
for f in all_features:
|
|
116
|
+
if hasattr(f, 'name') and f.name not in seen:
|
|
117
|
+
seen.add(f.name)
|
|
118
|
+
unique_features.append(f)
|
|
119
|
+
|
|
120
|
+
# Deduplicate user/item features
|
|
121
|
+
seen_user = set()
|
|
122
|
+
unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
|
|
123
|
+
seen_item = set()
|
|
124
|
+
unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
|
|
125
|
+
|
|
126
|
+
return {
|
|
127
|
+
'features': unique_features,
|
|
128
|
+
'input_names': [f.name for f in unique_features],
|
|
129
|
+
'input_types': {
|
|
130
|
+
f.name: type(f).__name__ for f in unique_features
|
|
131
|
+
},
|
|
132
|
+
'user_features': unique_user,
|
|
133
|
+
'item_features': unique_item,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
|
|
138
|
+
"""Generate dummy input tensors for ONNX export based on feature definitions.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
features: List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
|
|
142
|
+
batch_size: Batch size for dummy input (default: 2).
|
|
143
|
+
seq_length: Sequence length for SequenceFeature (default: 10).
|
|
144
|
+
device: Device to create tensors on (default: 'cpu').
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Tuple of tensors in the order of input features.
|
|
148
|
+
|
|
149
|
+
Example:
|
|
150
|
+
>>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
|
|
151
|
+
>>> dummy = generate_dummy_input(features, batch_size=4)
|
|
152
|
+
>>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
|
|
153
|
+
"""
|
|
154
|
+
inputs = []
|
|
155
|
+
for feat in features:
|
|
156
|
+
if isinstance(feat, SequenceFeature):
|
|
157
|
+
# Sequence features have shape [batch_size, seq_length]
|
|
158
|
+
tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
|
|
159
|
+
elif isinstance(feat, SparseFeature):
|
|
160
|
+
# Sparse features have shape [batch_size]
|
|
161
|
+
tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
|
|
162
|
+
elif isinstance(feat, DenseFeature):
|
|
163
|
+
# Dense features have shape [batch_size, embed_dim]
|
|
164
|
+
tensor = torch.randn(batch_size, feat.embed_dim, device=device)
|
|
165
|
+
else:
|
|
166
|
+
raise TypeError(f"Unsupported feature type: {type(feat)}")
|
|
167
|
+
inputs.append(tensor)
|
|
168
|
+
return tuple(inputs)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def generate_dynamic_axes(input_names: List[str], output_names: List[str] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: List[str] = None) -> Dict[str, Dict[int, str]]:
|
|
172
|
+
"""Generate dynamic axes configuration for ONNX export.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
input_names: List of input tensor names.
|
|
176
|
+
output_names: List of output tensor names (default: ["output"]).
|
|
177
|
+
batch_dim: Dimension index for batch size (default: 0).
|
|
178
|
+
include_seq_dim: Whether to include sequence dimension as dynamic (default: True).
|
|
179
|
+
seq_features: List of feature names that are sequences (default: auto-detect).
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Dynamic axes dict for torch.onnx.export.
|
|
183
|
+
"""
|
|
184
|
+
if output_names is None:
|
|
185
|
+
output_names = ["output"]
|
|
186
|
+
|
|
187
|
+
dynamic_axes = {}
|
|
188
|
+
|
|
189
|
+
# Input axes
|
|
190
|
+
for name in input_names:
|
|
191
|
+
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
192
|
+
# Add sequence dimension for sequence features
|
|
193
|
+
if include_seq_dim and seq_features and name in seq_features:
|
|
194
|
+
dynamic_axes[name][1] = "seq_length"
|
|
195
|
+
|
|
196
|
+
# Output axes
|
|
197
|
+
for name in output_names:
|
|
198
|
+
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
199
|
+
|
|
200
|
+
return dynamic_axes
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class ONNXExporter:
|
|
204
|
+
"""Main class for exporting Torch-RecHub models to ONNX format.
|
|
205
|
+
|
|
206
|
+
This exporter handles the complexity of converting dict-input models to ONNX
|
|
207
|
+
by automatically extracting feature information and wrapping the model.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
model: The PyTorch recommendation model to export.
|
|
211
|
+
device: Device for export operations (default: 'cpu').
|
|
212
|
+
|
|
213
|
+
Example:
|
|
214
|
+
>>> exporter = ONNXExporter(deepfm_model)
|
|
215
|
+
>>> exporter.export("model.onnx")
|
|
216
|
+
|
|
217
|
+
>>> # For dual-tower models
|
|
218
|
+
>>> exporter = ONNXExporter(dssm_model)
|
|
219
|
+
>>> exporter.export("user_tower.onnx", mode="user")
|
|
220
|
+
>>> exporter.export("item_tower.onnx", mode="item")
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
def __init__(self, model: nn.Module, device: str = 'cpu'):
|
|
224
|
+
self.model = model
|
|
225
|
+
self.device = device
|
|
226
|
+
self.feature_info = extract_feature_info(model)
|
|
227
|
+
|
|
228
|
+
def export(
|
|
229
|
+
self,
|
|
230
|
+
output_path: str,
|
|
231
|
+
mode: Optional[str] = None,
|
|
232
|
+
dummy_input: Optional[Dict[str,
|
|
233
|
+
torch.Tensor]] = None,
|
|
234
|
+
batch_size: int = 2,
|
|
235
|
+
seq_length: int = 10,
|
|
236
|
+
opset_version: int = 14,
|
|
237
|
+
dynamic_batch: bool = True,
|
|
238
|
+
verbose: bool = False
|
|
239
|
+
) -> bool:
|
|
240
|
+
"""Export the model to ONNX format.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
output_path: Path to save the ONNX model.
|
|
244
|
+
mode: For dual-tower models, specify "user" or "item" to export
|
|
245
|
+
only that tower. None exports the full model.
|
|
246
|
+
dummy_input: Optional dict of example inputs. If not provided,
|
|
247
|
+
dummy inputs will be generated automatically.
|
|
248
|
+
batch_size: Batch size for generated dummy input (default: 2).
|
|
249
|
+
seq_length: Sequence length for SequenceFeature (default: 10).
|
|
250
|
+
opset_version: ONNX opset version (default: 14).
|
|
251
|
+
dynamic_batch: Whether to enable dynamic batch size (default: True).
|
|
252
|
+
verbose: Whether to print export details (default: False).
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
True if export succeeded, False otherwise.
|
|
256
|
+
|
|
257
|
+
Raises:
|
|
258
|
+
RuntimeError: If ONNX export fails.
|
|
259
|
+
"""
|
|
260
|
+
self.model.eval()
|
|
261
|
+
self.model.to(self.device)
|
|
262
|
+
|
|
263
|
+
# Determine which features to use based on mode
|
|
264
|
+
if mode == "user":
|
|
265
|
+
features = self.feature_info['user_features']
|
|
266
|
+
if not features:
|
|
267
|
+
raise ValueError("No user features found in model for mode='user'")
|
|
268
|
+
elif mode == "item":
|
|
269
|
+
features = self.feature_info['item_features']
|
|
270
|
+
if not features:
|
|
271
|
+
raise ValueError("No item features found in model for mode='item'")
|
|
272
|
+
else:
|
|
273
|
+
features = self.feature_info['features']
|
|
274
|
+
|
|
275
|
+
input_names = [f.name for f in features]
|
|
276
|
+
|
|
277
|
+
# Create wrapped model
|
|
278
|
+
wrapper = ONNXWrapper(self.model, input_names, mode=mode)
|
|
279
|
+
wrapper.eval()
|
|
280
|
+
|
|
281
|
+
# Generate or use provided dummy input
|
|
282
|
+
if dummy_input is not None:
|
|
283
|
+
dummy_tuple = tuple(dummy_input[name].to(self.device) for name in input_names)
|
|
284
|
+
else:
|
|
285
|
+
dummy_tuple = generate_dummy_input(features, batch_size=batch_size, seq_length=seq_length, device=self.device)
|
|
286
|
+
|
|
287
|
+
# Configure dynamic axes
|
|
288
|
+
dynamic_axes = None
|
|
289
|
+
if dynamic_batch:
|
|
290
|
+
seq_feature_names = [f.name for f in features if isinstance(f, SequenceFeature)]
|
|
291
|
+
dynamic_axes = generate_dynamic_axes(input_names=input_names, output_names=["output"], seq_features=seq_feature_names)
|
|
292
|
+
|
|
293
|
+
# Ensure output directory exists
|
|
294
|
+
output_dir = os.path.dirname(output_path)
|
|
295
|
+
if output_dir and not os.path.exists(output_dir):
|
|
296
|
+
os.makedirs(output_dir)
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
with torch.no_grad():
|
|
300
|
+
torch.onnx.export(
|
|
301
|
+
wrapper,
|
|
302
|
+
dummy_tuple,
|
|
303
|
+
output_path,
|
|
304
|
+
input_names=input_names,
|
|
305
|
+
output_names=["output"],
|
|
306
|
+
dynamic_axes=dynamic_axes,
|
|
307
|
+
opset_version=opset_version,
|
|
308
|
+
do_constant_folding=True,
|
|
309
|
+
verbose=verbose,
|
|
310
|
+
dynamo=False # Use legacy exporter for dynamic_axes support
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
if verbose:
|
|
314
|
+
print(f"Successfully exported ONNX model to: {output_path}")
|
|
315
|
+
print(f" Input names: {input_names}")
|
|
316
|
+
print(f" Opset version: {opset_version}")
|
|
317
|
+
print(f" Dynamic batch: {dynamic_batch}")
|
|
318
|
+
|
|
319
|
+
return True
|
|
320
|
+
|
|
321
|
+
except Exception as e:
|
|
322
|
+
warnings.warn(f"ONNX export failed: {str(e)}")
|
|
323
|
+
raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
|
|
324
|
+
finally:
|
|
325
|
+
# Restore original mode
|
|
326
|
+
wrapper.restore_mode()
|
|
327
|
+
|
|
328
|
+
def get_input_info(self, mode: Optional[str] = None) -> Dict[str, Any]:
|
|
329
|
+
"""Get information about model inputs.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
mode: For dual-tower models, "user" or "item".
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Dict with input names, types, and shapes.
|
|
336
|
+
"""
|
|
337
|
+
if mode == "user":
|
|
338
|
+
features = self.feature_info['user_features']
|
|
339
|
+
elif mode == "item":
|
|
340
|
+
features = self.feature_info['item_features']
|
|
341
|
+
else:
|
|
342
|
+
features = self.feature_info['features']
|
|
343
|
+
|
|
344
|
+
info = []
|
|
345
|
+
for f in features:
|
|
346
|
+
feat_info = {'name': f.name, 'type': type(f).__name__, 'embed_dim': f.embed_dim}
|
|
347
|
+
if hasattr(f, 'vocab_size'):
|
|
348
|
+
feat_info['vocab_size'] = f.vocab_size
|
|
349
|
+
if hasattr(f, 'pooling'):
|
|
350
|
+
feat_info['pooling'] = f.pooling
|
|
351
|
+
info.append(feat_info)
|
|
352
|
+
|
|
353
|
+
return {'mode': mode, 'inputs': info, 'input_names': [f.name for f in features]}
|