torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +3 -1
- torch_rechub/basic/callback.py +2 -2
- torch_rechub/basic/features.py +38 -8
- torch_rechub/basic/initializers.py +92 -0
- torch_rechub/basic/layers.py +800 -46
- torch_rechub/basic/loss_func.py +223 -0
- torch_rechub/basic/metaoptimizer.py +76 -0
- torch_rechub/basic/metric.py +251 -0
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -0
- torch_rechub/models/matching/comirec.py +193 -0
- torch_rechub/models/matching/dssm.py +72 -0
- torch_rechub/models/matching/dssm_facebook.py +77 -0
- torch_rechub/models/matching/dssm_senet.py +87 -0
- torch_rechub/models/matching/gru4rec.py +85 -0
- torch_rechub/models/matching/mind.py +103 -0
- torch_rechub/models/matching/narm.py +82 -0
- torch_rechub/models/matching/sasrec.py +143 -0
- torch_rechub/models/matching/sine.py +148 -0
- torch_rechub/models/matching/stamp.py +81 -0
- torch_rechub/models/matching/youtube_dnn.py +75 -0
- torch_rechub/models/matching/youtube_sbc.py +98 -0
- torch_rechub/models/multi_task/__init__.py +5 -2
- torch_rechub/models/multi_task/aitm.py +83 -0
- torch_rechub/models/multi_task/esmm.py +19 -8
- torch_rechub/models/multi_task/mmoe.py +18 -12
- torch_rechub/models/multi_task/ple.py +41 -29
- torch_rechub/models/multi_task/shared_bottom.py +3 -2
- torch_rechub/models/ranking/__init__.py +13 -2
- torch_rechub/models/ranking/afm.py +65 -0
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -0
- torch_rechub/models/ranking/dcn.py +38 -0
- torch_rechub/models/ranking/dcn_v2.py +59 -0
- torch_rechub/models/ranking/deepffm.py +131 -0
- torch_rechub/models/ranking/deepfm.py +8 -7
- torch_rechub/models/ranking/dien.py +191 -0
- torch_rechub/models/ranking/din.py +31 -19
- torch_rechub/models/ranking/edcn.py +101 -0
- torch_rechub/models/ranking/fibinet.py +42 -0
- torch_rechub/models/ranking/widedeep.py +6 -6
- torch_rechub/trainers/__init__.py +4 -2
- torch_rechub/trainers/ctr_trainer.py +191 -0
- torch_rechub/trainers/match_trainer.py +239 -0
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +137 -23
- torch_rechub/trainers/seq_trainer.py +293 -0
- torch_rechub/utils/__init__.py +0 -0
- torch_rechub/utils/data.py +492 -0
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -0
- torch_rechub/utils/mtl.py +136 -0
- torch_rechub/utils/onnx_export.py +353 -0
- torch_rechub-0.0.4.dist-info/METADATA +391 -0
- torch_rechub-0.0.4.dist-info/RECORD +62 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
- torch_rechub/basic/utils.py +0 -168
- torch_rechub/trainers/trainer.py +0 -111
- torch_rechub-0.0.1.dist-info/METADATA +0 -105
- torch_rechub-0.0.1.dist-info/RECORD +0 -26
- torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""Sequence Generation Model Trainer."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import tqdm
|
|
8
|
+
|
|
9
|
+
from ..basic.callback import EarlyStopper
|
|
10
|
+
from ..basic.loss_func import NCELoss
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SeqTrainer(object):
|
|
14
|
+
"""序列生成模型训练器.
|
|
15
|
+
|
|
16
|
+
用于训练HSTU等序列生成模型。
|
|
17
|
+
支持CrossEntropyLoss损失函数和生成式评估指标。
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model (nn.Module): 要训练的模型
|
|
21
|
+
optimizer_fn (torch.optim): 优化器函数,默认为torch.optim.Adam
|
|
22
|
+
optimizer_params (dict): 优化器参数
|
|
23
|
+
scheduler_fn (torch.optim.lr_scheduler): torch调度器类
|
|
24
|
+
scheduler_params (dict): 调度器参数
|
|
25
|
+
n_epoch (int): 训练轮数,默认10
|
|
26
|
+
earlystop_patience (int): 早停耐心值,默认10
|
|
27
|
+
device (str): 设备,'cpu'或'cuda',默认'cpu'
|
|
28
|
+
gpus (list): 多GPU的id列表,默认为[]
|
|
29
|
+
model_path (str): 模型保存路径,默认为'./'
|
|
30
|
+
|
|
31
|
+
Methods:
|
|
32
|
+
fit: 训练模型
|
|
33
|
+
evaluate: 评估模型
|
|
34
|
+
predict: 生成预测
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> trainer = SeqTrainer(
|
|
38
|
+
... model=model,
|
|
39
|
+
... optimizer_fn=torch.optim.Adam,
|
|
40
|
+
... optimizer_params={'lr': 1e-3, 'weight_decay': 1e-5},
|
|
41
|
+
... device='cuda'
|
|
42
|
+
... )
|
|
43
|
+
>>> trainer.fit(
|
|
44
|
+
... train_loader=train_loader,
|
|
45
|
+
... val_loader=val_loader
|
|
46
|
+
... )
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, model, optimizer_fn=torch.optim.Adam, optimizer_params=None, scheduler_fn=None, scheduler_params=None, n_epoch=10, earlystop_patience=10, device='cpu', gpus=None, model_path='./', loss_type='cross_entropy', loss_params=None):
|
|
50
|
+
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
51
|
+
if gpus is None:
|
|
52
|
+
gpus = []
|
|
53
|
+
self.gpus = gpus
|
|
54
|
+
if len(gpus) > 1:
|
|
55
|
+
print('parallel running on these gpus:', gpus)
|
|
56
|
+
self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
|
|
57
|
+
# torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
58
|
+
self.device = torch.device(device)
|
|
59
|
+
self.model.to(self.device)
|
|
60
|
+
if optimizer_params is None:
|
|
61
|
+
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
62
|
+
self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default optimizer
|
|
63
|
+
self.scheduler = None
|
|
64
|
+
if scheduler_fn is not None:
|
|
65
|
+
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
|
|
66
|
+
|
|
67
|
+
# 损失函数
|
|
68
|
+
if loss_type == 'nce':
|
|
69
|
+
if loss_params is None:
|
|
70
|
+
loss_params = {"temperature": 0.1, "ignore_index": 0}
|
|
71
|
+
self.loss_fn = NCELoss(**loss_params)
|
|
72
|
+
else: # default to cross_entropy
|
|
73
|
+
if loss_params is None:
|
|
74
|
+
loss_params = {"ignore_index": 0}
|
|
75
|
+
self.loss_fn = nn.CrossEntropyLoss(**loss_params)
|
|
76
|
+
|
|
77
|
+
self.n_epoch = n_epoch
|
|
78
|
+
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
79
|
+
self.model_path = model_path
|
|
80
|
+
|
|
81
|
+
def fit(self, train_dataloader, val_dataloader=None):
|
|
82
|
+
"""训练模型.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
train_dataloader (DataLoader): 训练数据加载器
|
|
86
|
+
val_dataloader (DataLoader): 验证数据加载器
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
dict: 训练历史
|
|
90
|
+
"""
|
|
91
|
+
history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
|
|
92
|
+
|
|
93
|
+
for epoch_i in range(self.n_epoch):
|
|
94
|
+
print('epoch:', epoch_i)
|
|
95
|
+
# 训练阶段
|
|
96
|
+
self.train_one_epoch(train_dataloader)
|
|
97
|
+
if self.scheduler is not None:
|
|
98
|
+
if epoch_i % self.scheduler.step_size == 0:
|
|
99
|
+
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
100
|
+
self.scheduler.step() # update lr in epoch level by scheduler
|
|
101
|
+
|
|
102
|
+
# 验证阶段
|
|
103
|
+
if val_dataloader:
|
|
104
|
+
val_loss, val_accuracy = self.evaluate(val_dataloader)
|
|
105
|
+
history['val_loss'].append(val_loss)
|
|
106
|
+
history['val_accuracy'].append(val_accuracy)
|
|
107
|
+
|
|
108
|
+
print(f"epoch: {epoch_i}, validation: loss: {val_loss:.4f}, accuracy: {val_accuracy:.4f}")
|
|
109
|
+
|
|
110
|
+
# 早停
|
|
111
|
+
if self.early_stopper.stop_training(val_accuracy, self.model.state_dict()):
|
|
112
|
+
print(f'validation: best accuracy: {self.early_stopper.best_auc}')
|
|
113
|
+
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
114
|
+
break
|
|
115
|
+
|
|
116
|
+
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best model
|
|
117
|
+
return history
|
|
118
|
+
|
|
119
|
+
def train_one_epoch(self, data_loader, log_interval=10):
|
|
120
|
+
"""Train the model for a single epoch.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
data_loader (DataLoader): Training data loader.
|
|
124
|
+
log_interval (int): Interval (in steps) for logging average loss.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
float: Average training loss for this epoch.
|
|
128
|
+
"""
|
|
129
|
+
self.model.train()
|
|
130
|
+
total_loss = 0
|
|
131
|
+
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
132
|
+
for i, (seq_tokens, seq_positions, seq_time_diffs, targets) in enumerate(tk0):
|
|
133
|
+
# Move tensors to the target device
|
|
134
|
+
seq_tokens = seq_tokens.to(self.device)
|
|
135
|
+
seq_positions = seq_positions.to(self.device)
|
|
136
|
+
seq_time_diffs = seq_time_diffs.to(self.device)
|
|
137
|
+
targets = targets.to(self.device).squeeze(-1)
|
|
138
|
+
|
|
139
|
+
# Forward pass
|
|
140
|
+
logits = self.model(seq_tokens, seq_time_diffs) # (B, L, V)
|
|
141
|
+
|
|
142
|
+
# Compute loss
|
|
143
|
+
# For next-item prediction we only use the last position in the sequence
|
|
144
|
+
# logits[:, -1, :] selects the prediction at the last step for each sequence
|
|
145
|
+
last_logits = logits[:, -1, :] # (B, V)
|
|
146
|
+
|
|
147
|
+
loss = self.loss_fn(last_logits, targets)
|
|
148
|
+
|
|
149
|
+
# 反向传播
|
|
150
|
+
self.model.zero_grad()
|
|
151
|
+
loss.backward()
|
|
152
|
+
self.optimizer.step()
|
|
153
|
+
|
|
154
|
+
total_loss += loss.item()
|
|
155
|
+
if (i + 1) % log_interval == 0:
|
|
156
|
+
tk0.set_postfix(loss=total_loss / log_interval)
|
|
157
|
+
total_loss = 0
|
|
158
|
+
|
|
159
|
+
def evaluate(self, data_loader):
|
|
160
|
+
"""Evaluate the model on a validation/test data loader.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
data_loader (DataLoader): Validation or test data loader.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
tuple: ``(avg_loss, top1_accuracy)``.
|
|
167
|
+
"""
|
|
168
|
+
self.model.eval()
|
|
169
|
+
total_loss = 0.0
|
|
170
|
+
total_correct = 0
|
|
171
|
+
total_samples = 0
|
|
172
|
+
|
|
173
|
+
with torch.no_grad():
|
|
174
|
+
for seq_tokens, seq_positions, seq_time_diffs, targets in tqdm.tqdm(data_loader, desc="evaluating", smoothing=0, mininterval=1.0):
|
|
175
|
+
# Move tensors to the target device
|
|
176
|
+
seq_tokens = seq_tokens.to(self.device)
|
|
177
|
+
seq_positions = seq_positions.to(self.device)
|
|
178
|
+
seq_time_diffs = seq_time_diffs.to(self.device)
|
|
179
|
+
targets = targets.to(self.device).squeeze(-1)
|
|
180
|
+
|
|
181
|
+
# Forward pass
|
|
182
|
+
logits = self.model(seq_tokens, seq_time_diffs) # (B, L, V)
|
|
183
|
+
|
|
184
|
+
# Compute loss using only the last position (next-item prediction)
|
|
185
|
+
last_logits = logits[:, -1, :] # (B, V)
|
|
186
|
+
|
|
187
|
+
loss = self.loss_fn(last_logits, targets)
|
|
188
|
+
total_loss += loss.item()
|
|
189
|
+
|
|
190
|
+
# Compute top-1 accuracy
|
|
191
|
+
predictions = torch.argmax(last_logits, dim=-1) # (B,)
|
|
192
|
+
correct = (predictions == targets).sum().item()
|
|
193
|
+
total_correct += correct
|
|
194
|
+
total_samples += targets.numel()
|
|
195
|
+
|
|
196
|
+
avg_loss = total_loss / len(data_loader)
|
|
197
|
+
accuracy = total_correct / total_samples
|
|
198
|
+
|
|
199
|
+
return avg_loss, accuracy
|
|
200
|
+
|
|
201
|
+
def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
202
|
+
"""Export the trained sequence generation model to ONNX format.
|
|
203
|
+
|
|
204
|
+
This method exports sequence generation models (e.g., HSTU) to ONNX format.
|
|
205
|
+
Unlike other trainers, sequence models use positional arguments (seq_tokens, seq_time_diffs)
|
|
206
|
+
instead of dict input, making ONNX export more straightforward.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
output_path (str): Path to save the ONNX model file.
|
|
210
|
+
batch_size (int): Batch size for dummy input (default: 2).
|
|
211
|
+
seq_length (int): Sequence length for dummy input (default: 50).
|
|
212
|
+
vocab_size (int, optional): Vocabulary size for generating dummy tokens.
|
|
213
|
+
If None, will try to get from model.vocab_size.
|
|
214
|
+
opset_version (int): ONNX opset version (default: 14).
|
|
215
|
+
dynamic_batch (bool): Enable dynamic batch size (default: True).
|
|
216
|
+
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
217
|
+
If None, defaults to 'cpu' for maximum compatibility.
|
|
218
|
+
verbose (bool): Print export details (default: False).
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
bool: True if export succeeded, False otherwise.
|
|
222
|
+
|
|
223
|
+
Example:
|
|
224
|
+
>>> trainer = SeqTrainer(hstu_model, ...)
|
|
225
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
226
|
+
>>> trainer.export_onnx("hstu.onnx", vocab_size=10000)
|
|
227
|
+
|
|
228
|
+
>>> # Export on specific device
|
|
229
|
+
>>> trainer.export_onnx("hstu.onnx", vocab_size=10000, device="cpu")
|
|
230
|
+
"""
|
|
231
|
+
import warnings
|
|
232
|
+
|
|
233
|
+
# Use provided device or default to 'cpu'
|
|
234
|
+
export_device = device if device is not None else 'cpu'
|
|
235
|
+
|
|
236
|
+
# Handle DataParallel wrapped model
|
|
237
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
238
|
+
model.eval()
|
|
239
|
+
model.to(export_device)
|
|
240
|
+
|
|
241
|
+
# Get vocab_size from model if not provided
|
|
242
|
+
if vocab_size is None:
|
|
243
|
+
if hasattr(model, 'vocab_size'):
|
|
244
|
+
vocab_size = model.vocab_size
|
|
245
|
+
elif hasattr(model, 'item_num'):
|
|
246
|
+
vocab_size = model.item_num
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError("vocab_size must be provided or model must have 'vocab_size' or 'item_num' attribute")
|
|
249
|
+
|
|
250
|
+
# Generate dummy inputs on the export device
|
|
251
|
+
dummy_seq_tokens = torch.randint(0, vocab_size, (batch_size, seq_length), device=export_device)
|
|
252
|
+
dummy_seq_time_diffs = torch.zeros(batch_size, seq_length, dtype=torch.float32, device=export_device)
|
|
253
|
+
|
|
254
|
+
# Configure dynamic axes
|
|
255
|
+
dynamic_axes = None
|
|
256
|
+
if dynamic_batch:
|
|
257
|
+
dynamic_axes = {"seq_tokens": {0: "batch_size", 1: "seq_length"}, "seq_time_diffs": {0: "batch_size", 1: "seq_length"}, "output": {0: "batch_size", 1: "seq_length"}}
|
|
258
|
+
|
|
259
|
+
# Ensure output directory exists
|
|
260
|
+
import os
|
|
261
|
+
output_dir = os.path.dirname(output_path)
|
|
262
|
+
if output_dir and not os.path.exists(output_dir):
|
|
263
|
+
os.makedirs(output_dir)
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
with torch.no_grad():
|
|
267
|
+
torch.onnx.export(
|
|
268
|
+
model,
|
|
269
|
+
(dummy_seq_tokens,
|
|
270
|
+
dummy_seq_time_diffs),
|
|
271
|
+
output_path,
|
|
272
|
+
input_names=["seq_tokens",
|
|
273
|
+
"seq_time_diffs"],
|
|
274
|
+
output_names=["output"],
|
|
275
|
+
dynamic_axes=dynamic_axes,
|
|
276
|
+
opset_version=opset_version,
|
|
277
|
+
do_constant_folding=True,
|
|
278
|
+
verbose=verbose,
|
|
279
|
+
dynamo=False # Use legacy exporter for dynamic_axes support
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if verbose:
|
|
283
|
+
print(f"Successfully exported ONNX model to: {output_path}")
|
|
284
|
+
print(" Input names: ['seq_tokens', 'seq_time_diffs']")
|
|
285
|
+
print(f" Vocab size: {vocab_size}")
|
|
286
|
+
print(f" Opset version: {opset_version}")
|
|
287
|
+
print(f" Dynamic batch: {dynamic_batch}")
|
|
288
|
+
|
|
289
|
+
return True
|
|
290
|
+
|
|
291
|
+
except Exception as e:
|
|
292
|
+
warnings.warn(f"ONNX export failed: {str(e)}")
|
|
293
|
+
raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
|
|
File without changes
|