torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +3 -1
  3. torch_rechub/basic/callback.py +2 -2
  4. torch_rechub/basic/features.py +38 -8
  5. torch_rechub/basic/initializers.py +92 -0
  6. torch_rechub/basic/layers.py +800 -46
  7. torch_rechub/basic/loss_func.py +223 -0
  8. torch_rechub/basic/metaoptimizer.py +76 -0
  9. torch_rechub/basic/metric.py +251 -0
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -0
  14. torch_rechub/models/matching/comirec.py +193 -0
  15. torch_rechub/models/matching/dssm.py +72 -0
  16. torch_rechub/models/matching/dssm_facebook.py +77 -0
  17. torch_rechub/models/matching/dssm_senet.py +87 -0
  18. torch_rechub/models/matching/gru4rec.py +85 -0
  19. torch_rechub/models/matching/mind.py +103 -0
  20. torch_rechub/models/matching/narm.py +82 -0
  21. torch_rechub/models/matching/sasrec.py +143 -0
  22. torch_rechub/models/matching/sine.py +148 -0
  23. torch_rechub/models/matching/stamp.py +81 -0
  24. torch_rechub/models/matching/youtube_dnn.py +75 -0
  25. torch_rechub/models/matching/youtube_sbc.py +98 -0
  26. torch_rechub/models/multi_task/__init__.py +5 -2
  27. torch_rechub/models/multi_task/aitm.py +83 -0
  28. torch_rechub/models/multi_task/esmm.py +19 -8
  29. torch_rechub/models/multi_task/mmoe.py +18 -12
  30. torch_rechub/models/multi_task/ple.py +41 -29
  31. torch_rechub/models/multi_task/shared_bottom.py +3 -2
  32. torch_rechub/models/ranking/__init__.py +13 -2
  33. torch_rechub/models/ranking/afm.py +65 -0
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -0
  36. torch_rechub/models/ranking/dcn.py +38 -0
  37. torch_rechub/models/ranking/dcn_v2.py +59 -0
  38. torch_rechub/models/ranking/deepffm.py +131 -0
  39. torch_rechub/models/ranking/deepfm.py +8 -7
  40. torch_rechub/models/ranking/dien.py +191 -0
  41. torch_rechub/models/ranking/din.py +31 -19
  42. torch_rechub/models/ranking/edcn.py +101 -0
  43. torch_rechub/models/ranking/fibinet.py +42 -0
  44. torch_rechub/models/ranking/widedeep.py +6 -6
  45. torch_rechub/trainers/__init__.py +4 -2
  46. torch_rechub/trainers/ctr_trainer.py +191 -0
  47. torch_rechub/trainers/match_trainer.py +239 -0
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +137 -23
  50. torch_rechub/trainers/seq_trainer.py +293 -0
  51. torch_rechub/utils/__init__.py +0 -0
  52. torch_rechub/utils/data.py +492 -0
  53. torch_rechub/utils/hstu_utils.py +198 -0
  54. torch_rechub/utils/match.py +457 -0
  55. torch_rechub/utils/mtl.py +136 -0
  56. torch_rechub/utils/onnx_export.py +353 -0
  57. torch_rechub-0.0.4.dist-info/METADATA +391 -0
  58. torch_rechub-0.0.4.dist-info/RECORD +62 -0
  59. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
  60. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
  61. torch_rechub/basic/utils.py +0 -168
  62. torch_rechub/trainers/trainer.py +0 -111
  63. torch_rechub-0.0.1.dist-info/METADATA +0 -105
  64. torch_rechub-0.0.1.dist-info/RECORD +0 -26
  65. torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,293 @@
