torch-rechub 0.0.6__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/layers.py +228 -159
- torch_rechub/basic/loss_func.py +62 -47
- torch_rechub/data/dataset.py +18 -31
- torch_rechub/models/generative/hstu.py +48 -33
- torch_rechub/serving/__init__.py +50 -0
- torch_rechub/serving/annoy.py +133 -0
- torch_rechub/serving/base.py +107 -0
- torch_rechub/serving/faiss.py +154 -0
- torch_rechub/serving/milvus.py +215 -0
- torch_rechub/trainers/ctr_trainer.py +12 -2
- torch_rechub/trainers/match_trainer.py +13 -2
- torch_rechub/trainers/mtl_trainer.py +12 -2
- torch_rechub/trainers/seq_trainer.py +34 -15
- torch_rechub/types.py +5 -0
- torch_rechub/utils/data.py +191 -145
- torch_rechub/utils/hstu_utils.py +87 -76
- torch_rechub/utils/model_utils.py +10 -12
- torch_rechub/utils/onnx_export.py +98 -45
- torch_rechub/utils/quantization.py +128 -0
- torch_rechub/utils/visualization.py +4 -12
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/METADATA +34 -18
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/RECORD +24 -18
- torch_rechub/trainers/matching.md +0 -3
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""FAISS-based vector index implementation for the retrieval stage."""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import typing as ty
|
|
5
|
+
|
|
6
|
+
import faiss
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from torch_rechub.types import FilePath
|
|
10
|
+
|
|
11
|
+
from .base import BaseBuilder, BaseIndexer
|
|
12
|
+
|
|
13
|
+
# Type for indexing methods.
|
|
14
|
+
_FaissIndexType = ty.Literal["Flat", "HNSW", "IVF"]
|
|
15
|
+
|
|
16
|
+
# Type for indexing metrics.
|
|
17
|
+
_FaissMetric = ty.Literal["IP", "L2"]
|
|
18
|
+
|
|
19
|
+
# Default indexing method.
|
|
20
|
+
_DEFAULT_FAISS_INDEX_TYPE: _FaissIndexType = "Flat"
|
|
21
|
+
|
|
22
|
+
# Default indexing metric.
|
|
23
|
+
_DEFAULT_FAISS_METRIC: _FaissMetric = "L2"
|
|
24
|
+
|
|
25
|
+
# Default number of clusters to build an IVF index.
|
|
26
|
+
_DEFAULT_N_LISTS = 100
|
|
27
|
+
|
|
28
|
+
# Default max number of neighbors to build an HNSW index.
|
|
29
|
+
_DEFAULT_M = 32
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class FaissBuilder(BaseBuilder):
|
|
33
|
+
"""Implement ``BaseBuilder`` for FAISS vector index construction."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
index_type: _FaissIndexType = _DEFAULT_FAISS_INDEX_TYPE,
|
|
38
|
+
metric: _FaissMetric = _DEFAULT_FAISS_METRIC,
|
|
39
|
+
*,
|
|
40
|
+
m: int = _DEFAULT_M,
|
|
41
|
+
nlists: int = _DEFAULT_N_LISTS,
|
|
42
|
+
efSearch: ty.Optional[int] = None,
|
|
43
|
+
nprobe: ty.Optional[int] = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""
|
|
46
|
+
Initialize a FAISS builder.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
index_type : ``"Flat"``, ``"HNSW"``, or ``"IVF"``, optional
|
|
51
|
+
The indexing index_type. Default to ``"Flat"``.
|
|
52
|
+
metric : ``"IP"``, ``"L2"``, optional
|
|
53
|
+
The indexing metric. Default to ``"L2"``.
|
|
54
|
+
m : int, optional
|
|
55
|
+
Max number of neighbors to build an HNSW index.
|
|
56
|
+
nlists : int, optional
|
|
57
|
+
Number of clusters to build an IVF index.
|
|
58
|
+
efSearch : int or None, optional
|
|
59
|
+
Number of candidate nodes during an HNSW search.
|
|
60
|
+
nprobe : int or None, optional
|
|
61
|
+
Number of clusters during an IVF search.
|
|
62
|
+
"""
|
|
63
|
+
self._index_type_dsl = _build_index_type_dsl(index_type, m=m, nlists=nlists)
|
|
64
|
+
self._metric = _resolve_metric_type(metric)
|
|
65
|
+
|
|
66
|
+
self._efSearch = efSearch
|
|
67
|
+
self._nprobe = nprobe
|
|
68
|
+
|
|
69
|
+
@contextlib.contextmanager
|
|
70
|
+
def from_embeddings(
|
|
71
|
+
self,
|
|
72
|
+
embeddings: torch.Tensor,
|
|
73
|
+
) -> ty.Generator["FaissIndexer",
|
|
74
|
+
None,
|
|
75
|
+
None]:
|
|
76
|
+
"""Adhere to ``BaseBuilder.from_embeddings``."""
|
|
77
|
+
index: faiss.Index = faiss.index_factory(
|
|
78
|
+
embeddings.shape[1],
|
|
79
|
+
self._index_type_dsl,
|
|
80
|
+
self._metric,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if isinstance(index, faiss.IndexHNSW) and self._efSearch is not None:
|
|
84
|
+
index.hnsw.efSearch = self._efSearch
|
|
85
|
+
|
|
86
|
+
if isinstance(index, faiss.IndexIVF) and self._nprobe is not None:
|
|
87
|
+
index.nprobe = self._nprobe
|
|
88
|
+
|
|
89
|
+
index.train(embeddings)
|
|
90
|
+
index.add(embeddings)
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
yield FaissIndexer(index)
|
|
94
|
+
finally:
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
@contextlib.contextmanager
|
|
98
|
+
def from_index_file(
|
|
99
|
+
self,
|
|
100
|
+
index_file: FilePath,
|
|
101
|
+
) -> ty.Generator["FaissIndexer",
|
|
102
|
+
None,
|
|
103
|
+
None]:
|
|
104
|
+
"""Adhere to ``BaseBuilder.from_index_file``."""
|
|
105
|
+
index = faiss.read_index(str(index_file))
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
yield FaissIndexer(index)
|
|
109
|
+
finally:
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class FaissIndexer(BaseIndexer):
|
|
114
|
+
"""FAISS-based implementation of ``BaseIndexer``."""
|
|
115
|
+
|
|
116
|
+
def __init__(self, index: faiss.Index) -> None:
|
|
117
|
+
"""Initialize a FAISS indexer."""
|
|
118
|
+
self._index = index
|
|
119
|
+
|
|
120
|
+
def query(
|
|
121
|
+
self,
|
|
122
|
+
embeddings: torch.Tensor,
|
|
123
|
+
top_k: int,
|
|
124
|
+
) -> tuple[torch.Tensor,
|
|
125
|
+
torch.Tensor]:
|
|
126
|
+
"""Adhere to ``BaseIndexer.query``."""
|
|
127
|
+
dists, ids = self._index.search(embeddings.cpu().numpy(), top_k)
|
|
128
|
+
return torch.from_numpy(ids), torch.from_numpy(dists)
|
|
129
|
+
|
|
130
|
+
def save(self, file_path: FilePath) -> None:
|
|
131
|
+
"""Adhere to ``BaseIndexer.save``."""
|
|
132
|
+
faiss.write_index(self._index, str(file_path))
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# helper functions
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _build_index_type_dsl(index_type: _FaissIndexType, *, m: int, nlists: int) -> str:
|
|
139
|
+
"""Build the index_type DSL passed to ``faiss.index_factory``."""
|
|
140
|
+
if index_type == "HNSW":
|
|
141
|
+
return f"{index_type}{m},Flat"
|
|
142
|
+
|
|
143
|
+
if index_type == "IVF":
|
|
144
|
+
return f"{index_type}{nlists},Flat"
|
|
145
|
+
|
|
146
|
+
return "Flat"
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _resolve_metric_type(metric: _FaissMetric) -> int:
|
|
150
|
+
"""Resolve the metric type from a string literal to an integer."""
|
|
151
|
+
if metric == "L2":
|
|
152
|
+
return ty.cast(int, faiss.METRIC_L2)
|
|
153
|
+
|
|
154
|
+
return ty.cast(int, faiss.METRIC_INNER_PRODUCT)
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Milvus-based vector index implementation for the retrieval stage."""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import typing as ty
|
|
5
|
+
import uuid
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pymilvus as milvus
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from torch_rechub.types import FilePath
|
|
12
|
+
|
|
13
|
+
from .base import BaseBuilder, BaseIndexer
|
|
14
|
+
|
|
15
|
+
# Type for indexing methods.
|
|
16
|
+
_MilvusIndexType = ty.Literal["FLAT", "HNSW", "IVF_FLAT"]
|
|
17
|
+
|
|
18
|
+
# Type for indexing metrics.
|
|
19
|
+
_MilvusMetric = ty.Literal["COSINE", "IP", "L2"]
|
|
20
|
+
|
|
21
|
+
# Default indexing method.
|
|
22
|
+
_DEFAULT_MILVUS_INDEX_TYPE: _MilvusIndexType = "FLAT"
|
|
23
|
+
|
|
24
|
+
# Default indexing metric.
|
|
25
|
+
_DEFAULT_MILVUS_METRIC: _MilvusMetric = "COSINE"
|
|
26
|
+
|
|
27
|
+
# Default number of clusters to build an IVF index.
|
|
28
|
+
_DEFAULT_N_LIST = 128
|
|
29
|
+
|
|
30
|
+
# Default max number of neighbors to build an HNSW index.
|
|
31
|
+
_DEFAULT_M = 30
|
|
32
|
+
|
|
33
|
+
# Default name of Milvus database connection.
|
|
34
|
+
_DEFAULT_NAME = "rechub"
|
|
35
|
+
|
|
36
|
+
# Default host of Milvus instance.
|
|
37
|
+
_DEFAULT_HOST = "localhost"
|
|
38
|
+
|
|
39
|
+
# Default port of Milvus instance.
|
|
40
|
+
_DEFAULT_PORT = 19530
|
|
41
|
+
|
|
42
|
+
# Name of the embedding column in the Milvus database table.
|
|
43
|
+
_EMBEDDING_COLUMN = "embedding"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MilvusBuilder(BaseBuilder):
|
|
47
|
+
"""Implement ``BaseBuilder`` for Milvus vector index construction."""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
d: int,
|
|
52
|
+
index_type: _MilvusIndexType = _DEFAULT_MILVUS_INDEX_TYPE,
|
|
53
|
+
metric: _MilvusMetric = _DEFAULT_MILVUS_METRIC,
|
|
54
|
+
*,
|
|
55
|
+
m: int = _DEFAULT_M,
|
|
56
|
+
nlist: int = _DEFAULT_N_LIST,
|
|
57
|
+
ef: ty.Optional[int] = None,
|
|
58
|
+
nprobe: ty.Optional[int] = None,
|
|
59
|
+
name: str = _DEFAULT_NAME,
|
|
60
|
+
host: str = _DEFAULT_HOST,
|
|
61
|
+
port: int = _DEFAULT_PORT,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""
|
|
64
|
+
Initialize a Milvus builder.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
d : int
|
|
69
|
+
The dimension of embeddings.
|
|
70
|
+
index_type : ``"FLAT"``, ``"HNSW"``, or ``"IVF_FLAT"``, optional
|
|
71
|
+
The indexing index_type. Default to ``"FLAT"``.
|
|
72
|
+
metric : ``"COSINE"``, ``"IP"``, or ``"L2"``, optional
|
|
73
|
+
The indexing metric. Default to ``"COSINE"``.
|
|
74
|
+
m : int, optional
|
|
75
|
+
Max number of neighbors to build an HNSW index.
|
|
76
|
+
nlist : int, optional
|
|
77
|
+
Number of clusters to build an IVF index.
|
|
78
|
+
ef : int or None, optional
|
|
79
|
+
Number of candidate nodes during an HNSW search.
|
|
80
|
+
nprobe : int or None, optional
|
|
81
|
+
Number of clusters during an IVF search.
|
|
82
|
+
name : str, optional
|
|
83
|
+
The name of connection. Each name corresponds to one connection.
|
|
84
|
+
host : str, optional
|
|
85
|
+
The host of Milvus instance. Default at "localhost".
|
|
86
|
+
port : int, optional
|
|
87
|
+
The port of Milvus instance. Default at 19530
|
|
88
|
+
"""
|
|
89
|
+
self._d = d
|
|
90
|
+
|
|
91
|
+
# connection parameters
|
|
92
|
+
self._name = name
|
|
93
|
+
self._host = host
|
|
94
|
+
self._port = port
|
|
95
|
+
|
|
96
|
+
bparams: dict[str, ty.Any] = {}
|
|
97
|
+
qparams: dict[str, ty.Any] = {}
|
|
98
|
+
|
|
99
|
+
if index_type == "HNSW":
|
|
100
|
+
bparams.update(M=m)
|
|
101
|
+
if ef is not None:
|
|
102
|
+
qparams.update(ef=ef)
|
|
103
|
+
|
|
104
|
+
if index_type == "IVF_FLAT":
|
|
105
|
+
bparams.update(nlist=nlist)
|
|
106
|
+
if nprobe is not None:
|
|
107
|
+
qparams.update(nprobe=nprobe)
|
|
108
|
+
|
|
109
|
+
self._build_params = dict(
|
|
110
|
+
index_type=index_type,
|
|
111
|
+
metric_type=metric,
|
|
112
|
+
params=bparams,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self._query_params = dict(
|
|
116
|
+
metric_type=metric,
|
|
117
|
+
params=qparams,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
@contextlib.contextmanager
|
|
121
|
+
def from_embeddings(
|
|
122
|
+
self,
|
|
123
|
+
embeddings: torch.Tensor,
|
|
124
|
+
) -> ty.Generator["MilvusIndexer",
|
|
125
|
+
None,
|
|
126
|
+
None]:
|
|
127
|
+
"""Adhere to ``BaseBuilder.from_embeddings``."""
|
|
128
|
+
milvus.connections.connect(self._name, host=self._host, port=self._port)
|
|
129
|
+
collection = self._build_collection(embeddings)
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
yield MilvusIndexer(collection, self._query_params)
|
|
133
|
+
finally:
|
|
134
|
+
collection.drop()
|
|
135
|
+
milvus.connections.disconnect(self._name)
|
|
136
|
+
|
|
137
|
+
@contextlib.contextmanager
|
|
138
|
+
def from_index_file(
|
|
139
|
+
self,
|
|
140
|
+
index_file: FilePath,
|
|
141
|
+
) -> ty.Generator["MilvusIndexer",
|
|
142
|
+
None,
|
|
143
|
+
None]:
|
|
144
|
+
"""Adhere to ``BaseBuilder.from_index_file``."""
|
|
145
|
+
raise NotImplementedError("Milvus does not support index files!")
|
|
146
|
+
|
|
147
|
+
def _build_collection(self, embeddings: torch.Tensor) -> milvus.Collection:
|
|
148
|
+
"""Build a Milvus collection with the current connection."""
|
|
149
|
+
fields = [
|
|
150
|
+
milvus.FieldSchema(
|
|
151
|
+
name="id",
|
|
152
|
+
dtype=milvus.DataType.INT64,
|
|
153
|
+
is_primary=True,
|
|
154
|
+
),
|
|
155
|
+
milvus.FieldSchema(
|
|
156
|
+
name=_EMBEDDING_COLUMN,
|
|
157
|
+
dtype=milvus.DataType.FLOAT_VECTOR,
|
|
158
|
+
dim=self._d,
|
|
159
|
+
),
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
collection = milvus.Collection(
|
|
163
|
+
name=f"{self._name}_{uuid.uuid4().hex}",
|
|
164
|
+
schema=milvus.CollectionSchema(fields=fields),
|
|
165
|
+
using=self._name,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
n, _ = embeddings.shape
|
|
169
|
+
collection.insert([np.arange(n, dtype=np.int64), embeddings.cpu().numpy()])
|
|
170
|
+
collection.create_index(_EMBEDDING_COLUMN, index_params=self._build_params)
|
|
171
|
+
collection.load()
|
|
172
|
+
|
|
173
|
+
return collection
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class MilvusIndexer(BaseIndexer):
|
|
177
|
+
"""Milvus-based implementation of ``BaseIndexer``."""
|
|
178
|
+
|
|
179
|
+
def __init__(
|
|
180
|
+
self,
|
|
181
|
+
collection: milvus.Collection,
|
|
182
|
+
query_params: dict[str,
|
|
183
|
+
ty.Any],
|
|
184
|
+
) -> None:
|
|
185
|
+
"""Initialize a Milvus indexer."""
|
|
186
|
+
self._collection = collection
|
|
187
|
+
self._query_params = query_params
|
|
188
|
+
|
|
189
|
+
def query(
|
|
190
|
+
self,
|
|
191
|
+
embeddings: torch.Tensor,
|
|
192
|
+
top_k: int,
|
|
193
|
+
) -> tuple[torch.Tensor,
|
|
194
|
+
torch.Tensor]:
|
|
195
|
+
"""Adhere to ``BaseIndexer.query``."""
|
|
196
|
+
results = self._collection.search(
|
|
197
|
+
data=embeddings.cpu().numpy(),
|
|
198
|
+
anns_field=_EMBEDDING_COLUMN,
|
|
199
|
+
param=self._query_params,
|
|
200
|
+
limit=top_k,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
n, _ = embeddings.shape
|
|
204
|
+
nn_ids = np.zeros((n, top_k), dtype=np.int64)
|
|
205
|
+
nn_distances = np.zeros((n, top_k), dtype=np.float32)
|
|
206
|
+
|
|
207
|
+
for i, result in enumerate(results):
|
|
208
|
+
nn_ids[i] = result.ids
|
|
209
|
+
nn_distances[i] = result.distances
|
|
210
|
+
|
|
211
|
+
return torch.from_numpy(nn_ids), torch.from_numpy(nn_distances)
|
|
212
|
+
|
|
213
|
+
def save(self, file_path: FilePath) -> None:
|
|
214
|
+
"""Adhere to ``BaseIndexer.save``."""
|
|
215
|
+
raise NotImplementedError("Milvus does not support index files!")
|
|
@@ -185,7 +185,7 @@ class CTRTrainer(object):
|
|
|
185
185
|
predicts.extend(y_pred.tolist())
|
|
186
186
|
return predicts
|
|
187
187
|
|
|
188
|
-
def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
188
|
+
def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
|
|
189
189
|
"""Export the trained model to ONNX format.
|
|
190
190
|
|
|
191
191
|
This method exports the ranking model (e.g., DeepFM, WideDeep, DCN) to ONNX format
|
|
@@ -202,6 +202,7 @@ class CTRTrainer(object):
|
|
|
202
202
|
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
203
203
|
If None, defaults to 'cpu' for maximum compatibility.
|
|
204
204
|
verbose (bool): Print export details (default: False).
|
|
205
|
+
onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
|
|
205
206
|
|
|
206
207
|
Returns:
|
|
207
208
|
bool: True if export succeeded, False otherwise.
|
|
@@ -227,7 +228,16 @@ class CTRTrainer(object):
|
|
|
227
228
|
export_device = device if device is not None else 'cpu'
|
|
228
229
|
|
|
229
230
|
exporter = ONNXExporter(model, device=export_device)
|
|
230
|
-
return exporter.export(
|
|
231
|
+
return exporter.export(
|
|
232
|
+
output_path=output_path,
|
|
233
|
+
dummy_input=dummy_input,
|
|
234
|
+
batch_size=batch_size,
|
|
235
|
+
seq_length=seq_length,
|
|
236
|
+
opset_version=opset_version,
|
|
237
|
+
dynamic_batch=dynamic_batch,
|
|
238
|
+
verbose=verbose,
|
|
239
|
+
onnx_export_kwargs=onnx_export_kwargs,
|
|
240
|
+
)
|
|
231
241
|
|
|
232
242
|
def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
233
243
|
"""Visualize the model's computation graph.
|
|
@@ -215,7 +215,7 @@ class MatchTrainer(object):
|
|
|
215
215
|
predicts.append(y_pred.data)
|
|
216
216
|
return torch.cat(predicts, dim=0)
|
|
217
217
|
|
|
218
|
-
def export_onnx(self, output_path, mode=None, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
218
|
+
def export_onnx(self, output_path, mode=None, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
|
|
219
219
|
"""Export the trained matching model to ONNX format.
|
|
220
220
|
|
|
221
221
|
This method exports matching/retrieval models (e.g., DSSM, YoutubeDNN, MIND)
|
|
@@ -237,6 +237,7 @@ class MatchTrainer(object):
|
|
|
237
237
|
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
238
238
|
If None, defaults to 'cpu' for maximum compatibility.
|
|
239
239
|
verbose (bool): Print export details (default: False).
|
|
240
|
+
onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
|
|
240
241
|
|
|
241
242
|
Returns:
|
|
242
243
|
bool: True if export succeeded, False otherwise.
|
|
@@ -270,7 +271,17 @@ class MatchTrainer(object):
|
|
|
270
271
|
|
|
271
272
|
try:
|
|
272
273
|
exporter = ONNXExporter(model, device=export_device)
|
|
273
|
-
return exporter.export(
|
|
274
|
+
return exporter.export(
|
|
275
|
+
output_path=output_path,
|
|
276
|
+
mode=mode,
|
|
277
|
+
dummy_input=dummy_input,
|
|
278
|
+
batch_size=batch_size,
|
|
279
|
+
seq_length=seq_length,
|
|
280
|
+
opset_version=opset_version,
|
|
281
|
+
dynamic_batch=dynamic_batch,
|
|
282
|
+
verbose=verbose,
|
|
283
|
+
onnx_export_kwargs=onnx_export_kwargs,
|
|
284
|
+
)
|
|
274
285
|
finally:
|
|
275
286
|
# Restore original mode
|
|
276
287
|
if hasattr(model, 'mode'):
|
|
@@ -261,7 +261,7 @@ class MTLTrainer(object):
|
|
|
261
261
|
predicts.extend(y_preds.tolist())
|
|
262
262
|
return predicts
|
|
263
263
|
|
|
264
|
-
def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
264
|
+
def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
|
|
265
265
|
"""Export the trained multi-task model to ONNX format.
|
|
266
266
|
|
|
267
267
|
This method exports multi-task learning models (e.g., MMOE, PLE, ESMM, SharedBottom)
|
|
@@ -283,6 +283,7 @@ class MTLTrainer(object):
|
|
|
283
283
|
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
284
284
|
If None, defaults to 'cpu' for maximum compatibility.
|
|
285
285
|
verbose (bool): Print export details (default: False).
|
|
286
|
+
onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
|
|
286
287
|
|
|
287
288
|
Returns:
|
|
288
289
|
bool: True if export succeeded, False otherwise.
|
|
@@ -304,7 +305,16 @@ class MTLTrainer(object):
|
|
|
304
305
|
export_device = device if device is not None else 'cpu'
|
|
305
306
|
|
|
306
307
|
exporter = ONNXExporter(model, device=export_device)
|
|
307
|
-
return exporter.export(
|
|
308
|
+
return exporter.export(
|
|
309
|
+
output_path=output_path,
|
|
310
|
+
dummy_input=dummy_input,
|
|
311
|
+
batch_size=batch_size,
|
|
312
|
+
seq_length=seq_length,
|
|
313
|
+
opset_version=opset_version,
|
|
314
|
+
dynamic_batch=dynamic_batch,
|
|
315
|
+
verbose=verbose,
|
|
316
|
+
onnx_export_kwargs=onnx_export_kwargs,
|
|
317
|
+
)
|
|
308
318
|
|
|
309
319
|
def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
310
320
|
"""Visualize the model's computation graph.
|
|
@@ -255,7 +255,7 @@ class SeqTrainer(object):
|
|
|
255
255
|
|
|
256
256
|
return avg_loss, accuracy
|
|
257
257
|
|
|
258
|
-
def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
258
|
+
def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
|
|
259
259
|
"""Export the trained sequence generation model to ONNX format.
|
|
260
260
|
|
|
261
261
|
This method exports sequence generation models (e.g., HSTU) to ONNX format.
|
|
@@ -273,6 +273,7 @@ class SeqTrainer(object):
|
|
|
273
273
|
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
274
274
|
If None, defaults to 'cpu' for maximum compatibility.
|
|
275
275
|
verbose (bool): Print export details (default: False).
|
|
276
|
+
onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
|
|
276
277
|
|
|
277
278
|
Returns:
|
|
278
279
|
bool: True if export succeeded, False otherwise.
|
|
@@ -321,20 +322,38 @@ class SeqTrainer(object):
|
|
|
321
322
|
|
|
322
323
|
try:
|
|
323
324
|
with torch.no_grad():
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
325
|
+
import inspect
|
|
326
|
+
|
|
327
|
+
export_kwargs = {
|
|
328
|
+
"f": output_path,
|
|
329
|
+
"input_names": ["seq_tokens",
|
|
330
|
+
"seq_time_diffs"],
|
|
331
|
+
"output_names": ["output"],
|
|
332
|
+
"dynamic_axes": dynamic_axes,
|
|
333
|
+
"opset_version": opset_version,
|
|
334
|
+
"do_constant_folding": True,
|
|
335
|
+
"verbose": verbose,
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
if onnx_export_kwargs:
|
|
339
|
+
overlap = set(export_kwargs.keys()) & set(onnx_export_kwargs.keys())
|
|
340
|
+
overlap.discard("dynamo")
|
|
341
|
+
if overlap:
|
|
342
|
+
raise ValueError("onnx_export_kwargs contains keys that overlap with explicit args: "
|
|
343
|
+
f"{sorted(overlap)}. Please set them via export_onnx() parameters instead.")
|
|
344
|
+
export_kwargs.update(onnx_export_kwargs)
|
|
345
|
+
|
|
346
|
+
# Auto-pick exporter:
|
|
347
|
+
# - dynamic_axes present => prefer legacy exporter (dynamo=False) for dynamic batch/seq
|
|
348
|
+
# - otherwise prefer dynamo exporter (dynamo=True) on newer torch
|
|
349
|
+
sig = inspect.signature(torch.onnx.export)
|
|
350
|
+
if "dynamo" in sig.parameters:
|
|
351
|
+
if "dynamo" not in export_kwargs:
|
|
352
|
+
export_kwargs["dynamo"] = False if dynamic_axes is not None else True
|
|
353
|
+
else:
|
|
354
|
+
export_kwargs.pop("dynamo", None)
|
|
355
|
+
|
|
356
|
+
torch.onnx.export(model, (dummy_seq_tokens, dummy_seq_time_diffs), **export_kwargs)
|
|
338
357
|
|
|
339
358
|
if verbose:
|
|
340
359
|
print(f"Successfully exported ONNX model to: {output_path}")
|