nextrec 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +399 -21
- nextrec/basic/features.py +4 -0
- nextrec/basic/layers.py +103 -24
- nextrec/basic/metrics.py +71 -1
- nextrec/basic/model.py +285 -186
- nextrec/data/data_processing.py +1 -3
- nextrec/loss/loss_utils.py +73 -4
- nextrec/models/generative/__init__.py +16 -0
- nextrec/models/generative/hstu.py +110 -57
- nextrec/models/generative/rqvae.py +826 -0
- nextrec/models/match/dssm.py +5 -4
- nextrec/models/match/dssm_v2.py +4 -3
- nextrec/models/match/mind.py +5 -4
- nextrec/models/match/sdm.py +5 -4
- nextrec/models/match/youtube_dnn.py +5 -4
- nextrec/models/ranking/masknet.py +1 -1
- nextrec/utils/config.py +38 -1
- nextrec/utils/embedding.py +28 -0
- nextrec/utils/initializer.py +4 -4
- nextrec/utils/synthetic_data.py +19 -0
- nextrec-0.4.7.dist-info/METADATA +376 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/RECORD +26 -25
- nextrec-0.4.5.dist-info/METADATA +0 -357
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/WHEEL +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.7"
|
nextrec/basic/callback.py
CHANGED
|
@@ -1,35 +1,413 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
Callback System for Training Process
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
+
Checkpoint: edit on 17/12/2025
|
|
5
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
8
|
|
|
8
9
|
import copy
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Optional
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
import torch
|
|
14
|
+
import pickle
|
|
15
|
+
from nextrec import __version__
|
|
9
16
|
|
|
10
17
|
|
|
11
|
-
class
|
|
12
|
-
|
|
18
|
+
class Callback:
|
|
19
|
+
"""
|
|
20
|
+
Base callback.
|
|
21
|
+
|
|
22
|
+
Notes (DDP):
|
|
23
|
+
- In distributed training, the training loop runs on every rank.
|
|
24
|
+
- For callbacks with side effects (saving, logging, etc.), set
|
|
25
|
+
``run_on_main_process_only=True`` to avoid multi-rank duplication.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
run_on_main_process_only: bool = False
|
|
29
|
+
|
|
30
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
def on_train_end(self, logs: Optional[dict] = None):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
def on_batch_begin(self, batch: int, logs: Optional[dict] = None):
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
def on_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
def on_validation_begin(self, logs: Optional[dict] = None):
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
def on_validation_end(self, logs: Optional[dict] = None):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def set_model(self, model):
|
|
55
|
+
self.model = model
|
|
56
|
+
|
|
57
|
+
def set_params(self, params: dict):
|
|
58
|
+
self.params = params
|
|
59
|
+
|
|
60
|
+
def should_run(self) -> bool:
|
|
61
|
+
if not getattr(self, "run_on_main_process_only", False):
|
|
62
|
+
return True
|
|
63
|
+
model = getattr(self, "model", None)
|
|
64
|
+
if model is None:
|
|
65
|
+
return True
|
|
66
|
+
return bool(getattr(model, "is_main_process", True))
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class CallbackList:
|
|
70
|
+
"""Container for managing multiple callbacks."""
|
|
71
|
+
|
|
72
|
+
def __init__(self, callbacks: Optional[list[Callback]] = None):
|
|
73
|
+
self.callbacks = callbacks or []
|
|
74
|
+
|
|
75
|
+
def append(self, callback: Callback):
|
|
76
|
+
self.callbacks.append(callback)
|
|
77
|
+
|
|
78
|
+
def set_model(self, model):
|
|
79
|
+
for callback in self.callbacks:
|
|
80
|
+
callback.set_model(model)
|
|
81
|
+
|
|
82
|
+
def set_params(self, params: dict):
|
|
83
|
+
for callback in self.callbacks:
|
|
84
|
+
callback.set_params(params)
|
|
85
|
+
|
|
86
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
87
|
+
for callback in self.callbacks:
|
|
88
|
+
if not callback.should_run():
|
|
89
|
+
continue
|
|
90
|
+
callback.on_train_begin(logs)
|
|
91
|
+
|
|
92
|
+
def on_train_end(self, logs: Optional[dict] = None):
|
|
93
|
+
for callback in self.callbacks:
|
|
94
|
+
if not callback.should_run():
|
|
95
|
+
continue
|
|
96
|
+
callback.on_train_end(logs)
|
|
97
|
+
|
|
98
|
+
def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
|
|
99
|
+
for callback in self.callbacks:
|
|
100
|
+
if not callback.should_run():
|
|
101
|
+
continue
|
|
102
|
+
callback.on_epoch_begin(epoch, logs)
|
|
103
|
+
|
|
104
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
105
|
+
for callback in self.callbacks:
|
|
106
|
+
if not callback.should_run():
|
|
107
|
+
continue
|
|
108
|
+
callback.on_epoch_end(epoch, logs)
|
|
109
|
+
|
|
110
|
+
def on_batch_begin(self, batch: int, logs: Optional[dict] = None):
|
|
111
|
+
for callback in self.callbacks:
|
|
112
|
+
if not callback.should_run():
|
|
113
|
+
continue
|
|
114
|
+
callback.on_batch_begin(batch, logs)
|
|
115
|
+
|
|
116
|
+
def on_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
117
|
+
for callback in self.callbacks:
|
|
118
|
+
if not callback.should_run():
|
|
119
|
+
continue
|
|
120
|
+
callback.on_batch_end(batch, logs)
|
|
121
|
+
|
|
122
|
+
def on_validation_begin(self, logs: Optional[dict] = None):
|
|
123
|
+
for callback in self.callbacks:
|
|
124
|
+
if not callback.should_run():
|
|
125
|
+
continue
|
|
126
|
+
callback.on_validation_begin(logs)
|
|
127
|
+
|
|
128
|
+
def on_validation_end(self, logs: Optional[dict] = None):
|
|
129
|
+
for callback in self.callbacks:
|
|
130
|
+
if not callback.should_run():
|
|
131
|
+
continue
|
|
132
|
+
callback.on_validation_end(logs)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class EarlyStopper(Callback):
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
monitor: str = "val_auc",
|
|
140
|
+
patience: int = 20,
|
|
141
|
+
mode: str = "max",
|
|
142
|
+
min_delta: float = 0.0,
|
|
143
|
+
restore_best_weights: bool = True,
|
|
144
|
+
verbose: int = 1,
|
|
145
|
+
):
|
|
146
|
+
super().__init__()
|
|
147
|
+
self.monitor = monitor
|
|
13
148
|
self.patience = patience
|
|
14
|
-
self.
|
|
15
|
-
self.
|
|
149
|
+
self.mode = mode
|
|
150
|
+
self.min_delta = abs(min_delta)
|
|
151
|
+
self.restore_best_weights = restore_best_weights
|
|
152
|
+
self.verbose = verbose
|
|
153
|
+
|
|
154
|
+
self.wait = 0
|
|
155
|
+
self.stopped_epoch = 0
|
|
16
156
|
self.best_weights = None
|
|
157
|
+
self.best_epoch = 0
|
|
158
|
+
|
|
159
|
+
if mode == "min":
|
|
160
|
+
self.best_value = float("inf")
|
|
161
|
+
self.monitor_op = lambda current, best: current < (best - self.min_delta)
|
|
162
|
+
elif mode == "max":
|
|
163
|
+
self.best_value = float("-inf")
|
|
164
|
+
self.monitor_op = lambda current, best: current > (best + self.min_delta)
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError(f"mode must be 'min' or 'max', got {mode}")
|
|
167
|
+
|
|
168
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
169
|
+
self.wait = 0
|
|
170
|
+
self.stopped_epoch = 0
|
|
171
|
+
self.best_weights = None
|
|
172
|
+
self.best_epoch = 0
|
|
173
|
+
if self.mode == "min":
|
|
174
|
+
self.best_value = float("inf")
|
|
175
|
+
else:
|
|
176
|
+
self.best_value = float("-inf")
|
|
177
|
+
|
|
178
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
179
|
+
logs = logs or {}
|
|
180
|
+
current = logs.get(self.monitor)
|
|
181
|
+
|
|
182
|
+
if current is None:
|
|
183
|
+
if self.verbose > 0:
|
|
184
|
+
logging.warning(
|
|
185
|
+
f"Early stopping conditioned on metric `{self.monitor}` "
|
|
186
|
+
f"which is not available. Available metrics are: {','.join(list(logs.keys()))}"
|
|
187
|
+
)
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
if self.monitor_op(current, self.best_value):
|
|
191
|
+
self.best_value = current
|
|
192
|
+
self.best_epoch = epoch
|
|
193
|
+
self.wait = 0
|
|
194
|
+
if self.restore_best_weights:
|
|
195
|
+
self.best_weights = copy.deepcopy(self.model.state_dict())
|
|
196
|
+
else:
|
|
197
|
+
self.wait += 1
|
|
198
|
+
if self.wait >= self.patience:
|
|
199
|
+
self.stopped_epoch = epoch
|
|
200
|
+
if hasattr(self.model, "stop_training"):
|
|
201
|
+
self.model.stop_training = True
|
|
202
|
+
if self.verbose > 0:
|
|
203
|
+
logging.info(
|
|
204
|
+
f"Early stopping triggered at epoch {epoch + 1}. "
|
|
205
|
+
f"Best {self.monitor}: {self.best_value:.6f} at epoch {self.best_epoch + 1}"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def on_train_end(self, logs: Optional[dict] = None):
|
|
209
|
+
if self.restore_best_weights and self.best_weights is not None:
|
|
210
|
+
if self.verbose > 0:
|
|
211
|
+
logging.info(
|
|
212
|
+
f"Restoring model weights from epoch {self.best_epoch + 1} "
|
|
213
|
+
f"with best {self.monitor}: {self.best_value:.6f}"
|
|
214
|
+
)
|
|
215
|
+
self.model.load_state_dict(self.best_weights)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class CheckpointSaver(Callback):
|
|
219
|
+
"""Callback to save model checkpoints during training.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
save_path: Path to save checkpoints.
|
|
223
|
+
monitor: Metric name to monitor for saving best model.
|
|
224
|
+
mode: One of {'min', 'max'}.
|
|
225
|
+
save_best_only: If True, only save when the model is considered the "best".
|
|
226
|
+
save_freq: Frequency of checkpoint saving ('epoch' or integer for every N epochs).
|
|
227
|
+
verbose: Verbosity mode.
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
def __init__(
|
|
231
|
+
self,
|
|
232
|
+
save_path: str | Path,
|
|
233
|
+
monitor: str = "val_auc",
|
|
234
|
+
mode: str = "max",
|
|
235
|
+
save_best_only: bool = False,
|
|
236
|
+
save_freq: str | int = "epoch",
|
|
237
|
+
verbose: int = 1,
|
|
238
|
+
run_on_main_process_only: bool = True,
|
|
239
|
+
):
|
|
240
|
+
super().__init__()
|
|
241
|
+
self.run_on_main_process_only = run_on_main_process_only
|
|
242
|
+
self.save_path = Path(save_path)
|
|
243
|
+
self.monitor = monitor
|
|
17
244
|
self.mode = mode
|
|
245
|
+
self.save_best_only = save_best_only
|
|
246
|
+
self.save_freq = save_freq
|
|
247
|
+
self.verbose = verbose
|
|
18
248
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
elif self.mode == "min":
|
|
26
|
-
if val_metrics < self.best_metrics:
|
|
27
|
-
self.best_metrics = val_metrics
|
|
28
|
-
self.trial_counter = 0
|
|
29
|
-
self.best_weights = copy.deepcopy(weights)
|
|
30
|
-
return False
|
|
31
|
-
elif self.trial_counter + 1 < self.patience:
|
|
32
|
-
self.trial_counter += 1
|
|
33
|
-
return False
|
|
249
|
+
if mode == "min":
|
|
250
|
+
self.best_value = float("inf")
|
|
251
|
+
self.monitor_op = lambda current, best: current < best
|
|
252
|
+
elif mode == "max":
|
|
253
|
+
self.best_value = float("-inf")
|
|
254
|
+
self.monitor_op = lambda current, best: current > best
|
|
34
255
|
else:
|
|
35
|
-
|
|
256
|
+
raise ValueError(f"mode must be 'min' or 'max', got {mode}")
|
|
257
|
+
|
|
258
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
259
|
+
if self.mode == "min":
|
|
260
|
+
self.best_value = float("inf")
|
|
261
|
+
else:
|
|
262
|
+
self.best_value = float("-inf")
|
|
263
|
+
|
|
264
|
+
# Create directory if it doesn't exist
|
|
265
|
+
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
266
|
+
|
|
267
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
268
|
+
logs = logs or {}
|
|
269
|
+
|
|
270
|
+
# Check if we should save this epoch
|
|
271
|
+
should_save = False
|
|
272
|
+
if self.save_freq == "epoch":
|
|
273
|
+
should_save = True
|
|
274
|
+
elif isinstance(self.save_freq, int) and (epoch + 1) % self.save_freq == 0:
|
|
275
|
+
should_save = True
|
|
276
|
+
|
|
277
|
+
if not should_save and self.save_best_only:
|
|
278
|
+
should_save = False
|
|
279
|
+
|
|
280
|
+
# Check if this is the best model
|
|
281
|
+
current = logs.get(self.monitor)
|
|
282
|
+
is_best = False
|
|
283
|
+
|
|
284
|
+
if current is not None and self.monitor_op(current, self.best_value):
|
|
285
|
+
self.best_value = current
|
|
286
|
+
is_best = True
|
|
287
|
+
should_save = True
|
|
288
|
+
|
|
289
|
+
if should_save:
|
|
290
|
+
if not self.save_best_only or is_best:
|
|
291
|
+
checkpoint_path = (
|
|
292
|
+
self.save_path.parent
|
|
293
|
+
/ f"{self.save_path.stem}_epoch_{epoch + 1}{self.save_path.suffix}"
|
|
294
|
+
)
|
|
295
|
+
self.save_checkpoint(checkpoint_path, epoch, logs)
|
|
296
|
+
|
|
297
|
+
if is_best:
|
|
298
|
+
# Use save_path directly without adding _best suffix since it may already contain it
|
|
299
|
+
self.save_checkpoint(self.save_path, epoch, logs)
|
|
300
|
+
if self.verbose > 0:
|
|
301
|
+
logging.info(
|
|
302
|
+
f"Saved best model to {self.save_path} with {self.monitor}: {current:.6f}"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def save_checkpoint(self, path: Path, epoch: int, logs: dict):
|
|
306
|
+
|
|
307
|
+
# Get the actual model (unwrap DDP if needed)
|
|
308
|
+
model_to_save = (
|
|
309
|
+
self.model.ddp_model.module
|
|
310
|
+
if getattr(self.model, "ddp_model", None) is not None
|
|
311
|
+
else self.model
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Save only state_dict to match BaseModel.save_model() format
|
|
315
|
+
torch.save(model_to_save.state_dict(), path)
|
|
316
|
+
|
|
317
|
+
# Also save features_config.pkl if it doesn't exist
|
|
318
|
+
config_path = path.parent / "features_config.pkl"
|
|
319
|
+
if not config_path.exists():
|
|
320
|
+
features_config = {
|
|
321
|
+
"all_features": self.model.all_features,
|
|
322
|
+
"target": self.model.target_columns,
|
|
323
|
+
"id_columns": self.model.id_columns,
|
|
324
|
+
"version": __version__,
|
|
325
|
+
}
|
|
326
|
+
with open(config_path, "wb") as f:
|
|
327
|
+
pickle.dump(features_config, f)
|
|
328
|
+
|
|
329
|
+
if self.verbose > 1:
|
|
330
|
+
logging.info(f"Saved checkpoint to {path}")
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class LearningRateScheduler(Callback):
|
|
334
|
+
"""Callback for learning rate scheduling.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
scheduler: Learning rate scheduler instance or name.
|
|
338
|
+
verbose: Verbosity mode.
|
|
339
|
+
"""
|
|
340
|
+
|
|
341
|
+
def __init__(self, scheduler=None, verbose: int = 0):
|
|
342
|
+
super().__init__()
|
|
343
|
+
self.scheduler = scheduler
|
|
344
|
+
self.verbose = verbose
|
|
345
|
+
|
|
346
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
347
|
+
if self.scheduler is None and hasattr(self.model, "scheduler_fn"):
|
|
348
|
+
self.scheduler = self.model.scheduler_fn
|
|
349
|
+
|
|
350
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
351
|
+
if self.scheduler is not None:
|
|
352
|
+
# Get current lr before step
|
|
353
|
+
if hasattr(self.model, "optimizer_fn"):
|
|
354
|
+
old_lr = self.model.optimizer_fn.param_groups[0]["lr"]
|
|
355
|
+
|
|
356
|
+
# Step the scheduler
|
|
357
|
+
if hasattr(self.scheduler, "step"):
|
|
358
|
+
# Some schedulers need metrics
|
|
359
|
+
if "val_loss" in (logs or {}) and hasattr(self.scheduler, "mode"):
|
|
360
|
+
self.scheduler.step(logs["val_loss"])
|
|
361
|
+
else:
|
|
362
|
+
self.scheduler.step()
|
|
363
|
+
|
|
364
|
+
# Log new lr
|
|
365
|
+
if self.verbose > 0 and hasattr(self.model, "optimizer_fn"):
|
|
366
|
+
if getattr(self.model, "is_main_process", True):
|
|
367
|
+
new_lr = self.model.optimizer_fn.param_groups[0]["lr"]
|
|
368
|
+
if new_lr != old_lr:
|
|
369
|
+
logging.info(
|
|
370
|
+
f"Learning rate changed from {old_lr:.6e} to {new_lr:.6e}"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class MetricsLogger(Callback):
|
|
375
|
+
"""Callback for logging training metrics.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
log_freq: Frequency of logging ('epoch', 'batch', or integer for every N epochs/batches).
|
|
379
|
+
verbose: Verbosity mode.
|
|
380
|
+
"""
|
|
381
|
+
|
|
382
|
+
def __init__(self, log_freq: str | int = "epoch", verbose: int = 1):
|
|
383
|
+
super().__init__()
|
|
384
|
+
self.run_on_main_process_only = True
|
|
385
|
+
self.log_freq = log_freq
|
|
386
|
+
self.verbose = verbose
|
|
387
|
+
self.batch_count = 0
|
|
388
|
+
|
|
389
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
390
|
+
if self.verbose > 0 and (
|
|
391
|
+
self.log_freq == "epoch"
|
|
392
|
+
or (isinstance(self.log_freq, int) and (epoch + 1) % self.log_freq == 0)
|
|
393
|
+
):
|
|
394
|
+
logs = logs or {}
|
|
395
|
+
metrics_str = " - ".join(
|
|
396
|
+
[
|
|
397
|
+
f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
|
|
398
|
+
for k, v in logs.items()
|
|
399
|
+
]
|
|
400
|
+
)
|
|
401
|
+
logging.info(f"Epoch {epoch + 1}: {metrics_str}")
|
|
402
|
+
|
|
403
|
+
def on_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
404
|
+
self.batch_count += 1
|
|
405
|
+
if self.verbose > 1 and self.log_freq == "batch":
|
|
406
|
+
logs = logs or {}
|
|
407
|
+
metrics_str = " - ".join(
|
|
408
|
+
[
|
|
409
|
+
f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
|
|
410
|
+
for k, v in logs.items()
|
|
411
|
+
]
|
|
412
|
+
)
|
|
413
|
+
logging.info(f"Batch {batch}: {metrics_str}")
|
nextrec/basic/features.py
CHANGED
|
@@ -33,6 +33,8 @@ class SequenceFeature(BaseFeature):
|
|
|
33
33
|
l1_reg: float = 0.0,
|
|
34
34
|
l2_reg: float = 1e-5,
|
|
35
35
|
trainable: bool = True,
|
|
36
|
+
pretrained_weight: torch.Tensor | None = None,
|
|
37
|
+
freeze_pretrained: bool = False,
|
|
36
38
|
):
|
|
37
39
|
self.name = name
|
|
38
40
|
self.vocab_size = vocab_size
|
|
@@ -47,6 +49,8 @@ class SequenceFeature(BaseFeature):
|
|
|
47
49
|
self.l1_reg = l1_reg
|
|
48
50
|
self.l2_reg = l2_reg
|
|
49
51
|
self.trainable = trainable
|
|
52
|
+
self.pretrained_weight = pretrained_weight
|
|
53
|
+
self.freeze_pretrained = freeze_pretrained
|
|
50
54
|
|
|
51
55
|
|
|
52
56
|
class SparseFeature(BaseFeature):
|
nextrec/basic/layers.py
CHANGED
|
@@ -496,12 +496,18 @@ class HadamardInteractionLayer(nn.Module):
|
|
|
496
496
|
|
|
497
497
|
|
|
498
498
|
class MultiHeadSelfAttention(nn.Module):
|
|
499
|
+
"""
|
|
500
|
+
Multi-Head Self-Attention layer with Flash Attention support.
|
|
501
|
+
Uses PyTorch 2.0+ scaled_dot_product_attention when available for better performance.
|
|
502
|
+
"""
|
|
503
|
+
|
|
499
504
|
def __init__(
|
|
500
505
|
self,
|
|
501
506
|
embedding_dim: int,
|
|
502
507
|
num_heads: int = 2,
|
|
503
508
|
dropout: float = 0.0,
|
|
504
509
|
use_residual: bool = True,
|
|
510
|
+
use_layer_norm: bool = False,
|
|
505
511
|
):
|
|
506
512
|
super().__init__()
|
|
507
513
|
if embedding_dim % num_heads != 0:
|
|
@@ -512,45 +518,100 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
512
518
|
self.num_heads = num_heads
|
|
513
519
|
self.head_dim = embedding_dim // num_heads
|
|
514
520
|
self.use_residual = use_residual
|
|
521
|
+
self.dropout_rate = dropout
|
|
522
|
+
|
|
515
523
|
self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
516
524
|
self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
517
525
|
self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
526
|
+
self.W_O = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
527
|
+
|
|
518
528
|
if self.use_residual:
|
|
519
529
|
self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
530
|
+
if use_layer_norm:
|
|
531
|
+
self.layer_norm = nn.LayerNorm(embedding_dim)
|
|
532
|
+
else:
|
|
533
|
+
self.layer_norm = None
|
|
534
|
+
|
|
520
535
|
self.dropout = nn.Dropout(dropout)
|
|
536
|
+
# Check if Flash Attention is available
|
|
537
|
+
self.use_flash_attention = hasattr(F, "scaled_dot_product_attention")
|
|
521
538
|
|
|
522
|
-
def forward(
|
|
523
|
-
|
|
524
|
-
|
|
539
|
+
def forward(
|
|
540
|
+
self, x: torch.Tensor, attention_mask: torch.Tensor | None = None
|
|
541
|
+
) -> torch.Tensor:
|
|
542
|
+
"""
|
|
543
|
+
Args:
|
|
544
|
+
x: [batch_size, seq_len, embedding_dim]
|
|
545
|
+
attention_mask: [batch_size, seq_len] or [batch_size, seq_len, seq_len], boolean mask where True indicates valid positions
|
|
546
|
+
Returns:
|
|
547
|
+
output: [batch_size, seq_len, embedding_dim]
|
|
548
|
+
"""
|
|
549
|
+
batch_size, seq_len, _ = x.shape
|
|
550
|
+
Q = self.W_Q(x) # [batch_size, seq_len, embedding_dim]
|
|
525
551
|
K = self.W_K(x)
|
|
526
552
|
V = self.W_V(x)
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
)
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
553
|
+
|
|
554
|
+
# Split into multiple heads: [batch_size, num_heads, seq_len, head_dim]
|
|
555
|
+
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
556
|
+
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
557
|
+
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
558
|
+
|
|
559
|
+
if self.use_flash_attention:
|
|
560
|
+
# Use PyTorch 2.0+ Flash Attention
|
|
561
|
+
if attention_mask is not None:
|
|
562
|
+
# Convert mask to [batch_size, 1, seq_len, seq_len] format
|
|
563
|
+
if attention_mask.dim() == 2:
|
|
564
|
+
# [B, L] -> [B, 1, 1, L]
|
|
565
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
566
|
+
elif attention_mask.dim() == 3:
|
|
567
|
+
# [B, L, L] -> [B, 1, L, L]
|
|
568
|
+
attention_mask = attention_mask.unsqueeze(1)
|
|
569
|
+
attention_output = F.scaled_dot_product_attention(
|
|
570
|
+
Q,
|
|
571
|
+
K,
|
|
572
|
+
V,
|
|
573
|
+
attn_mask=attention_mask,
|
|
574
|
+
dropout_p=self.dropout_rate if self.training else 0.0,
|
|
575
|
+
)
|
|
576
|
+
# Handle potential NaN values
|
|
577
|
+
attention_output = torch.nan_to_num(attention_output, nan=0.0)
|
|
578
|
+
else:
|
|
579
|
+
# Fallback to standard attention
|
|
580
|
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
581
|
+
|
|
582
|
+
if attention_mask is not None:
|
|
583
|
+
# Process mask for standard attention
|
|
584
|
+
if attention_mask.dim() == 2:
|
|
585
|
+
# [B, L] -> [B, 1, 1, L]
|
|
586
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
587
|
+
elif attention_mask.dim() == 3:
|
|
588
|
+
# [B, L, L] -> [B, 1, L, L]
|
|
589
|
+
attention_mask = attention_mask.unsqueeze(1)
|
|
590
|
+
scores = scores.masked_fill(~attention_mask, float("-1e9"))
|
|
591
|
+
|
|
592
|
+
attention_weights = F.softmax(scores, dim=-1)
|
|
593
|
+
attention_weights = self.dropout(attention_weights)
|
|
594
|
+
attention_output = torch.matmul(
|
|
595
|
+
attention_weights, V
|
|
596
|
+
) # [batch_size, num_heads, seq_len, head_dim]
|
|
597
|
+
|
|
544
598
|
# Concatenate heads
|
|
545
599
|
attention_output = attention_output.transpose(1, 2).contiguous()
|
|
546
600
|
attention_output = attention_output.view(
|
|
547
|
-
batch_size,
|
|
601
|
+
batch_size, seq_len, self.embedding_dim
|
|
548
602
|
)
|
|
603
|
+
|
|
604
|
+
# Output projection
|
|
605
|
+
output = self.W_O(attention_output)
|
|
606
|
+
|
|
549
607
|
# Residual connection
|
|
550
608
|
if self.use_residual:
|
|
551
|
-
output =
|
|
552
|
-
|
|
553
|
-
|
|
609
|
+
output = output + self.W_Res(x)
|
|
610
|
+
|
|
611
|
+
# Layer normalization
|
|
612
|
+
if self.layer_norm is not None:
|
|
613
|
+
output = self.layer_norm(output)
|
|
614
|
+
|
|
554
615
|
output = F.relu(output)
|
|
555
616
|
return output
|
|
556
617
|
|
|
@@ -653,3 +714,21 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
653
714
|
# Weighted sum over keys: (B, L, 1) * (B, L, D) -> (B, D)
|
|
654
715
|
output = torch.sum(attention_weights * keys, dim=1)
|
|
655
716
|
return output
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
class RMSNorm(torch.nn.Module):
|
|
720
|
+
"""
|
|
721
|
+
Root Mean Square Layer Normalization.
|
|
722
|
+
Reference: https://arxiv.org/abs/1910.07467
|
|
723
|
+
"""
|
|
724
|
+
|
|
725
|
+
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
|
726
|
+
super().__init__()
|
|
727
|
+
self.eps = eps
|
|
728
|
+
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
|
729
|
+
|
|
730
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
731
|
+
# RMS(x) = sqrt(mean(x^2) + eps)
|
|
732
|
+
variance = torch.mean(x**2, dim=-1, keepdim=True)
|
|
733
|
+
x_normalized = x * torch.rsqrt(variance + self.eps)
|
|
734
|
+
return self.weight * x_normalized
|