torch-rechub 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,6 +39,7 @@ class MatchTrainer(object):
39
39
  device="cpu",
40
40
  gpus=None,
41
41
  model_path="./",
42
+ model_logger=None,
42
43
  ):
43
44
  self.model = model # for uniform weights save method in one gpu or multi gpu
44
45
  if gpus is None:
@@ -73,10 +74,13 @@ class MatchTrainer(object):
73
74
  self.model_path = model_path
74
75
  # Initialize regularization loss
75
76
  self.reg_loss_fn = RegularizationLoss(**regularization_params)
77
+ self.model_logger = model_logger
76
78
 
77
79
  def train_one_epoch(self, data_loader, log_interval=10):
78
80
  self.model.train()
79
81
  total_loss = 0
82
+ epoch_loss = 0
83
+ batch_count = 0
80
84
  tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
81
85
  for i, (x_dict, y) in enumerate(tk0):
82
86
  x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
@@ -114,14 +118,26 @@ class MatchTrainer(object):
114
118
  loss.backward()
115
119
  self.optimizer.step()
116
120
  total_loss += loss.item()
121
+ epoch_loss += loss.item()
122
+ batch_count += 1
117
123
  if (i + 1) % log_interval == 0:
118
124
  tk0.set_postfix(loss=total_loss / log_interval)
119
125
  total_loss = 0
120
126
 
127
+ # Return average epoch loss
128
+ return epoch_loss / batch_count if batch_count > 0 else 0
129
+
121
130
  def fit(self, train_dataloader, val_dataloader=None):
131
+ for logger in self._iter_loggers():
132
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.mode})
133
+
122
134
  for epoch_i in range(self.n_epoch):
123
135
  print('epoch:', epoch_i)
124
- self.train_one_epoch(train_dataloader)
136
+ train_loss = self.train_one_epoch(train_dataloader)
137
+
138
+ for logger in self._iter_loggers():
139
+ logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
140
+
125
141
  if self.scheduler is not None:
126
142
  if epoch_i % self.scheduler.step_size == 0:
127
143
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
@@ -130,12 +146,34 @@ class MatchTrainer(object):
130
146
  if val_dataloader:
131
147
  auc = self.evaluate(self.model, val_dataloader)
132
148
  print('epoch:', epoch_i, 'validation: auc:', auc)
149
+
150
+ for logger in self._iter_loggers():
151
+ logger.log_metrics({'val/auc': auc}, step=epoch_i)
152
+
133
153
  if self.early_stopper.stop_training(auc, self.model.state_dict()):
134
154
  print(f'validation: best auc: {self.early_stopper.best_auc}')
135
155
  self.model.load_state_dict(self.early_stopper.best_weights)
136
156
  break
157
+
137
158
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
138
159
 
160
+ for logger in self._iter_loggers():
161
+ logger.finish()
162
+
163
+ def _iter_loggers(self):
164
+ """Return logger instances as a list.
165
+
166
+ Returns
167
+ -------
168
+ list
169
+ Active logger instances. Empty when ``model_logger`` is ``None``.
170
+ """
171
+ if self.model_logger is None:
172
+ return []
173
+ if isinstance(self.model_logger, (list, tuple)):
174
+ return list(self.model_logger)
175
+ return [self.model_logger]
176
+
139
177
  def evaluate(self, model, data_loader):
140
178
  model.eval()
141
179
  targets, predicts = list(), list()
@@ -177,7 +215,7 @@ class MatchTrainer(object):
177
215
  predicts.append(y_pred.data)
178
216
  return torch.cat(predicts, dim=0)
179
217
 