1
+ """Sequence Generation Model Trainer."""
2
+
3
+ import os
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import tqdm
8
+
9
+ from ..basic.callback import EarlyStopper
10
+ from ..basic.loss_func import NCELoss
11
+
12
+
13
+ class SeqTrainer(object):
14
+ """序列生成模型训练器.
15
+
16
+ 用于训练HSTU等序列生成模型。
17
+ 支持CrossEntropyLoss损失函数和生成式评估指标。
18
+
19
+ Args:
20
+ model (nn.Module): 要训练的模型
21
+ optimizer_fn (torch.optim): 优化器函数,默认为torch.optim.Adam
22
+ optimizer_params (dict): 优化器参数
23
+ scheduler_fn (torch.optim.lr_scheduler): torch调度器类
24
+ scheduler_params (dict): 调度器参数
25
+ n_epoch (int): 训练轮数,默认10
26
+ earlystop_patience (int): 早停耐心值,默认10
27
+ device (str): 设备,'cpu'或'cuda',默认'cpu'
28
+ gpus (list): 多GPU的id列表,默认为[]
29
+ model_path (str): 模型保存路径,默认为'./'
30
+
31
+ Methods:
32
+ fit: 训练模型
33
+ evaluate: 评估模型
34
+ predict: 生成预测
35
+
36
+ Example:
37
+ >>> trainer = SeqTrainer(
38
+ ... model=model,
39
+ ... optimizer_fn=torch.optim.Adam,
40
+ ... optimizer_params={'lr': 1e-3, 'weight_decay': 1e-5},
41
+ ... device='cuda'
42
+ ... )
43
+ >>> trainer.fit(
44
+ ... train_loader=train_loader,
45
+ ... val_loader=val_loader
46
+ ... )
47
+ """
48
+
49
+ def __init__(self, model, optimizer_fn=torch.optim.Adam, optimizer_params=None, scheduler_fn=None, scheduler_params=None, n_epoch=10, earlystop_patience=10, device='cpu', gpus=None, model_path='./', loss_type='cross_entropy', loss_params=None):
50
+ self.model = model # for uniform weights save method in one gpu or multi gpu
51
+ if gpus is None:
52
+ gpus = []
53
+ self.gpus = gpus
54
+ if len(gpus) > 1:
55
+ print('parallel running on these gpus:', gpus)
56
+ self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
57
+ # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+ self.device = torch.device(device)
59
+ self.model.to(self.device)
60
+ if optimizer_params is None:
61
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
62
+ self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default optimizer
63
+ self.scheduler = None
64
+ if scheduler_fn is not None:
65
+ self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
66
+
67
+ # 损失函数
68
+ if loss_type == 'nce':
69
+ if loss_params is None:
70
+ loss_params = {"temperature": 0.1, "ignore_index": 0}
71
+ self.loss_fn = NCELoss(**loss_params)
72
+ else: # default to cross_entropy
73
+ if loss_params is None:
74
+ loss_params = {"ignore_index": 0}
75
+ self.loss_fn = nn.CrossEntropyLoss(**loss_params)
76
+
77
+ self.n_epoch = n_epoch
78
+ self.early_stopper = EarlyStopper(patience=earlystop_patience)
79
+ self.model_path = model_path
80
+
81
+ def fit(self, train_dataloader, val_dataloader=None):
82
+ """训练模型.
83
+
84
+ Args:
85
+ train_dataloader (DataLoader): 训练数据加载器
86
+ val_dataloader (DataLoader): 验证数据加载器
87
+
88
+ Returns:
89
+ dict: 训练历史
90
+ """
91
+ history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
92
+
93
+ for epoch_i in range(self.n_epoch):
94
+ print('epoch:', epoch_i)
95
+ # 训练阶段
96
+ self.train_one_epoch(train_dataloader)
97
+ if self.scheduler is not None:
98
+ if epoch_i % self.scheduler.step_size == 0:
99
+ print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
100
+ self.scheduler.step() # update lr in epoch level by scheduler
101
+
102
+ # 验证阶段
103
+ if val_dataloader:
104
+ val_loss, val_accuracy = self.evaluate(val_dataloader)
105
+ history['val_loss'].append(val_loss)
106
+ history['val_accuracy'].append(val_accuracy)
107
+
108
+ print(f"epoch: {epoch_i}, validation: loss: {val_loss:.4f}, accuracy: {val_accuracy:.4f}")
109
+
110
+ # 早停
111
+ if self.early_stopper.stop_training(val_accuracy, self.model.state_dict()):
112
+ print(f'validation: best accuracy: {self.early_stopper.best_auc}')
113
+ self.model.load_state_dict(self.early_stopper.best_weights)
114
+ break
115
+
116
+ torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best model
117
+ return history
118
+
119
+ def train_one_epoch(self, data_loader, log_interval=10):
120
+ """Train the model for a single epoch.
121
+
122
+ Args:
123
+ data_loader (DataLoader): Training data loader.
124
+ log_interval (int): Interval (in steps) for logging average loss.
125
+
126
+ Returns:
127
+ float: Average training loss for this epoch.
128
+ """
129
+ self.model.train()
130
+ total_loss = 0
131
+ tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
132
+ for i, (seq_tokens, seq_positions, seq_time_diffs, targets) in enumerate(tk0):
133
+ # Move tensors to the target device
134
+ seq_tokens = seq_tokens.to(self.device)
135
+ seq_positions = seq_positions.to(self.device)
136
+ seq_time_diffs = seq_time_diffs.to(self.device)
137
+ targets = targets.to(self.device).squeeze(-1)
138
+
139
+ # Forward pass
140
+ logits = self.model(seq_tokens, seq_time_diffs) # (B, L, V)
141
+
142
+ # Compute loss
143
+ # For next-item prediction we only use the last position in the sequence
144
+ # logits[:, -1, :] selects the prediction at the last step for each sequence
145
+ last_logits = logits[:, -1, :] # (B, V)
146
+
147
+ loss = self.loss_fn(last_logits, targets)
148
+
149
+ # 反向传播
150
+ self.model.zero_grad()
151
+ loss.backward()
152
+ self.optimizer.step()
153
+
154
+ total_loss += loss.item()
155
+ if (i + 1) % log_interval == 0:
156
+ tk0.set_postfix(loss=total_loss / log_interval)
157
+ total_loss = 0
158
+
159
+ def evaluate(self, data_loader):
160
+ """Evaluate the model on a validation/test data loader.
161
+
162
+ Args:
163
+ data_loader (DataLoader): Validation or test data loader.
164
+
165
+ Returns:
166
+ tuple: ``(avg_loss, top1_accuracy)``.
167
+ """
168
+ self.model.eval()
169
+ total_loss = 0.0
170
+ total_correct = 0
171
+ total_samples = 0
172
+
173
+ with torch.no_grad():
174
+ for seq_tokens, seq_positions, seq_time_diffs, targets in tqdm.tqdm(data_loader, desc="evaluating", smoothing=0, mininterval=1.0):
175
+ # Move tensors to the target device
176
+ seq_tokens = seq_tokens.to(self.device)
177
+ seq_positions = seq_positions.to(self.device)
178
+ seq_time_diffs = seq_time_diffs.to(self.device)
179
+ targets = targets.to(self.device).squeeze(-1)
180
+
181
+ # Forward pass
182
+ logits = self.model(seq_tokens, seq_time_diffs) # (B, L, V)
183
+
184
+ # Compute loss using only the last position (next-item prediction)
185
+ last_logits = logits[:, -1, :] # (B, V)
186
+
187
+ loss = self.loss_fn(last_logits, targets)
188
+ total_loss += loss.item()
189
+
190
+ # Compute top-1 accuracy
191
+ predictions = torch.argmax(last_logits, dim=-1) # (B,)
192
+ correct = (predictions == targets).sum().item()
193
+ total_correct += correct
194
+ total_samples += targets.numel()
195
+
196
+ avg_loss = total_loss / len(data_loader)
197
+ accuracy = total_correct / total_samples
198
+
199
+ return avg_loss, accuracy
200
+
201
+ def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False):
202
+ """Export the trained sequence generation model to ONNX format.
203
+
204
+ This method exports sequence generation models (e.g., HSTU) to ONNX format.
205
+ Unlike other trainers, sequence models use positional arguments (seq_tokens, seq_time_diffs)
206
+ instead of dict input, making ONNX export more straightforward.
207
+
208
+ Args:
209
+ output_path (str): Path to save the ONNX model file.
210
+ batch_size (int): Batch size for dummy input (default: 2).
211
+ seq_length (int): Sequence length for dummy input (default: 50).
212
+ vocab_size (int, optional): Vocabulary size for generating dummy tokens.
213
+ If None, will try to get from model.vocab_size.
214
+ opset_version (int): ONNX opset version (default: 14).
215
+ dynamic_batch (bool): Enable dynamic batch size (default: True).
216
+ device (str, optional): Device for export ('cpu', 'cuda', etc.).
217
+ If None, defaults to 'cpu' for maximum compatibility.
218
+ verbose (bool): Print export details (default: False).
219
+
220
+ Returns:
221
+ bool: True if export succeeded, False otherwise.
222
+
223
+ Example:
224
+ >>> trainer = SeqTrainer(hstu_model, ...)
225
+ >>> trainer.fit(train_dl, val_dl)
226
+ >>> trainer.export_onnx("hstu.onnx", vocab_size=10000)
227
+
228
+ >>> # Export on specific device
229
+ >>> trainer.export_onnx("hstu.onnx", vocab_size=10000, device="cpu")
230
+ """
231
+ import warnings
232
+
233
+ # Use provided device or default to 'cpu'
234
+ export_device = device if device is not None else 'cpu'
235
+
236
+ # Handle DataParallel wrapped model
237
+ model = self.model.module if hasattr(self.model, 'module') else self.model
238
+ model.eval()
239
+ model.to(export_device)
240
+
241
+ # Get vocab_size from model if not provided
242
+ if vocab_size is None:
243
+ if hasattr(model, 'vocab_size'):
244
+ vocab_size = model.vocab_size
245
+ elif hasattr(model, 'item_num'):
246
+ vocab_size = model.item_num
247
+ else:
248
+ raise ValueError("vocab_size must be provided or model must have 'vocab_size' or 'item_num' attribute")
249
+
250
+ # Generate dummy inputs on the export device
251
+ dummy_seq_tokens = torch.randint(0, vocab_size, (batch_size, seq_length), device=export_device)
252
+ dummy_seq_time_diffs = torch.zeros(batch_size, seq_length, dtype=torch.float32, device=export_device)
253
+
254
+ # Configure dynamic axes
255
+ dynamic_axes = None
256
+ if dynamic_batch:
257
+ dynamic_axes = {"seq_tokens": {0: "batch_size", 1: "seq_length"}, "seq_time_diffs": {0: "batch_size", 1: "seq_length"}, "output": {0: "batch_size", 1: "seq_length"}}
258
+
259
+ # Ensure output directory exists
260
+ import os
261
+ output_dir = os.path.dirname(output_path)
262
+ if output_dir and not os.path.exists(output_dir):
263
+ os.makedirs(output_dir)
264
+
265
+ try:
266
+ with torch.no_grad():
267
+ torch.onnx.export(
268
+ model,
269
+ (dummy_seq_tokens,
270
+ dummy_seq_time_diffs),
271
+ output_path,
272
+ input_names=["seq_tokens",
273
+ "seq_time_diffs"],
274
+ output_names=["output"],
275
+ dynamic_axes=dynamic_axes,
276
+ opset_version=opset_version,
277
+ do_constant_folding=True,
278
+ verbose=verbose,
279
+ dynamo=False # Use legacy exporter for dynamic_axes support
280
+ )
281
+
282
+ if verbose:
283
+ print(f"Successfully exported ONNX model to: {output_path}")
284
+ print(" Input names: ['seq_tokens', 'seq_time_diffs']")
285
+ print(f" Vocab size: {vocab_size}")
286
+ print(f" Opset version: {opset_version}")
287
+ print(f" Dynamic batch: {dynamic_batch}")
288
+
289
+ return True
290
+
291
+ except Exception as e:
292
+ warnings.warn(f"ONNX export failed: {str(e)}")
293
+ raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
File without changes