torch-rechub 0.0.6__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/layers.py +228 -159
- torch_rechub/basic/loss_func.py +62 -47
- torch_rechub/data/dataset.py +18 -31
- torch_rechub/models/generative/hstu.py +48 -33
- torch_rechub/serving/__init__.py +50 -0
- torch_rechub/serving/annoy.py +133 -0
- torch_rechub/serving/base.py +107 -0
- torch_rechub/serving/faiss.py +154 -0
- torch_rechub/serving/milvus.py +215 -0
- torch_rechub/trainers/ctr_trainer.py +12 -2
- torch_rechub/trainers/match_trainer.py +13 -2
- torch_rechub/trainers/mtl_trainer.py +12 -2
- torch_rechub/trainers/seq_trainer.py +34 -15
- torch_rechub/types.py +5 -0
- torch_rechub/utils/data.py +191 -145
- torch_rechub/utils/hstu_utils.py +87 -76
- torch_rechub/utils/model_utils.py +10 -12
- torch_rechub/utils/onnx_export.py +98 -45
- torch_rechub/utils/quantization.py +128 -0
- torch_rechub/utils/visualization.py +4 -12
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/METADATA +34 -18
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/RECORD +24 -18
- torch_rechub/trainers/matching.md +0 -3
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -26,32 +26,30 @@ except ImportError:
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
|
|
29
|
-
"""Extract feature information from a torch-rechub model
|
|
30
|
-
|
|
31
|
-
This function inspects model attributes to find feature lists without
|
|
32
|
-
modifying the model code. Supports various model architectures.
|
|
29
|
+
"""Extract feature information from a torch-rechub model via reflection.
|
|
33
30
|
|
|
34
31
|
Parameters
|
|
35
32
|
----------
|
|
36
33
|
model : nn.Module
|
|
37
|
-
|
|
34
|
+
Model to inspect.
|
|
38
35
|
|
|
39
36
|
Returns
|
|
40
37
|
-------
|
|
41
38
|
dict
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
39
|
+
{
|
|
40
|
+
'features': list of unique Feature objects,
|
|
41
|
+
'input_names': ordered feature names,
|
|
42
|
+
'input_types': map name -> feature type,
|
|
43
|
+
'user_features': user-side features (dual-tower),
|
|
44
|
+
'item_features': item-side features (dual-tower),
|
|
45
|
+
}
|
|
48
46
|
|
|
49
47
|
Examples
|
|
50
48
|
--------
|
|
51
49
|
>>> from torch_rechub.models.ranking import DeepFM
|
|
52
50
|
>>> model = DeepFM(deep_features, fm_features, mlp_params)
|
|
53
51
|
>>> info = extract_feature_info(model)
|
|
54
|
-
>>>
|
|
52
|
+
>>> info['input_names'] # ['user_id', 'item_id', ...]
|
|
55
53
|
"""
|
|
56
54
|
# Common feature attribute names across different model types
|
|
57
55
|
feature_attrs = [
|
|
@@ -12,6 +12,7 @@ References:
|
|
|
12
12
|
Authors: Torch-RecHub Contributors
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
+
import inspect
|
|
15
16
|
import os
|
|
16
17
|
import warnings
|
|
17
18
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
@@ -23,19 +24,24 @@ from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class ONNXWrapper(nn.Module):
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
ONNX
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
27
|
+
"""Wrap a dict-input model to accept positional args for ONNX.
|
|
28
|
+
|
|
29
|
+
ONNX disallows dict inputs; this wrapper maps positional args back to dict
|
|
30
|
+
before calling the original model.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
model : nn.Module
|
|
35
|
+
Original dict-input model.
|
|
36
|
+
input_names : list[str]
|
|
37
|
+
Ordered feature names matching positional inputs.
|
|
38
|
+
mode : {'user', 'item'}, optional
|
|
39
|
+
For dual-tower models, set tower mode.
|
|
40
|
+
|
|
41
|
+
Examples
|
|
42
|
+
--------
|
|
43
|
+
>>> wrapper = ONNXWrapper(dssm_model, ["user_id", "movie_id", "hist_movie_id"])
|
|
44
|
+
>>> wrapper(user_id_tensor, movie_id_tensor, hist_tensor)
|
|
39
45
|
"""
|
|
40
46
|
|
|
41
47
|
def __init__(self, model: nn.Module, input_names: List[str], mode: Optional[str] = None):
|
|
@@ -102,27 +108,49 @@ class ONNXExporter:
|
|
|
102
108
|
seq_length: int = 10,
|
|
103
109
|
opset_version: int = 14,
|
|
104
110
|
dynamic_batch: bool = True,
|
|
105
|
-
verbose: bool = False
|
|
111
|
+
verbose: bool = False,
|
|
112
|
+
onnx_export_kwargs: Optional[Dict[str,
|
|
113
|
+
Any]] = None,
|
|
106
114
|
) -> bool:
|
|
107
|
-
"""Export
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
115
|
+
"""Export model to ONNX format.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
output_path : str
|
|
120
|
+
Destination path.
|
|
121
|
+
mode : {'user', 'item'}, optional
|
|
122
|
+
For dual-tower, export specific tower; None exports full model.
|
|
123
|
+
dummy_input : dict[str, Tensor], optional
|
|
124
|
+
Example inputs; auto-generated if None.
|
|
125
|
+
batch_size : int, default=2
|
|
126
|
+
Batch size for dummy input generation.
|
|
127
|
+
seq_length : int, default=10
|
|
128
|
+
Sequence length for SequenceFeature.
|
|
129
|
+
opset_version : int, default=14
|
|
130
|
+
ONNX opset.
|
|
131
|
+
dynamic_batch : bool, default=True
|
|
132
|
+
Enable dynamic batch axes.
|
|
133
|
+
verbose : bool, default=False
|
|
134
|
+
Print export details.
|
|
135
|
+
onnx_export_kwargs : dict, optional
|
|
136
|
+
Extra keyword args forwarded to ``torch.onnx.export`` (e.g. ``operator_export_type``,
|
|
137
|
+
``keep_initializers_as_inputs``, ``do_constant_folding``).
|
|
138
|
+
Notes:
|
|
139
|
+
- If you pass keys that overlap with the explicit parameters above
|
|
140
|
+
(like ``opset_version`` / ``dynamic_axes`` / ``input_names``), this function
|
|
141
|
+
will raise a ``ValueError`` to avoid ambiguous behavior.
|
|
142
|
+
- Some kwargs (like ``dynamo``) are only available in newer PyTorch; unsupported
|
|
143
|
+
keys will be ignored for compatibility.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
bool
|
|
148
|
+
True if export succeeds.
|
|
149
|
+
|
|
150
|
+
Raises
|
|
151
|
+
------
|
|
152
|
+
RuntimeError
|
|
153
|
+
If ONNX export fails.
|
|
126
154
|
"""
|
|
127
155
|
self.model.eval()
|
|
128
156
|
self.model.to(self.device)
|
|
@@ -164,18 +192,43 @@ class ONNXExporter:
|
|
|
164
192
|
|
|
165
193
|
try:
|
|
166
194
|
with torch.no_grad():
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
195
|
+
export_kwargs: Dict[str,
|
|
196
|
+
Any] = {
|
|
197
|
+
"f": output_path,
|
|
198
|
+
"input_names": input_names,
|
|
199
|
+
"output_names": ["output"],
|
|
200
|
+
"dynamic_axes": dynamic_axes,
|
|
201
|
+
"opset_version": opset_version,
|
|
202
|
+
"do_constant_folding": True,
|
|
203
|
+
"verbose": verbose,
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
if onnx_export_kwargs:
|
|
207
|
+
# Prevent silent conflicts with explicit arguments
|
|
208
|
+
overlap = set(export_kwargs.keys()) & set(onnx_export_kwargs.keys())
|
|
209
|
+
# allow user to set 'dynamo' even if we inject it later
|
|
210
|
+
overlap.discard("dynamo")
|
|
211
|
+
if overlap:
|
|
212
|
+
raise ValueError("onnx_export_kwargs contains keys that overlap with explicit args: "
|
|
213
|
+
f"{sorted(overlap)}. Please set them via export() parameters instead.")
|
|
214
|
+
export_kwargs.update(onnx_export_kwargs)
|
|
215
|
+
|
|
216
|
+
# Auto-pick exporter:
|
|
217
|
+
# - When dynamic axes are requested, prefer legacy exporter (dynamo=False),
|
|
218
|
+
# because the dynamo exporter may not honor `dynamic_axes` consistently
|
|
219
|
+
# across torch versions.
|
|
220
|
+
# - When no dynamic axes are requested, prefer dynamo exporter (dynamo=True)
|
|
221
|
+
# for better operator coverage in newer torch.
|
|
222
|
+
#
|
|
223
|
+
# In older torch versions, 'dynamo' kwarg does not exist.
|
|
224
|
+
sig = inspect.signature(torch.onnx.export)
|
|
225
|
+
if "dynamo" in sig.parameters:
|
|
226
|
+
if "dynamo" not in export_kwargs:
|
|
227
|
+
export_kwargs["dynamo"] = False if dynamic_axes is not None else True
|
|
228
|
+
else:
|
|
229
|
+
export_kwargs.pop("dynamo", None)
|
|
230
|
+
|
|
231
|
+
torch.onnx.export(wrapper, dummy_tuple, **export_kwargs)
|
|
179
232
|
|
|
180
233
|
if verbose:
|
|
181
234
|
print(f"Successfully exported ONNX model to: {output_path}")
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ONNX Quantization Utilities.
|
|
3
|
+
|
|
4
|
+
This module provides a lightweight API to quantize exported ONNX models:
|
|
5
|
+
- INT8 dynamic quantization (recommended for MLP-heavy rec models on CPU)
|
|
6
|
+
- FP16 conversion (recommended for GPU inference)
|
|
7
|
+
|
|
8
|
+
The functions are optional-dependency friendly:
|
|
9
|
+
- INT8 quantization requires: onnxruntime
|
|
10
|
+
- FP16 conversion requires: onnx + onnxconverter-common
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
import os
|
|
17
|
+
from typing import Any, Dict, Optional
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _ensure_parent_dir(path: str) -> None:
|
|
21
|
+
parent = os.path.dirname(os.path.abspath(path))
|
|
22
|
+
if parent and not os.path.exists(parent):
|
|
23
|
+
os.makedirs(parent, exist_ok=True)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def quantize_model(
|
|
27
|
+
input_path: str,
|
|
28
|
+
output_path: str,
|
|
29
|
+
mode: str = "int8",
|
|
30
|
+
*,
|
|
31
|
+
# INT8(dynamic) params
|
|
32
|
+
per_channel: bool = False,
|
|
33
|
+
reduce_range: bool = False,
|
|
34
|
+
weight_type: str = "qint8",
|
|
35
|
+
optimize_model: bool = False,
|
|
36
|
+
op_types_to_quantize: Optional[list[str]] = None,
|
|
37
|
+
nodes_to_quantize: Optional[list[str]] = None,
|
|
38
|
+
nodes_to_exclude: Optional[list[str]] = None,
|
|
39
|
+
extra_options: Optional[Dict[str,
|
|
40
|
+
Any]] = None,
|
|
41
|
+
# FP16 params
|
|
42
|
+
keep_io_types: bool = True,
|
|
43
|
+
) -> str:
|
|
44
|
+
"""Quantize an ONNX model.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
input_path : str
|
|
49
|
+
Input ONNX model path (FP32).
|
|
50
|
+
output_path : str
|
|
51
|
+
Output ONNX model path.
|
|
52
|
+
mode : str, default="int8"
|
|
53
|
+
Quantization mode:
|
|
54
|
+
- "int8" / "dynamic_int8": ONNX Runtime dynamic quantization (weights INT8).
|
|
55
|
+
- "fp16": convert float tensors to float16.
|
|
56
|
+
per_channel : bool, default=False
|
|
57
|
+
Enable per-channel quantization for weights (INT8).
|
|
58
|
+
reduce_range : bool, default=False
|
|
59
|
+
Use reduced quantization range (INT8), sometimes helpful on certain CPUs.
|
|
60
|
+
weight_type : {"qint8", "quint8"}, default="qint8"
|
|
61
|
+
Weight quant type for dynamic quantization.
|
|
62
|
+
optimize_model : bool, default=False
|
|
63
|
+
Run ORT graph optimization before quantization.
|
|
64
|
+
op_types_to_quantize / nodes_to_quantize / nodes_to_exclude / extra_options
|
|
65
|
+
Advanced options forwarded to ``onnxruntime.quantization.quantize_dynamic``.
|
|
66
|
+
keep_io_types : bool, default=True
|
|
67
|
+
For FP16 conversion, keep model input/output types as float32 for compatibility.
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
str
|
|
72
|
+
The output_path.
|
|
73
|
+
"""
|
|
74
|
+
mode_norm = (mode or "").strip().lower()
|
|
75
|
+
_ensure_parent_dir(output_path)
|
|
76
|
+
|
|
77
|
+
if mode_norm in ("int8", "dynamic_int8", "dynamic"):
|
|
78
|
+
try:
|
|
79
|
+
from onnxruntime.quantization import QuantType, quantize_dynamic
|
|
80
|
+
except Exception as e: # pragma: no cover
|
|
81
|
+
raise ImportError("INT8 quantization requires onnxruntime. Install with: pip install -U \"torch-rechub[onnx]\"") from e
|
|
82
|
+
|
|
83
|
+
wt = (weight_type or "").strip().lower()
|
|
84
|
+
if wt in ("qint8", "int8", "signed"):
|
|
85
|
+
qt = QuantType.QInt8
|
|
86
|
+
elif wt in ("quint8", "uint8", "unsigned"):
|
|
87
|
+
qt = QuantType.QUInt8
|
|
88
|
+
else:
|
|
89
|
+
raise ValueError("weight_type must be one of {'qint8','quint8'}")
|
|
90
|
+
|
|
91
|
+
q_kwargs: Dict[str,
|
|
92
|
+
Any] = {
|
|
93
|
+
"model_input": input_path,
|
|
94
|
+
"model_output": output_path,
|
|
95
|
+
"per_channel": per_channel,
|
|
96
|
+
"reduce_range": reduce_range,
|
|
97
|
+
"weight_type": qt,
|
|
98
|
+
"optimize_model": optimize_model,
|
|
99
|
+
"op_types_to_quantize": op_types_to_quantize,
|
|
100
|
+
"nodes_to_quantize": nodes_to_quantize,
|
|
101
|
+
"nodes_to_exclude": nodes_to_exclude,
|
|
102
|
+
"extra_options": extra_options,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
# Compatibility: different onnxruntime versions expose different kwargs.
|
|
106
|
+
sig = inspect.signature(quantize_dynamic)
|
|
107
|
+
q_kwargs = {k: v for k, v in q_kwargs.items() if k in sig.parameters and v is not None}
|
|
108
|
+
|
|
109
|
+
quantize_dynamic(**q_kwargs)
|
|
110
|
+
return output_path
|
|
111
|
+
|
|
112
|
+
if mode_norm in ("fp16", "float16"):
|
|
113
|
+
try:
|
|
114
|
+
import onnx
|
|
115
|
+
except Exception as e: # pragma: no cover
|
|
116
|
+
raise ImportError("FP16 conversion requires onnx. Install with: pip install -U \"torch-rechub[onnx]\"") from e
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
from onnxconverter_common import float16
|
|
120
|
+
except Exception as e: # pragma: no cover
|
|
121
|
+
raise ImportError("FP16 conversion requires onnxconverter-common. Install with: pip install -U onnxconverter-common") from e
|
|
122
|
+
|
|
123
|
+
model = onnx.load(input_path)
|
|
124
|
+
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=keep_io_types)
|
|
125
|
+
onnx.save(model_fp16, output_path)
|
|
126
|
+
return output_path
|
|
127
|
+
|
|
128
|
+
raise ValueError("mode must be one of {'int8','dynamic_int8','fp16'}")
|
|
@@ -44,27 +44,19 @@ def _is_jupyter_environment() -> bool:
|
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
def display_graph(graph: Any, format: str = 'png') -> Any:
|
|
47
|
-
"""Display a torchview ComputationGraph in Jupyter
|
|
48
|
-
|
|
49
|
-
This function provides a reliable way to display visualization graphs
|
|
50
|
-
in Jupyter environments, especially VSCode Jupyter.
|
|
47
|
+
"""Display a torchview ComputationGraph in Jupyter.
|
|
51
48
|
|
|
52
49
|
Parameters
|
|
53
50
|
----------
|
|
54
51
|
graph : ComputationGraph
|
|
55
|
-
|
|
52
|
+
Returned by :func:`visualize_model`.
|
|
56
53
|
format : str, default='png'
|
|
57
|
-
Output format
|
|
54
|
+
Output format; 'png' recommended for VSCode.
|
|
58
55
|
|
|
59
56
|
Returns
|
|
60
57
|
-------
|
|
61
58
|
graphviz.Digraph or None
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
Examples
|
|
65
|
-
--------
|
|
66
|
-
>>> graph = visualize_model(model, depth=4)
|
|
67
|
-
>>> display_graph(graph) # Works in VSCode Jupyter
|
|
59
|
+
Displayed graph object, or None if display fails.
|
|
68
60
|
"""
|
|
69
61
|
if not TORCHVIEW_AVAILABLE:
|
|
70
62
|
raise ImportError(f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-rechub
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
|
|
5
5
|
Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
|
|
6
6
|
Project-URL: Documentation, https://www.torch-rechub.com
|
|
@@ -28,8 +28,10 @@ Requires-Dist: scikit-learn>=0.24.0
|
|
|
28
28
|
Requires-Dist: torch>=1.10.0
|
|
29
29
|
Requires-Dist: tqdm>=4.60.0
|
|
30
30
|
Requires-Dist: transformers>=4.46.3
|
|
31
|
+
Provides-Extra: annoy
|
|
32
|
+
Requires-Dist: annoy>=1.17.2; extra == 'annoy'
|
|
31
33
|
Provides-Extra: bigdata
|
|
32
|
-
Requires-Dist: pyarrow
|
|
34
|
+
Requires-Dist: pyarrow<23,>=21; extra == 'bigdata'
|
|
33
35
|
Provides-Extra: dev
|
|
34
36
|
Requires-Dist: bandit>=1.7.0; extra == 'dev'
|
|
35
37
|
Requires-Dist: flake8>=3.8.0; extra == 'dev'
|
|
@@ -41,8 +43,13 @@ Requires-Dist: pytest-cov>=2.0; extra == 'dev'
|
|
|
41
43
|
Requires-Dist: pytest>=6.0; extra == 'dev'
|
|
42
44
|
Requires-Dist: toml>=0.10.2; extra == 'dev'
|
|
43
45
|
Requires-Dist: yapf==0.43.0; extra == 'dev'
|
|
46
|
+
Provides-Extra: faiss
|
|
47
|
+
Requires-Dist: faiss-cpu==1.13.0; extra == 'faiss'
|
|
48
|
+
Provides-Extra: milvus
|
|
49
|
+
Requires-Dist: pymilvus>=2.6.5; extra == 'milvus'
|
|
44
50
|
Provides-Extra: onnx
|
|
45
51
|
Requires-Dist: onnx>=1.14.0; extra == 'onnx'
|
|
52
|
+
Requires-Dist: onnxconverter-common>=1.14.0; extra == 'onnx'
|
|
46
53
|
Requires-Dist: onnxruntime>=1.14.0; extra == 'onnx'
|
|
47
54
|
Provides-Extra: tracking
|
|
48
55
|
Requires-Dist: swanlab>=0.1.0; extra == 'tracking'
|
|
@@ -53,9 +60,11 @@ Requires-Dist: graphviz>=0.20; extra == 'visualization'
|
|
|
53
60
|
Requires-Dist: torchview>=0.2.6; extra == 'visualization'
|
|
54
61
|
Description-Content-Type: text/markdown
|
|
55
62
|
|
|
56
|
-
|
|
63
|
+
<div align="center">
|
|
57
64
|
|
|
58
|
-
|
|
65
|
+

|
|
66
|
+
|
|
67
|
+
# Torch-RecHub: 轻量、高效、易用的 PyTorch 推荐系统框架
|
|
59
68
|
|
|
60
69
|
[](LICENSE)
|
|
61
70
|

|
|
@@ -71,21 +80,13 @@ Description-Content-Type: text/markdown
|
|
|
71
80
|
|
|
72
81
|
[English](README_en.md) | 简体中文
|
|
73
82
|
|
|
74
|
-
|
|
83
|
+

|
|
75
84
|
|
|
76
|
-
|
|
85
|
+
</div>
|
|
77
86
|
|
|
78
|
-
|
|
87
|
+
**在线文档:** https://datawhalechina.github.io/torch-rechub/zh/
|
|
79
88
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
| 特性 | Torch-RecHub | 其他框架 |
|
|
83
|
-
| ------------- | --------------------------- | ---------- |
|
|
84
|
-
| 代码行数 | **10行** 完成训练+评估+部署 | 100+ 行 |
|
|
85
|
-
| 模型覆盖 | **30+** 主流模型 | 有限 |
|
|
86
|
-
| 生成式推荐 | ✅ HSTU/HLLM (Meta 2024) | ❌ |
|
|
87
|
-
| ONNX 一键导出 | ✅ 内置支持 | 需手动适配 |
|
|
88
|
-
| 学习曲线 | 极低 | 陡峭 |
|
|
89
|
+
**Torch-RecHub** —— **10 行代码实现工业级推荐系统**。30+ 主流模型开箱即用,支持一键 ONNX 部署,让你专注于业务而非工程。
|
|
89
90
|
|
|
90
91
|
## ✨ 特性
|
|
91
92
|
|
|
@@ -102,7 +103,6 @@ Description-Content-Type: text/markdown
|
|
|
102
103
|
## 📖 目录
|
|
103
104
|
|
|
104
105
|
- [🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架](#-torch-rechub---轻量高效易用的-pytorch-推荐系统框架)
|
|
105
|
-
- [🎯 为什么选择 Torch-RecHub?](#-为什么选择-torch-rechub)
|
|
106
106
|
- [✨ 特性](#-特性)
|
|
107
107
|
- [📖 目录](#-目录)
|
|
108
108
|
- [🔧 安装](#-安装)
|
|
@@ -214,6 +214,8 @@ torch-rechub/ # 根目录
|
|
|
214
214
|
|
|
215
215
|
本框架目前支持 **30+** 主流推荐模型:
|
|
216
216
|
|
|
217
|
+
<details>
|
|
218
|
+
|
|
217
219
|
### 排序模型 (Ranking Models) - 13个
|
|
218
220
|
|
|
219
221
|
| 模型 | 论文 | 简介 |
|
|
@@ -229,7 +231,11 @@ torch-rechub/ # 根目录
|
|
|
229
231
|
| **AutoInt** | [CIKM 2019](https://arxiv.org/abs/1810.11921) | 自动特征交互学习 |
|
|
230
232
|
| **FiBiNET** | [RecSys 2019](https://arxiv.org/abs/1905.09433) | 特征重要性 + 双线性交互 |
|
|
231
233
|
| **DeepFFM** | [RecSys 2019](https://arxiv.org/abs/1611.00144) | 场感知因子分解机 |
|
|
232
|
-
| **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络
|
|
234
|
+
| **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络
|
|
235
|
+
|
|
|
236
|
+
</details>
|
|
237
|
+
|
|
238
|
+
<details>
|
|
233
239
|
|
|
234
240
|
### 召回模型 (Matching Models) - 12个
|
|
235
241
|
|
|
@@ -246,6 +252,10 @@ torch-rechub/ # 根目录
|
|
|
246
252
|
| **STAMP** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3219895) | 短期注意力记忆优先 |
|
|
247
253
|
| **ComiRec** | [KDD 2020](https://arxiv.org/abs/2005.09347) | 可控多兴趣推荐 |
|
|
248
254
|
|
|
255
|
+
</details>
|
|
256
|
+
|
|
257
|
+
<details>
|
|
258
|
+
|
|
249
259
|
### 多任务模型 (Multi-Task Models) - 5个
|
|
250
260
|
|
|
251
261
|
| 模型 | 论文 | 简介 |
|
|
@@ -256,6 +266,10 @@ torch-rechub/ # 根目录
|
|
|
256
266
|
| **AITM** | [KDD 2021](https://arxiv.org/abs/2105.08489) | 自适应信息迁移 |
|
|
257
267
|
| **SharedBottom** | - | 经典多任务共享底层 |
|
|
258
268
|
|
|
269
|
+
</details>
|
|
270
|
+
|
|
271
|
+
<details>
|
|
272
|
+
|
|
259
273
|
### 生成式推荐 (Generative Recommendation) - 2个
|
|
260
274
|
|
|
261
275
|
| 模型 | 论文 | 简介 |
|
|
@@ -263,6 +277,8 @@ torch-rechub/ # 根目录
|
|
|
263
277
|
| **HSTU** | [Meta 2024](https://arxiv.org/abs/2402.17152) | 层级序列转换单元,支撑 Meta 万亿参数推荐系统 |
|
|
264
278
|
| **HLLM** | [2024](https://arxiv.org/abs/2409.12740) | 层级大语言模型推荐,融合 LLM 语义理解能力 |
|
|
265
279
|
|
|
280
|
+
</details>
|
|
281
|
+
|
|
266
282
|
## 📊 支持的数据集
|
|
267
283
|
|
|
268
284
|
框架内置了对以下常见数据集格式的支持或提供了处理脚本:
|
|
@@ -1,21 +1,22 @@
|
|
|
1
1
|
torch_rechub/__init__.py,sha256=XUwV85oz-uIokuE9qj3nmbUQg3EY8dZcDMohlob3suw,245
|
|
2
|
+
torch_rechub/types.py,sha256=oWPz577qJO0PuSvmJsbST5lAFzuTjxWS9J7fj_IntNw,97
|
|
2
3
|
torch_rechub/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
4
|
torch_rechub/basic/activation.py,sha256=hIZDCe7cAgV3bX2UnvUrkO8pQs4iXxkQGD0J4GejbVg,1600
|
|
4
5
|
torch_rechub/basic/callback.py,sha256=ZeiDSDQAZUKmyK1AyGJCnqEJ66vwfwlX5lOyu6-h2G0,946
|
|
5
6
|
torch_rechub/basic/features.py,sha256=TLHR5EaNvIbKyKd730Qt8OlLpV0Km91nv2TMnq0HObk,3562
|
|
6
7
|
torch_rechub/basic/initializers.py,sha256=V6hprXvRexcw3vrYsf8Qp-F52fp8uzPMpa1CvkHofy8,3196
|
|
7
|
-
torch_rechub/basic/layers.py,sha256=
|
|
8
|
-
torch_rechub/basic/loss_func.py,sha256=
|
|
8
|
+
torch_rechub/basic/layers.py,sha256=0qNeoIzgcSfmlVoQkyjT6yEnLklcKmQG44wBypAn2rY,39148
|
|
9
|
+
torch_rechub/basic/loss_func.py,sha256=a-j1gan4eYUk5zstWwKeaPZ99eJkZPGWS82LNhT6Jbc,7756
|
|
9
10
|
torch_rechub/basic/metaoptimizer.py,sha256=y-oT4MV3vXnSQ5Zd_ZEHP1KClITEi3kbZa6RKjlkYw8,3093
|
|
10
11
|
torch_rechub/basic/metric.py,sha256=9JsaJJGvT6VRvsLoM2Y171CZxESsjYTofD3qnMI-bPM,8443
|
|
11
12
|
torch_rechub/basic/tracking.py,sha256=7-aoyKJxyqb8GobpjRjFsgPYWsBDOV44BYOC_vMoCto,6608
|
|
12
13
|
torch_rechub/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
14
|
torch_rechub/data/convert.py,sha256=clGFEbDSDpdZBvscWatfjtuXMZUzgy1kiEAg4w_q7VM,2241
|
|
14
|
-
torch_rechub/data/dataset.py,sha256=
|
|
15
|
+
torch_rechub/data/dataset.py,sha256=2avvPcw2KQK5Xe6F_-kjGQMRiV957pGajU5klrs1NMA,3496
|
|
15
16
|
torch_rechub/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
17
|
torch_rechub/models/generative/__init__.py,sha256=TsCdVIhOcalQwqKZKjEuNbHKyIjyclapKGNwYfFR7TM,135
|
|
17
18
|
torch_rechub/models/generative/hllm.py,sha256=6Vrp5Bh0fTFHCn7C-3EqzOyc7UunOyEY9TzAKGHrW-8,9669
|
|
18
|
-
torch_rechub/models/generative/hstu.py,sha256=
|
|
19
|
+
torch_rechub/models/generative/hstu.py,sha256=MkyodycbrZo15jr0m6mY1XrqYYhwT9ez3g-oZ3TlF5Q,7588
|
|
19
20
|
torch_rechub/models/matching/__init__.py,sha256=fjWOzJB8loPGy8rJMG-6G-NUISp7k3sD_1FdsKGw1as,471
|
|
20
21
|
torch_rechub/models/matching/comirec.py,sha256=8KB5rg2FWlZaG73CBI7_J8-J-XpjTPnblwh5OsbtAbc,9439
|
|
21
22
|
torch_rechub/models/matching/dssm.py,sha256=1Q0JYpt1h_7NWlLN5a_RbCoUSubZwpYTVEXccSn44eg,3003
|
|
@@ -48,21 +49,26 @@ torch_rechub/models/ranking/din.py,sha256=HsOCEErea3KwEiyWw4M_aX_LMC_-Sqs1C_zeRL
|
|
|
48
49
|
torch_rechub/models/ranking/edcn.py,sha256=6f_S8I6Ir16kCIU54R4EfumWfUFOND5KDKUPHMgsVU0,4997
|
|
49
50
|
torch_rechub/models/ranking/fibinet.py,sha256=fmEJ9WkO8Mn0RtK_8aRHlnQFh_jMBPO0zODoHZPWmDA,2234
|
|
50
51
|
torch_rechub/models/ranking/widedeep.py,sha256=eciRvWRBHLlctabLLS5NB7k3MnqrWXCBdpflOU6jMB0,1636
|
|
52
|
+
torch_rechub/serving/__init__.py,sha256=F2UfH7yqwqbt0aOqVqxoPk96fxIt9s6ppZ-98Hq3qkY,1456
|
|
53
|
+
torch_rechub/serving/annoy.py,sha256=G3ilENVAfKKnQUHUc53QOwqequxkf2oMnU-0Yo-WbHQ,3863
|
|
54
|
+
torch_rechub/serving/base.py,sha256=27dG_zPBbwV0ojj0lDmOeIg-WHOwV-pjqKLimVI5vmg,2989
|
|
55
|
+
torch_rechub/serving/faiss.py,sha256=kroqICeIxfZg8hPZiWZXmFtUpQSj9JLheFxorzdV3aw,4479
|
|
56
|
+
torch_rechub/serving/milvus.py,sha256=EnhD-zbtmp3KAS-lkZYFCQjXeKe7J2-LM3-iIUhLg0Y,6529
|
|
51
57
|
torch_rechub/trainers/__init__.py,sha256=NSa2DqgfE1HGDyj40YgrbtUrfBHBxNBpw57XtaAB_jE,148
|
|
52
|
-
torch_rechub/trainers/ctr_trainer.py,sha256=
|
|
53
|
-
torch_rechub/trainers/match_trainer.py,sha256=
|
|
54
|
-
torch_rechub/trainers/
|
|
55
|
-
torch_rechub/trainers/
|
|
56
|
-
torch_rechub/trainers/seq_trainer.py,sha256=pyY70kAjTWdKrnAYZynql1PPNtveYDLMB_1hbpCHa48,19217
|
|
58
|
+
torch_rechub/trainers/ctr_trainer.py,sha256=6vU2_-HCY1MBHwmT8p68rkoYFjbdFZgZ3zTyHxPIcGs,14407
|
|
59
|
+
torch_rechub/trainers/match_trainer.py,sha256=oASggXTvFd-93ltvt2uhB1TFPSYP_H-EGdA8Zurw64A,16648
|
|
60
|
+
torch_rechub/trainers/mtl_trainer.py,sha256=J8ztmZN-4f2ELruN2lAGLlC1quo9Y-yH9Yu30MXBqJE,18562
|
|
61
|
+
torch_rechub/trainers/seq_trainer.py,sha256=48s8YfY0PN5HETm0Dj09xDKrCT9S8wqykK4q1OtMTRo,20358
|
|
57
62
|
torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
58
|
-
torch_rechub/utils/data.py,sha256=
|
|
59
|
-
torch_rechub/utils/hstu_utils.py,sha256=
|
|
63
|
+
torch_rechub/utils/data.py,sha256=Qt_HpwiU6n4wikJizRflAS5acr33YJN-t1Ar86U8UIQ,19715
|
|
64
|
+
torch_rechub/utils/hstu_utils.py,sha256=QKX2V6dmbK6kwNEETSE0oEpbHz-FbIhB4PvbQC9Lx5w,5656
|
|
60
65
|
torch_rechub/utils/match.py,sha256=l9qDwJGHPP9gOQTMYoqGVdWrlhDx1F1-8UnQwDWrEyk,18143
|
|
61
|
-
torch_rechub/utils/model_utils.py,sha256=
|
|
66
|
+
torch_rechub/utils/model_utils.py,sha256=f8dx9uVCN8kfwYSJm_Mg5jZ2_gNMItPzTyccpVf_zA4,8219
|
|
62
67
|
torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,5737
|
|
63
|
-
torch_rechub/utils/onnx_export.py,sha256=
|
|
64
|
-
torch_rechub/utils/
|
|
65
|
-
torch_rechub
|
|
66
|
-
torch_rechub-0.0.
|
|
67
|
-
torch_rechub-0.0.
|
|
68
|
-
torch_rechub-0.0.
|
|
68
|
+
torch_rechub/utils/onnx_export.py,sha256=02-UI4C0ACccP4nP5moVn6tPr4SSFaKdym0aczJs_jI,10739
|
|
69
|
+
torch_rechub/utils/quantization.py,sha256=ett0VpmQz6c14-zvRuoOwctQurmQFLfF7Dj565L7iqE,4847
|
|
70
|
+
torch_rechub/utils/visualization.py,sha256=cfaq3_ZYcqxb4R7V_be-RebPAqKDedAJSwjYoUm55AU,9201
|
|
71
|
+
torch_rechub-0.2.0.dist-info/METADATA,sha256=FGmR2swqnS6uViykJd4BFHyQ2d9itA42r4t0XXkPgq8,18098
|
|
72
|
+
torch_rechub-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
73
|
+
torch_rechub-0.2.0.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
|
|
74
|
+
torch_rechub-0.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|