torch-rechub 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/layers.py +213 -150
- torch_rechub/basic/loss_func.py +62 -47
- torch_rechub/basic/tracking.py +198 -0
- torch_rechub/data/__init__.py +0 -0
- torch_rechub/data/convert.py +67 -0
- torch_rechub/data/dataset.py +107 -0
- torch_rechub/models/generative/hstu.py +48 -33
- torch_rechub/serving/__init__.py +50 -0
- torch_rechub/serving/annoy.py +133 -0
- torch_rechub/serving/base.py +107 -0
- torch_rechub/serving/faiss.py +154 -0
- torch_rechub/serving/milvus.py +215 -0
- torch_rechub/trainers/ctr_trainer.py +52 -3
- torch_rechub/trainers/match_trainer.py +52 -3
- torch_rechub/trainers/mtl_trainer.py +61 -3
- torch_rechub/trainers/seq_trainer.py +93 -17
- torch_rechub/types.py +5 -0
- torch_rechub/utils/data.py +167 -137
- torch_rechub/utils/hstu_utils.py +87 -76
- torch_rechub/utils/model_utils.py +10 -12
- torch_rechub/utils/onnx_export.py +98 -45
- torch_rechub/utils/quantization.py +128 -0
- torch_rechub/utils/visualization.py +4 -12
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/METADATA +20 -5
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/RECORD +27 -17
- torch_rechub/trainers/matching.md +0 -3
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -39,6 +39,7 @@ class MatchTrainer(object):
|
|
|
39
39
|
device="cpu",
|
|
40
40
|
gpus=None,
|
|
41
41
|
model_path="./",
|
|
42
|
+
model_logger=None,
|
|
42
43
|
):
|
|
43
44
|
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
44
45
|
if gpus is None:
|
|
@@ -73,10 +74,13 @@ class MatchTrainer(object):
|
|
|
73
74
|
self.model_path = model_path
|
|
74
75
|
# Initialize regularization loss
|
|
75
76
|
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
77
|
+
self.model_logger = model_logger
|
|
76
78
|
|
|
77
79
|
def train_one_epoch(self, data_loader, log_interval=10):
|
|
78
80
|
self.model.train()
|
|
79
81
|
total_loss = 0
|
|
82
|
+
epoch_loss = 0
|
|
83
|
+
batch_count = 0
|
|
80
84
|
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
81
85
|
for i, (x_dict, y) in enumerate(tk0):
|
|
82
86
|
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
@@ -114,14 +118,26 @@ class MatchTrainer(object):
|
|
|
114
118
|
loss.backward()
|
|
115
119
|
self.optimizer.step()
|
|
116
120
|
total_loss += loss.item()
|
|
121
|
+
epoch_loss += loss.item()
|
|
122
|
+
batch_count += 1
|
|
117
123
|
if (i + 1) % log_interval == 0:
|
|
118
124
|
tk0.set_postfix(loss=total_loss / log_interval)
|
|
119
125
|
total_loss = 0
|
|
120
126
|
|
|
127
|
+
# Return average epoch loss
|
|
128
|
+
return epoch_loss / batch_count if batch_count > 0 else 0
|
|
129
|
+
|
|
121
130
|
def fit(self, train_dataloader, val_dataloader=None):
|
|
131
|
+
for logger in self._iter_loggers():
|
|
132
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.mode})
|
|
133
|
+
|
|
122
134
|
for epoch_i in range(self.n_epoch):
|
|
123
135
|
print('epoch:', epoch_i)
|
|
124
|
-
self.train_one_epoch(train_dataloader)
|
|
136
|
+
train_loss = self.train_one_epoch(train_dataloader)
|
|
137
|
+
|
|
138
|
+
for logger in self._iter_loggers():
|
|
139
|
+
logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
|
|
140
|
+
|
|
125
141
|
if self.scheduler is not None:
|
|
126
142
|
if epoch_i % self.scheduler.step_size == 0:
|
|
127
143
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
@@ -130,12 +146,34 @@ class MatchTrainer(object):
|
|
|
130
146
|
if val_dataloader:
|
|
131
147
|
auc = self.evaluate(self.model, val_dataloader)
|
|
132
148
|
print('epoch:', epoch_i, 'validation: auc:', auc)
|
|
149
|
+
|
|
150
|
+
for logger in self._iter_loggers():
|
|
151
|
+
logger.log_metrics({'val/auc': auc}, step=epoch_i)
|
|
152
|
+
|
|
133
153
|
if self.early_stopper.stop_training(auc, self.model.state_dict()):
|
|
134
154
|
print(f'validation: best auc: {self.early_stopper.best_auc}')
|
|
135
155
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
136
156
|
break
|
|
157
|
+
|
|
137
158
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
|
|
138
159
|
|
|
160
|
+
for logger in self._iter_loggers():
|
|
161
|
+
logger.finish()
|
|
162
|
+
|
|
163
|
+
def _iter_loggers(self):
|
|
164
|
+
"""Return logger instances as a list.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
list
|
|
169
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
170
|
+
"""
|
|
171
|
+
if self.model_logger is None:
|
|
172
|
+
return []
|
|
173
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
174
|
+
return list(self.model_logger)
|
|
175
|
+
return [self.model_logger]
|
|
176
|
+
|
|
139
177
|
def evaluate(self, model, data_loader):
|
|
140
178
|
model.eval()
|
|
141
179
|
targets, predicts = list(), list()
|
|
@@ -177,7 +215,7 @@ class MatchTrainer(object):
|
|
|
177
215
|
predicts.append(y_pred.data)
|
|
178
216
|
return torch.cat(predicts, dim=0)
|
|
179
217
|
|
|
180
|
-
def export_onnx(self, output_path, mode=None, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
218
|
+
def export_onnx(self, output_path, mode=None, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
|
|
181
219
|
"""Export the trained matching model to ONNX format.
|
|
182
220
|
|
|
183
221
|
This method exports matching/retrieval models (e.g., DSSM, YoutubeDNN, MIND)
|
|
@@ -199,6 +237,7 @@ class MatchTrainer(object):
|
|
|
199
237
|
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
200
238
|
If None, defaults to 'cpu' for maximum compatibility.
|
|
201
239
|
verbose (bool): Print export details (default: False).
|
|
240
|
+
onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
|
|
202
241
|
|
|
203
242
|
Returns:
|
|
204
243
|
bool: True if export succeeded, False otherwise.
|
|
@@ -232,7 +271,17 @@ class MatchTrainer(object):
|
|
|
232
271
|
|
|
233
272
|
try:
|
|
234
273
|
exporter = ONNXExporter(model, device=export_device)
|
|
235
|
-
return exporter.export(
|
|
274
|
+
return exporter.export(
|
|
275
|
+
output_path=output_path,
|
|
276
|
+
mode=mode,
|
|
277
|
+
dummy_input=dummy_input,
|
|
278
|
+
batch_size=batch_size,
|
|
279
|
+
seq_length=seq_length,
|
|
280
|
+
opset_version=opset_version,
|
|
281
|
+
dynamic_batch=dynamic_batch,
|
|
282
|
+
verbose=verbose,
|
|
283
|
+
onnx_export_kwargs=onnx_export_kwargs,
|
|
284
|
+
)
|
|
236
285
|
finally:
|
|
237
286
|
# Restore original mode
|
|
238
287
|
if hasattr(model, 'mode'):
|
|
@@ -47,6 +47,7 @@ class MTLTrainer(object):
|
|
|
47
47
|
device="cpu",
|
|
48
48
|
gpus=None,
|
|
49
49
|
model_path="./",
|
|
50
|
+
model_logger=None,
|
|
50
51
|
):
|
|
51
52
|
self.model = model
|
|
52
53
|
if gpus is None:
|
|
@@ -104,6 +105,7 @@ class MTLTrainer(object):
|
|
|
104
105
|
self.model_path = model_path
|
|
105
106
|
# Initialize regularization loss
|
|
106
107
|
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
108
|
+
self.model_logger = model_logger
|
|
107
109
|
|
|
108
110
|
def train_one_epoch(self, data_loader):
|
|
109
111
|
self.model.train()
|
|
@@ -163,21 +165,42 @@ class MTLTrainer(object):
|
|
|
163
165
|
def fit(self, train_dataloader, val_dataloader, mode='base', seed=0):
|
|
164
166
|
total_log = []
|
|
165
167
|
|
|
168
|
+
# Log hyperparameters once
|
|
169
|
+
for logger in self._iter_loggers():
|
|
170
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self._current_lr(), 'adaptive_method': self.adaptive_method})
|
|
171
|
+
|
|
166
172
|
for epoch_i in range(self.n_epoch):
|
|
167
173
|
_log_per_epoch = self.train_one_epoch(train_dataloader)
|
|
168
174
|
|
|
175
|
+
# Collect metrics
|
|
176
|
+
logs = {f'train/task_{task_id}_loss': loss_val for task_id, loss_val in enumerate(_log_per_epoch)}
|
|
177
|
+
lr_value = self._current_lr()
|
|
178
|
+
if lr_value is not None:
|
|
179
|
+
logs['learning_rate'] = lr_value
|
|
180
|
+
|
|
169
181
|
if self.scheduler is not None:
|
|
170
182
|
if epoch_i % self.scheduler.step_size == 0:
|
|
171
183
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
172
184
|
self.scheduler.step() # update lr in epoch level by scheduler
|
|
185
|
+
|
|
173
186
|
scores = self.evaluate(self.model, val_dataloader)
|
|
174
187
|
print('epoch:', epoch_i, 'validation scores: ', scores)
|
|
175
188
|
|
|
176
|
-
for score in scores:
|
|
189
|
+
for task_id, score in enumerate(scores):
|
|
190
|
+
logs[f'val/task_{task_id}_score'] = score
|
|
177
191
|
_log_per_epoch.append(score)
|
|
192
|
+
logs['auc'] = scores[self.earlystop_taskid]
|
|
193
|
+
|
|
194
|
+
if self.loss_weight:
|
|
195
|
+
for task_id, weight in enumerate(self.loss_weight):
|
|
196
|
+
logs[f'loss_weight/task_{task_id}'] = weight.item()
|
|
178
197
|
|
|
179
198
|
total_log.append(_log_per_epoch)
|
|
180
199
|
|
|
200
|
+
# Log metrics once per epoch
|
|
201
|
+
for logger in self._iter_loggers():
|
|
202
|
+
logger.log_metrics(logs, step=epoch_i)
|
|
203
|
+
|
|
181
204
|
if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
|
|
182
205
|
print('validation best auc of main task %d: %.6f' % (self.earlystop_taskid, self.early_stopper.best_auc))
|
|
183
206
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
@@ -185,8 +208,33 @@ class MTLTrainer(object):
|
|
|
185
208
|
|
|
186
209
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) # save best auc model
|
|
187
210
|
|
|
211
|
+
for logger in self._iter_loggers():
|
|
212
|
+
logger.finish()
|
|
213
|
+
|
|
188
214
|
return total_log
|
|
189
215
|
|
|
216
|
+
def _iter_loggers(self):
|
|
217
|
+
"""Return logger instances as a list.
|
|
218
|
+
|
|
219
|
+
Returns
|
|
220
|
+
-------
|
|
221
|
+
list
|
|
222
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
223
|
+
"""
|
|
224
|
+
if self.model_logger is None:
|
|
225
|
+
return []
|
|
226
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
227
|
+
return list(self.model_logger)
|
|
228
|
+
return [self.model_logger]
|
|
229
|
+
|
|
230
|
+
def _current_lr(self):
|
|
231
|
+
"""Fetch current learning rate regardless of adaptive method."""
|
|
232
|
+
if self.adaptive_method == "metabalance":
|
|
233
|
+
return self.share_optimizer.param_groups[0]['lr'] if hasattr(self, 'share_optimizer') else None
|
|
234
|
+
if hasattr(self, 'optimizer'):
|
|
235
|
+
return self.optimizer.param_groups[0]['lr']
|
|
236
|
+
return None
|
|
237
|
+
|
|
190
238
|
def evaluate(self, model, data_loader):
|
|
191
239
|
model.eval()
|
|
192
240
|
targets, predicts = list(), list()
|
|
@@ -213,7 +261,7 @@ class MTLTrainer(object):
|
|
|
213
261
|
predicts.extend(y_preds.tolist())
|
|
214
262
|
return predicts
|
|
215
263
|
|
|
216
|
-
def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
264
|
+
def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
|
|
217
265
|
"""Export the trained multi-task model to ONNX format.
|
|
218
266
|
|
|
219
267
|
This method exports multi-task learning models (e.g., MMOE, PLE, ESMM, SharedBottom)
|
|
@@ -235,6 +283,7 @@ class MTLTrainer(object):
|
|
|
235
283
|
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
236
284
|
If None, defaults to 'cpu' for maximum compatibility.
|
|
237
285
|
verbose (bool): Print export details (default: False).
|
|
286
|
+
onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
|
|
238
287
|
|
|
239
288
|
Returns:
|
|
240
289
|
bool: True if export succeeded, False otherwise.
|
|
@@ -256,7 +305,16 @@ class MTLTrainer(object):
|
|
|
256
305
|
export_device = device if device is not None else 'cpu'
|
|
257
306
|
|
|
258
307
|
exporter = ONNXExporter(model, device=export_device)
|
|
259
|
-
return exporter.export(
|
|
308
|
+
return exporter.export(
|
|
309
|
+
output_path=output_path,
|
|
310
|
+
dummy_input=dummy_input,
|
|
311
|
+
batch_size=batch_size,
|
|
312
|
+
seq_length=seq_length,
|
|
313
|
+
opset_version=opset_version,
|
|
314
|
+
dynamic_batch=dynamic_batch,
|
|
315
|
+
verbose=verbose,
|
|
316
|
+
onnx_export_kwargs=onnx_export_kwargs,
|
|
317
|
+
)
|
|
260
318
|
|
|
261
319
|
def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
262
320
|
"""Visualize the model's computation graph.
|
|
@@ -46,7 +46,22 @@ class SeqTrainer(object):
|
|
|
46
46
|
... )
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
|
-
def __init__(
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
model,
|
|
52
|
+
optimizer_fn=torch.optim.Adam,
|
|
53
|
+
optimizer_params=None,
|
|
54
|
+
scheduler_fn=None,
|
|
55
|
+
scheduler_params=None,
|
|
56
|
+
n_epoch=10,
|
|
57
|
+
earlystop_patience=10,
|
|
58
|
+
device='cpu',
|
|
59
|
+
gpus=None,
|
|
60
|
+
model_path='./',
|
|
61
|
+
loss_type='cross_entropy',
|
|
62
|
+
loss_params=None,
|
|
63
|
+
model_logger=None
|
|
64
|
+
):
|
|
50
65
|
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
51
66
|
if gpus is None:
|
|
52
67
|
gpus = []
|
|
@@ -74,9 +89,11 @@ class SeqTrainer(object):
|
|
|
74
89
|
loss_params = {"ignore_index": 0}
|
|
75
90
|
self.loss_fn = nn.CrossEntropyLoss(**loss_params)
|
|
76
91
|
|
|
92
|
+
self.loss_type = loss_type
|
|
77
93
|
self.n_epoch = n_epoch
|
|
78
94
|
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
79
95
|
self.model_path = model_path
|
|
96
|
+
self.model_logger = model_logger
|
|
80
97
|
|
|
81
98
|
def fit(self, train_dataloader, val_dataloader=None):
|
|
82
99
|
"""训练模型.
|
|
@@ -90,10 +107,18 @@ class SeqTrainer(object):
|
|
|
90
107
|
"""
|
|
91
108
|
history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
|
|
92
109
|
|
|
110
|
+
for logger in self._iter_loggers():
|
|
111
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_type': self.loss_type})
|
|
112
|
+
|
|
93
113
|
for epoch_i in range(self.n_epoch):
|
|
94
114
|
print('epoch:', epoch_i)
|
|
95
115
|
# 训练阶段
|
|
96
|
-
self.train_one_epoch(train_dataloader)
|
|
116
|
+
train_loss = self.train_one_epoch(train_dataloader)
|
|
117
|
+
history['train_loss'].append(train_loss)
|
|
118
|
+
|
|
119
|
+
# Collect metrics
|
|
120
|
+
logs = {'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}
|
|
121
|
+
|
|
97
122
|
if self.scheduler is not None:
|
|
98
123
|
if epoch_i % self.scheduler.step_size == 0:
|
|
99
124
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
@@ -105,6 +130,10 @@ class SeqTrainer(object):
|
|
|
105
130
|
history['val_loss'].append(val_loss)
|
|
106
131
|
history['val_accuracy'].append(val_accuracy)
|
|
107
132
|
|
|
133
|
+
logs['val/loss'] = val_loss
|
|
134
|
+
logs['val/accuracy'] = val_accuracy
|
|
135
|
+
logs['auc'] = val_accuracy # For compatibility with EarlyStopper
|
|
136
|
+
|
|
108
137
|
print(f"epoch: {epoch_i}, validation: loss: {val_loss:.4f}, accuracy: {val_accuracy:.4f}")
|
|
109
138
|
|
|
110
139
|
# 早停
|
|
@@ -113,9 +142,30 @@ class SeqTrainer(object):
|
|
|
113
142
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
114
143
|
break
|
|
115
144
|
|
|
145
|
+
for logger in self._iter_loggers():
|
|
146
|
+
logger.log_metrics(logs, step=epoch_i)
|
|
147
|
+
|
|
116
148
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best model
|
|
149
|
+
|
|
150
|
+
for logger in self._iter_loggers():
|
|
151
|
+
logger.finish()
|
|
152
|
+
|
|
117
153
|
return history
|
|
118
154
|
|
|
155
|
+
def _iter_loggers(self):
|
|
156
|
+
"""Return logger instances as a list.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
list
|
|
161
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
162
|
+
"""
|
|
163
|
+
if self.model_logger is None:
|
|
164
|
+
return []
|
|
165
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
166
|
+
return list(self.model_logger)
|
|
167
|
+
return [self.model_logger]
|
|
168
|
+
|
|
119
169
|
def train_one_epoch(self, data_loader, log_interval=10):
|
|
120
170
|
"""Train the model for a single epoch.
|
|
121
171
|
|
|
@@ -128,6 +178,8 @@ class SeqTrainer(object):
|
|
|
128
178
|
"""
|
|
129
179
|
self.model.train()
|
|
130
180
|
total_loss = 0
|
|
181
|
+
epoch_loss = 0
|
|
182
|
+
batch_count = 0
|
|
131
183
|
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
132
184
|
for i, (seq_tokens, seq_positions, seq_time_diffs, targets) in enumerate(tk0):
|
|
133
185
|
# Move tensors to the target device
|
|
@@ -152,10 +204,15 @@ class SeqTrainer(object):
|
|
|
152
204
|
self.optimizer.step()
|
|
153
205
|
|
|
154
206
|
total_loss += loss.item()
|
|
207
|
+
epoch_loss += loss.item()
|
|
208
|
+
batch_count += 1
|
|
155
209
|
if (i + 1) % log_interval == 0:
|
|
156
210
|
tk0.set_postfix(loss=total_loss / log_interval)
|
|
157
211
|
total_loss = 0
|
|
158
212
|
|
|
213
|
+
# Return average epoch loss
|
|
214
|
+
return epoch_loss / batch_count if batch_count > 0 else 0
|
|
215
|
+
|
|
159
216
|
def evaluate(self, data_loader):
|
|
160
217
|
"""Evaluate the model on a validation/test data loader.
|
|
161
218
|
|
|
@@ -198,7 +255,7 @@ class SeqTrainer(object):
|
|
|
198
255
|
|
|
199
256
|
return avg_loss, accuracy
|
|
200
257
|
|
|
201
|
-
def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
258
|
+
def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
|
|
202
259
|
"""Export the trained sequence generation model to ONNX format.
|
|
203
260
|
|
|
204
261
|
This method exports sequence generation models (e.g., HSTU) to ONNX format.
|
|
@@ -216,6 +273,7 @@ class SeqTrainer(object):
|
|
|
216
273
|
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
217
274
|
If None, defaults to 'cpu' for maximum compatibility.
|
|
218
275
|
verbose (bool): Print export details (default: False).
|
|
276
|
+
onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
|
|
219
277
|
|
|
220
278
|
Returns:
|
|
221
279
|
bool: True if export succeeded, False otherwise.
|
|
@@ -264,20 +322,38 @@ class SeqTrainer(object):
|
|
|
264
322
|
|
|
265
323
|
try:
|
|
266
324
|
with torch.no_grad():
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
325
|
+
import inspect
|
|
326
|
+
|
|
327
|
+
export_kwargs = {
|
|
328
|
+
"f": output_path,
|
|
329
|
+
"input_names": ["seq_tokens",
|
|
330
|
+
"seq_time_diffs"],
|
|
331
|
+
"output_names": ["output"],
|
|
332
|
+
"dynamic_axes": dynamic_axes,
|
|
333
|
+
"opset_version": opset_version,
|
|
334
|
+
"do_constant_folding": True,
|
|
335
|
+
"verbose": verbose,
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
if onnx_export_kwargs:
|
|
339
|
+
overlap = set(export_kwargs.keys()) & set(onnx_export_kwargs.keys())
|
|
340
|
+
overlap.discard("dynamo")
|
|
341
|
+
if overlap:
|
|
342
|
+
raise ValueError("onnx_export_kwargs contains keys that overlap with explicit args: "
|
|
343
|
+
f"{sorted(overlap)}. Please set them via export_onnx() parameters instead.")
|
|
344
|
+
export_kwargs.update(onnx_export_kwargs)
|
|
345
|
+
|
|
346
|
+
# Auto-pick exporter:
|
|
347
|
+
# - dynamic_axes present => prefer legacy exporter (dynamo=False) for dynamic batch/seq
|
|
348
|
+
# - otherwise prefer dynamo exporter (dynamo=True) on newer torch
|
|
349
|
+
sig = inspect.signature(torch.onnx.export)
|
|
350
|
+
if "dynamo" in sig.parameters:
|
|
351
|
+
if "dynamo" not in export_kwargs:
|
|
352
|
+
export_kwargs["dynamo"] = False if dynamic_axes is not None else True
|
|
353
|
+
else:
|
|
354
|
+
export_kwargs.pop("dynamo", None)
|
|
355
|
+
|
|
356
|
+
torch.onnx.export(model, (dummy_seq_tokens, dummy_seq_time_diffs), **export_kwargs)
|
|
281
357
|
|
|
282
358
|
if verbose:
|
|
283
359
|
print(f"Successfully exported ONNX model to: {output_path}")
|