torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -0,0 +1,427 @@
1
+ """Sequence Generation Model Trainer."""
2
+
3
+ import os
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import tqdm
8
+
9
+ from ..basic.callback import EarlyStopper
10
+ from ..basic.loss_func import NCELoss
11
+
12
+
13
+ class SeqTrainer(object):
14
+ """序列生成模型训练器.
15
+
16
+ 用于训练HSTU等序列生成模型。
17
+ 支持CrossEntropyLoss损失函数和生成式评估指标。
18
+
19
+ Args:
20
+ model (nn.Module): 要训练的模型
21
+ optimizer_fn (torch.optim): 优化器函数,默认为torch.optim.Adam
22
+ optimizer_params (dict): 优化器参数
23
+ scheduler_fn (torch.optim.lr_scheduler): torch调度器类
24
+ scheduler_params (dict): 调度器参数
25
+ n_epoch (int): 训练轮数,默认10
26
+ earlystop_patience (int): 早停耐心值,默认10
27
+ device (str): 设备,'cpu'或'cuda',默认'cpu'
28
+ gpus (list): 多GPU的id列表,默认为[]
29
+ model_path (str): 模型保存路径,默认为'./'
30
+
31
+ Methods:
32
+ fit: 训练模型
33
+ evaluate: 评估模型
34
+ predict: 生成预测
35
+
36
+ Example:
37
+ >>> trainer = SeqTrainer(
38
+ ... model=model,
39
+ ... optimizer_fn=torch.optim.Adam,
40
+ ... optimizer_params={'lr': 1e-3, 'weight_decay': 1e-5},
41
+ ... device='cuda'
42
+ ... )
43
+ >>> trainer.fit(
44
+ ... train_loader=train_loader,
45
+ ... val_loader=val_loader
46
+ ... )
47
+ """
48
+
49
+ def __init__(self, model, optimizer_fn=torch.optim.Adam, optimizer_params=None, scheduler_fn=None, scheduler_params=None, n_epoch=10, earlystop_patience=10, device='cpu', gpus=None, model_path='./', loss_type='cross_entropy', loss_params=None):
50
+ self.model = model # for uniform weights save method in one gpu or multi gpu
51
+ if gpus is None:
52
+ gpus = []
53
+ self.gpus = gpus
54
+ if len(gpus) > 1:
55
+ print('parallel running on these gpus:', gpus)
56
+ self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
57
+ # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+ self.device = torch.device(device)
59
+ self.model.to(self.device)
60
+ if optimizer_params is None:
61
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
62
+ self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default optimizer
63
+ self.scheduler = None
64
+ if scheduler_fn is not None:
65
+ self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
66
+
67
+ # 损失函数
68
+ if loss_type == 'nce':
69
+ if loss_params is None:
70
+ loss_params = {"temperature": 0.1, "ignore_index": 0}
71
+ self.loss_fn = NCELoss(**loss_params)
72
+ else: # default to cross_entropy
73
+ if loss_params is None:
74
+ loss_params = {"ignore_index": 0}
75
+ self.loss_fn = nn.CrossEntropyLoss(**loss_params)
76
+
77
+ self.n_epoch = n_epoch
78
+ self.early_stopper = EarlyStopper(patience=earlystop_patience)
79
+ self.model_path = model_path
80
+
81
+ def fit(self, train_dataloader, val_dataloader=None):
82
+ """训练模型.
83
+
84
+ Args:
85
+ train_dataloader (DataLoader): 训练数据加载器
86
+ val_dataloader (DataLoader): 验证数据加载器
87
+
88
+ Returns:
89
+ dict: 训练历史
90
+ """
91
+ history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
92
+
93
+ for epoch_i in range(self.n_epoch):
94
+ print('epoch:', epoch_i)
95
+ # 训练阶段
96
+ self.train_one_epoch(train_dataloader)
97
+ if self.scheduler is not None:
98
+ if epoch_i % self.scheduler.step_size == 0:
99
+ print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
100
+ self.scheduler.step() # update lr in epoch level by scheduler
101
+
102
+ # 验证阶段
103
+ if val_dataloader:
104
+ val_loss, val_accuracy = self.evaluate(val_dataloader)
105
+ history['val_loss'].append(val_loss)
106
+ history['val_accuracy'].append(val_accuracy)
107
+
108
+ print(f"epoch: {epoch_i}, validation: loss: {val_loss:.4f}, accuracy: {val_accuracy:.4f}")
109
+
110
+ # 早停
111
+ if self.early_stopper.stop_training(val_accuracy, self.model.state_dict()):
112
+ print(f'validation: best accuracy: {self.early_stopper.best_auc}')
113
+ self.model.load_state_dict(self.early_stopper.best_weights)
114
+ break
115
+
116
+ torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best model
117
+ return history
118
+
119
+ def train_one_epoch(self, data_loader, log_interval=10):
120
+ """Train the model for a single epoch.
121
+
122
+ Args:
123
+ data_loader (DataLoader): Training data loader.
124
+ log_interval (int): Interval (in steps) for logging average loss.
125
+
126
+ Returns:
127
+ float: Average training loss for this epoch.
128
+ """
129
+ self.model.train()
130
+ total_loss = 0
131
+ tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
132
+ for i, (seq_tokens, seq_positions, seq_time_diffs, targets) in enumerate(tk0):
133
+ # Move tensors to the target device
134
+ seq_tokens = seq_tokens.to(self.device)
135
+ seq_positions = seq_positions.to(self.device)
136
+ seq_time_diffs = seq_time_diffs.to(self.device)
137
+ targets = targets.to(self.device).squeeze(-1)
138
+
139
+ # Forward pass
140
+ logits = self.model(seq_tokens, seq_time_diffs) # (B, L, V)
141
+
142
+ # Compute loss
143
+ # For next-item prediction we only use the last position in the sequence
144
+ # logits[:, -1, :] selects the prediction at the last step for each sequence
145
+ last_logits = logits[:, -1, :] # (B, V)
146
+
147
+ loss = self.loss_fn(last_logits, targets)
148
+
149
+ # 反向传播
150
+ self.model.zero_grad()
151
+ loss.backward()
152
+ self.optimizer.step()
153
+
154
+ total_loss += loss.item()
155
+ if (i + 1) % log_interval == 0:
156
+ tk0.set_postfix(loss=total_loss / log_interval)
157
+ total_loss = 0
158
+
159
+ def evaluate(self, data_loader):
160
+ """Evaluate the model on a validation/test data loader.
161
+
162
+ Args:
163
+ data_loader (DataLoader): Validation or test data loader.
164
+
165
+ Returns:
166
+ tuple: ``(avg_loss, top1_accuracy)``.
167
+ """
168
+ self.model.eval()
169
+ total_loss = 0.0
170
+ total_correct = 0
171
+ total_samples = 0
172
+
173
+ with torch.no_grad():
174
+ for seq_tokens, seq_positions, seq_time_diffs, targets in tqdm.tqdm(data_loader, desc="evaluating", smoothing=0, mininterval=1.0):
175
+ # Move tensors to the target device
176
+ seq_tokens = seq_tokens.to(self.device)
177
+ seq_positions = seq_positions.to(self.device)
178
+ seq_time_diffs = seq_time_diffs.to(self.device)
179
+ targets = targets.to(self.device).squeeze(-1)
180
+
181
+ # Forward pass
182
+ logits = self.model(seq_tokens, seq_time_diffs) # (B, L, V)
183
+
184
+ # Compute loss using only the last position (next-item prediction)
185
+ last_logits = logits[:, -1, :] # (B, V)
186
+
187
+ loss = self.loss_fn(last_logits, targets)
188
+ total_loss += loss.item()
189
+
190
+ # Compute top-1 accuracy
191
+ predictions = torch.argmax(last_logits, dim=-1) # (B,)
192
+ correct = (predictions == targets).sum().item()
193
+ total_correct += correct
194
+ total_samples += targets.numel()
195
+
196
+ avg_loss = total_loss / len(data_loader)
197
+ accuracy = total_correct / total_samples
198
+
199
+ return avg_loss, accuracy
200
+
201
+ def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False):
202
+ """Export the trained sequence generation model to ONNX format.
203
+
204
+ This method exports sequence generation models (e.g., HSTU) to ONNX format.
205
+ Unlike other trainers, sequence models use positional arguments (seq_tokens, seq_time_diffs)
206
+ instead of dict input, making ONNX export more straightforward.
207
+
208
+ Args:
209
+ output_path (str): Path to save the ONNX model file.
210
+ batch_size (int): Batch size for dummy input (default: 2).
211
+ seq_length (int): Sequence length for dummy input (default: 50).
212
+ vocab_size (int, optional): Vocabulary size for generating dummy tokens.
213
+ If None, will try to get from model.vocab_size.
214
+ opset_version (int): ONNX opset version (default: 14).
215
+ dynamic_batch (bool): Enable dynamic batch size (default: True).
216
+ device (str, optional): Device for export ('cpu', 'cuda', etc.).
217
+ If None, defaults to 'cpu' for maximum compatibility.
218
+ verbose (bool): Print export details (default: False).
219
+
220
+ Returns:
221
+ bool: True if export succeeded, False otherwise.
222
+
223
+ Example:
224
+ >>> trainer = SeqTrainer(hstu_model, ...)
225
+ >>> trainer.fit(train_dl, val_dl)
226
+ >>> trainer.export_onnx("hstu.onnx", vocab_size=10000)
227
+
228
+ >>> # Export on specific device
229
+ >>> trainer.export_onnx("hstu.onnx", vocab_size=10000, device="cpu")
230
+ """
231
+ import warnings
232
+
233
+ # Use provided device or default to 'cpu'
234
+ export_device = device if device is not None else 'cpu'
235
+
236
+ # Handle DataParallel wrapped model
237
+ model = self.model.module if hasattr(self.model, 'module') else self.model
238
+ model.eval()
239
+ model.to(export_device)
240
+
241
+ # Get vocab_size from model if not provided
242
+ if vocab_size is None:
243
+ if hasattr(model, 'vocab_size'):
244
+ vocab_size = model.vocab_size
245
+ elif hasattr(model, 'item_num'):
246
+ vocab_size = model.item_num
247
+ else:
248
+ raise ValueError("vocab_size must be provided or model must have 'vocab_size' or 'item_num' attribute")
249
+
250
+ # Generate dummy inputs on the export device
251
+ dummy_seq_tokens = torch.randint(0, vocab_size, (batch_size, seq_length), device=export_device)
252
+ dummy_seq_time_diffs = torch.zeros(batch_size, seq_length, dtype=torch.float32, device=export_device)
253
+
254
+ # Configure dynamic axes
255
+ dynamic_axes = None
256
+ if dynamic_batch:
257
+ dynamic_axes = {"seq_tokens": {0: "batch_size", 1: "seq_length"}, "seq_time_diffs": {0: "batch_size", 1: "seq_length"}, "output": {0: "batch_size", 1: "seq_length"}}
258
+
259
+ # Ensure output directory exists
260
+ import os
261
+ output_dir = os.path.dirname(output_path)
262
+ if output_dir and not os.path.exists(output_dir):
263
+ os.makedirs(output_dir)
264
+
265
+ try:
266
+ with torch.no_grad():
267
+ torch.onnx.export(
268
+ model,
269
+ (dummy_seq_tokens,
270
+ dummy_seq_time_diffs),
271
+ output_path,
272
+ input_names=["seq_tokens",
273
+ "seq_time_diffs"],
274
+ output_names=["output"],
275
+ dynamic_axes=dynamic_axes,
276
+ opset_version=opset_version,
277
+ do_constant_folding=True,
278
+ verbose=verbose,
279
+ dynamo=False # Use legacy exporter for dynamic_axes support
280
+ )
281
+
282
+ if verbose:
283
+ print(f"Successfully exported ONNX model to: {output_path}")
284
+ print(" Input names: ['seq_tokens', 'seq_time_diffs']")
285
+ print(f" Vocab size: {vocab_size}")
286
+ print(f" Opset version: {opset_version}")
287
+ print(f" Dynamic batch: {dynamic_batch}")
288
+
289
+ return True
290
+
291
+ except Exception as e:
292
+ warnings.warn(f"ONNX export failed: {str(e)}")
293
+ raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
294
+
295
+ def visualization(self, seq_length=50, vocab_size=None, batch_size=2, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
296
+ """Visualize the model's computation graph.
297
+
298
+ This method generates a visual representation of the sequence model
299
+ architecture, showing layer connections, tensor shapes, and nested
300
+ module structures.
301
+
302
+ Parameters
303
+ ----------
304
+ seq_length : int, default=50
305
+ Sequence length for dummy input.
306
+ vocab_size : int, optional
307
+ Vocabulary size for generating dummy tokens.
308
+ If None, will try to get from model.vocab_size or model.item_num.
309
+ batch_size : int, default=2
310
+ Batch size for dummy input.
311
+ depth : int, default=3
312
+ Visualization depth, higher values show more detail.
313
+ Set to -1 to show all layers.
314
+ show_shapes : bool, default=True
315
+ Whether to display tensor shapes.
316
+ expand_nested : bool, default=True
317
+ Whether to expand nested modules.
318
+ save_path : str, optional
319
+ Path to save the graph image (.pdf, .svg, .png).
320
+ If None, displays in Jupyter or opens system viewer.
321
+ graph_name : str, default="model"
322
+ Name for the graph.
323
+ device : str, optional
324
+ Device for model execution. If None, defaults to 'cpu'.
325
+ dpi : int, default=300
326
+ Resolution in dots per inch for output image.
327
+ Higher values produce sharper images suitable for papers.
328
+ **kwargs : dict
329
+ Additional arguments passed to torchview.draw_graph().
330
+
331
+ Returns
332
+ -------
333
+ ComputationGraph
334
+ A torchview ComputationGraph object.
335
+
336
+ Raises
337
+ ------
338
+ ImportError
339
+ If torchview or graphviz is not installed.
340
+ ValueError
341
+ If vocab_size is not provided and cannot be inferred from model.
342
+
343
+ Notes
344
+ -----
345
+ Default Display Behavior:
346
+ When `save_path` is None (default):
347
+ - In Jupyter/IPython: automatically displays the graph inline
348
+ - In Python script: opens the graph with system default viewer
349
+
350
+ Examples
351
+ --------
352
+ >>> trainer = SeqTrainer(hstu_model, ...)
353
+ >>> trainer.fit(train_dl, val_dl)
354
+ >>>
355
+ >>> # Auto-display in Jupyter (no save_path needed)
356
+ >>> trainer.visualization(depth=4, vocab_size=10000)
357
+ >>>
358
+ >>> # Save to high-DPI PNG for papers
359
+ >>> trainer.visualization(save_path="model.png", dpi=300)
360
+ """
361
+ try:
362
+ from torchview import draw_graph
363
+ TORCHVIEW_AVAILABLE = True
364
+ except ImportError:
365
+ TORCHVIEW_AVAILABLE = False
366
+
367
+ if not TORCHVIEW_AVAILABLE:
368
+ raise ImportError(
369
+ "Visualization requires torchview. "
370
+ "Install with: pip install torch-rechub[visualization]\n"
371
+ "Also ensure graphviz is installed on your system:\n"
372
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
373
+ " - macOS: brew install graphviz\n"
374
+ " - Windows: choco install graphviz"
375
+ )
376
+
377
+ from ..utils.visualization import _is_jupyter_environment, display_graph
378
+
379
+ # Handle DataParallel wrapped model
380
+ model = self.model.module if hasattr(self.model, 'module') else self.model
381
+
382
+ # Use provided device or default to 'cpu'
383
+ viz_device = device if device is not None else 'cpu'
384
+
385
+ # Get vocab_size from model if not provided
386
+ if vocab_size is None:
387
+ if hasattr(model, 'vocab_size'):
388
+ vocab_size = model.vocab_size
389
+ elif hasattr(model, 'item_num'):
390
+ vocab_size = model.item_num
391
+ else:
392
+ raise ValueError("vocab_size must be provided or model must have "
393
+ "'vocab_size' or 'item_num' attribute")
394
+
395
+ # Generate dummy inputs for sequence model
396
+ dummy_seq_tokens = torch.randint(0, vocab_size, (batch_size, seq_length), device=viz_device)
397
+ dummy_seq_time_diffs = torch.zeros(batch_size, seq_length, dtype=torch.float32, device=viz_device)
398
+
399
+ # Move model to device
400
+ model = model.to(viz_device)
401
+ model.eval()
402
+
403
+ # Call torchview.draw_graph
404
+ graph = draw_graph(model, input_data=(dummy_seq_tokens, dummy_seq_time_diffs), graph_name=graph_name, depth=depth, device=viz_device, expand_nested=expand_nested, show_shapes=show_shapes, save_graph=False, **kwargs)
405
+
406
+ # Set DPI for high-quality output
407
+ graph.visual_graph.graph_attr['dpi'] = str(dpi)
408
+
409
+ # Handle save_path: manually save with DPI applied
410
+ if save_path:
411
+ import os
412
+ directory = os.path.dirname(save_path) or "."
413
+ filename = os.path.splitext(os.path.basename(save_path))[0]
414
+ ext = os.path.splitext(save_path)[1].lstrip('.')
415
+ output_format = ext if ext else 'pdf'
416
+ if directory != "." and not os.path.exists(directory):
417
+ os.makedirs(directory, exist_ok=True)
418
+ graph.visual_graph.render(filename=filename, directory=directory, format=output_format, cleanup=True)
419
+
420
+ # Handle default display behavior when save_path is None
421
+ if save_path is None:
422
+ if _is_jupyter_environment():
423
+ display_graph(graph)
424
+ else:
425
+ graph.visual_graph.view(cleanup=True)
426
+
427
+ return graph