180
- def export_onnx(self, output_path, mode=None, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
218
+ def export_onnx(self, output_path, mode=None, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
181
219
  """Export the trained matching model to ONNX format.
182
220
 
183
221
  This method exports matching/retrieval models (e.g., DSSM, YoutubeDNN, MIND)
@@ -199,6 +237,7 @@ class MatchTrainer(object):
199
237
  device (str, optional): Device for export ('cpu', 'cuda', etc.).
200
238
  If None, defaults to 'cpu' for maximum compatibility.
201
239
  verbose (bool): Print export details (default: False).
240
+ onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
202
241
 
203
242
  Returns:
204
243
  bool: True if export succeeded, False otherwise.
@@ -232,7 +271,17 @@ class MatchTrainer(object):
232
271
 
233
272
  try:
234
273
  exporter = ONNXExporter(model, device=export_device)
235
- return exporter.export(output_path=output_path, mode=mode, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
274
+ return exporter.export(
275
+ output_path=output_path,
276
+ mode=mode,
277
+ dummy_input=dummy_input,
278
+ batch_size=batch_size,
279
+ seq_length=seq_length,
280
+ opset_version=opset_version,
281
+ dynamic_batch=dynamic_batch,
282
+ verbose=verbose,
283
+ onnx_export_kwargs=onnx_export_kwargs,
284
+ )
236
285
  finally:
237
286
  # Restore original mode
238
287
  if hasattr(model, 'mode'):
@@ -47,6 +47,7 @@ class MTLTrainer(object):
47
47
  device="cpu",
48
48
  gpus=None,
49
49
  model_path="./",
50
+ model_logger=None,
50
51
  ):
51
52
  self.model = model
52
53
  if gpus is None:
@@ -104,6 +105,7 @@ class MTLTrainer(object):
104
105
  self.model_path = model_path
105
106
  # Initialize regularization loss
106
107
  self.reg_loss_fn = RegularizationLoss(**regularization_params)
108
+ self.model_logger = model_logger
107
109
 
108
110
  def train_one_epoch(self, data_loader):
109
111
  self.model.train()
@@ -163,21 +165,42 @@ class MTLTrainer(object):
163
165
  def fit(self, train_dataloader, val_dataloader, mode='base', seed=0):
164
166
  total_log = []
165
167
 
168
+ # Log hyperparameters once
169
+ for logger in self._iter_loggers():
170
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self._current_lr(), 'adaptive_method': self.adaptive_method})
171
+
166
172
  for epoch_i in range(self.n_epoch):
167
173
  _log_per_epoch = self.train_one_epoch(train_dataloader)
168
174
 
175
+ # Collect metrics
176
+ logs = {f'train/task_{task_id}_loss': loss_val for task_id, loss_val in enumerate(_log_per_epoch)}
177
+ lr_value = self._current_lr()
178
+ if lr_value is not None:
179
+ logs['learning_rate'] = lr_value
180
+
169
181
  if self.scheduler is not None:
170
182
  if epoch_i % self.scheduler.step_size == 0:
171
183
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
172
184
  self.scheduler.step() # update lr in epoch level by scheduler
185
+
173
186
  scores = self.evaluate(self.model, val_dataloader)
174
187
  print('epoch:', epoch_i, 'validation scores: ', scores)
175
188
 
176
- for score in scores:
189
+ for task_id, score in enumerate(scores):
190
+ logs[f'val/task_{task_id}_score'] = score
177
191
  _log_per_epoch.append(score)
192
+ logs['auc'] = scores[self.earlystop_taskid]
193
+
194
+ if self.loss_weight:
195
+ for task_id, weight in enumerate(self.loss_weight):
196
+ logs[f'loss_weight/task_{task_id}'] = weight.item()
178
197
 
179
198
  total_log.append(_log_per_epoch)
180
199
 
200
+ # Log metrics once per epoch
201
+ for logger in self._iter_loggers():
202
+ logger.log_metrics(logs, step=epoch_i)
203
+
181
204
  if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
182
205
  print('validation best auc of main task %d: %.6f' % (self.earlystop_taskid, self.early_stopper.best_auc))
183
206
  self.model.load_state_dict(self.early_stopper.best_weights)
@@ -185,8 +208,33 @@ class MTLTrainer(object):
185
208
 
186
209
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) # save best auc model
187
210
 
211
+ for logger in self._iter_loggers():
212
+ logger.finish()
213
+
188
214
  return total_log
189
215
 
216
+ def _iter_loggers(self):
217
+ """Return logger instances as a list.
218
+
219
+ Returns
220
+ -------
221
+ list
222
+ Active logger instances. Empty when ``model_logger`` is ``None``.
223
+ """
224
+ if self.model_logger is None:
225
+ return []
226
+ if isinstance(self.model_logger, (list, tuple)):
227
+ return list(self.model_logger)
228
+ return [self.model_logger]
229
+
230
+ def _current_lr(self):
231
+ """Fetch current learning rate regardless of adaptive method."""
232
+ if self.adaptive_method == "metabalance":
233
+ return self.share_optimizer.param_groups[0]['lr'] if hasattr(self, 'share_optimizer') else None
234
+ if hasattr(self, 'optimizer'):
235
+ return self.optimizer.param_groups[0]['lr']
236
+ return None
237
+
190
238
  def evaluate(self, model, data_loader):
191
239
  model.eval()
192
240
  targets, predicts = list(), list()
@@ -213,7 +261,7 @@ class MTLTrainer(object):
213
261
  predicts.extend(y_preds.tolist())
214
262
  return predicts
215
263
 
216
- def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
264
+ def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
217
265
  """Export the trained multi-task model to ONNX format.
218
266
 
219
267
  This method exports multi-task learning models (e.g., MMOE, PLE, ESMM, SharedBottom)
@@ -235,6 +283,7 @@ class MTLTrainer(object):
235
283
  device (str, optional): Device for export ('cpu', 'cuda', etc.).
236
284
  If None, defaults to 'cpu' for maximum compatibility.
237
285
  verbose (bool): Print export details (default: False).
286
+ onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
238
287
 
239
288
  Returns:
240
289
  bool: True if export succeeded, False otherwise.
@@ -256,7 +305,16 @@ class MTLTrainer(object):
256
305
  export_device = device if device is not None else 'cpu'
257
306
 
258
307
  exporter = ONNXExporter(model, device=export_device)
259
- return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
308
+ return exporter.export(
309
+ output_path=output_path,
310
+ dummy_input=dummy_input,
311
+ batch_size=batch_size,
312
+ seq_length=seq_length,
313
+ opset_version=opset_version,
314
+ dynamic_batch=dynamic_batch,
315
+ verbose=verbose,
316
+ onnx_export_kwargs=onnx_export_kwargs,
317
+ )
260
318
 
261
319
  def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
262
320
  """Visualize the model's computation graph.
@@ -46,7 +46,22 @@ class SeqTrainer(object):
46
46
  ... )
47
47
  """
48
48
 
49
- def __init__(self, model, optimizer_fn=torch.optim.Adam, optimizer_params=None, scheduler_fn=None, scheduler_params=None, n_epoch=10, earlystop_patience=10, device='cpu', gpus=None, model_path='./', loss_type='cross_entropy', loss_params=None):
49
+ def __init__(
50
+ self,
51
+ model,
52
+ optimizer_fn=torch.optim.Adam,
53
+ optimizer_params=None,
54
+ scheduler_fn=None,
55
+ scheduler_params=None,
56
+ n_epoch=10,
57
+ earlystop_patience=10,
58
+ device='cpu',
59
+ gpus=None,
60
+ model_path='./',
61
+ loss_type='cross_entropy',
62
+ loss_params=None,
63
+ model_logger=None
64
+ ):
50
65
  self.model = model # for uniform weights save method in one gpu or multi gpu
51
66
  if gpus is None:
52
67
  gpus = []
@@ -74,9 +89,11 @@ class SeqTrainer(object):
74
89
  loss_params = {"ignore_index": 0}
75
90
  self.loss_fn = nn.CrossEntropyLoss(**loss_params)
76
91
 
92
+ self.loss_type = loss_type
77
93
  self.n_epoch = n_epoch
78
94
  self.early_stopper = EarlyStopper(patience=earlystop_patience)
79
95
  self.model_path = model_path
96
+ self.model_logger = model_logger
80
97
 
81
98
  def fit(self, train_dataloader, val_dataloader=None):
82
99
  """训练模型.
@@ -90,10 +107,18 @@ class SeqTrainer(object):
90
107
  """
91
108
  history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
92
109
 
110
+ for logger in self._iter_loggers():
111
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_type': self.loss_type})
112
+
93
113
  for epoch_i in range(self.n_epoch):
94
114
  print('epoch:', epoch_i)
95
115
  # 训练阶段
96
- self.train_one_epoch(train_dataloader)
116
+ train_loss = self.train_one_epoch(train_dataloader)
117
+ history['train_loss'].append(train_loss)
118
+
119
+ # Collect metrics
120
+ logs = {'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}
121
+
97
122
  if self.scheduler is not None:
98
123
  if epoch_i % self.scheduler.step_size == 0:
99
124
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
@@ -105,6 +130,10 @@ class SeqTrainer(object):
105
130
  history['val_loss'].append(val_loss)
106
131
  history['val_accuracy'].append(val_accuracy)
107
132
 
133
+ logs['val/loss'] = val_loss
134
+ logs['val/accuracy'] = val_accuracy
135
+ logs['auc'] = val_accuracy # For compatibility with EarlyStopper
136
+
108
137
  print(f"epoch: {epoch_i}, validation: loss: {val_loss:.4f}, accuracy: {val_accuracy:.4f}")
109
138
 
110
139
  # 早停
@@ -113,9 +142,30 @@ class SeqTrainer(object):
113
142
  self.model.load_state_dict(self.early_stopper.best_weights)
114
143
  break
115
144
 
145
+ for logger in self._iter_loggers():
146
+ logger.log_metrics(logs, step=epoch_i)
147
+
116
148
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best model
149
+
150
+ for logger in self._iter_loggers():
151
+ logger.finish()
152
+
117
153
  return history
118
154
 
155
+ def _iter_loggers(self):
156
+ """Return logger instances as a list.
157
+
158
+ Returns
159
+ -------
160
+ list
161
+ Active logger instances. Empty when ``model_logger`` is ``None``.
162
+ """
163
+ if self.model_logger is None:
164
+ return []
165
+ if isinstance(self.model_logger, (list, tuple)):
166
+ return list(self.model_logger)
167
+ return [self.model_logger]
168
+
119
169
  def train_one_epoch(self, data_loader, log_interval=10):
120
170
  """Train the model for a single epoch.
121
171
 
@@ -128,6 +178,8 @@ class SeqTrainer(object):
128
178
  """
129
179
  self.model.train()
130
180
  total_loss = 0
181
+ epoch_loss = 0
182
+ batch_count = 0
131
183
  tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
132
184
  for i, (seq_tokens, seq_positions, seq_time_diffs, targets) in enumerate(tk0):
133
185
  # Move tensors to the target device
@@ -152,10 +204,15 @@ class SeqTrainer(object):
152
204
  self.optimizer.step()
153
205
 
154
206
  total_loss += loss.item()
207
+ epoch_loss += loss.item()
208
+ batch_count += 1
155
209
  if (i + 1) % log_interval == 0:
156
210
  tk0.set_postfix(loss=total_loss / log_interval)
157
211
  total_loss = 0
158
212
 
213
+ # Return average epoch loss
214
+ return epoch_loss / batch_count if batch_count > 0 else 0
215
+
159
216
  def evaluate(self, data_loader):
160
217
  """Evaluate the model on a validation/test data loader.
161
218
 
@@ -198,7 +255,7 @@ class SeqTrainer(object):
198
255
 
199
256
  return avg_loss, accuracy
200
257
 
201
- def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False):
258
+ def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
202
259
  """Export the trained sequence generation model to ONNX format.
203
260
 
204
261
  This method exports sequence generation models (e.g., HSTU) to ONNX format.
@@ -216,6 +273,7 @@ class SeqTrainer(object):
216
273
  device (str, optional): Device for export ('cpu', 'cuda', etc.).
217
274
  If None, defaults to 'cpu' for maximum compatibility.
218
275
  verbose (bool): Print export details (default: False).
276
+ onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
219
277
 
220
278
  Returns:
221
279
  bool: True if export succeeded, False otherwise.
@@ -264,20 +322,38 @@ class SeqTrainer(object):
264
322
 
265
323
  try:
266
324
  with torch.no_grad():
267
- torch.onnx.export(
268
- model,
269
- (dummy_seq_tokens,
270
- dummy_seq_time_diffs),
271
- output_path,
272
- input_names=["seq_tokens",
273
- "seq_time_diffs"],
274
- output_names=["output"],
275
- dynamic_axes=dynamic_axes,
276
- opset_version=opset_version,
277
- do_constant_folding=True,
278
- verbose=verbose,
279
- dynamo=False # Use legacy exporter for dynamic_axes support
280
- )
325
+ import inspect
326
+
327
+ export_kwargs = {
328
+ "f": output_path,
329
+ "input_names": ["seq_tokens",
330
+ "seq_time_diffs"],
331
+ "output_names": ["output"],
332
+ "dynamic_axes": dynamic_axes,
333
+ "opset_version": opset_version,
334
+ "do_constant_folding": True,
335
+ "verbose": verbose,
336
+ }
337
+
338
+ if onnx_export_kwargs:
339
+ overlap = set(export_kwargs.keys()) & set(onnx_export_kwargs.keys())
340
+ overlap.discard("dynamo")
341
+ if overlap:
342
+ raise ValueError("onnx_export_kwargs contains keys that overlap with explicit args: "
343
+ f"{sorted(overlap)}. Please set them via export_onnx() parameters instead.")
344
+ export_kwargs.update(onnx_export_kwargs)
345
+
346
+ # Auto-pick exporter:
347
+ # - dynamic_axes present => prefer legacy exporter (dynamo=False) for dynamic batch/seq
348
+ # - otherwise prefer dynamo exporter (dynamo=True) on newer torch
349
+ sig = inspect.signature(torch.onnx.export)
350
+ if "dynamo" in sig.parameters:
351
+ if "dynamo" not in export_kwargs:
352
+ export_kwargs["dynamo"] = False if dynamic_axes is not None else True
353
+ else:
354
+ export_kwargs.pop("dynamo", None)
355
+
356
+ torch.onnx.export(model, (dummy_seq_tokens, dummy_seq_time_diffs), **export_kwargs)
281
357
 
282
358
  if verbose:
283
359
  print(f"Successfully exported ONNX model to: {output_path}")
torch_rechub/types.py ADDED
@@ -0,0 +1,5 @@
1
+ import os
2
+ import typing as ty
3
+
4
+ #: Type for path to a file.
5
+ FilePath = ty.Union[str, os.PathLike]