nextrec 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +399 -21
- nextrec/basic/model.py +289 -173
- nextrec/cli.py +27 -1
- nextrec/loss/loss_utils.py +73 -4
- nextrec/models/match/dssm.py +5 -4
- nextrec/models/match/dssm_v2.py +4 -3
- nextrec/models/match/mind.py +5 -4
- nextrec/models/match/sdm.py +5 -4
- nextrec/models/match/youtube_dnn.py +5 -4
- nextrec/utils/cli_utils.py +58 -0
- nextrec/utils/config.py +5 -4
- {nextrec-0.4.6.dist-info → nextrec-0.4.8.dist-info}/METADATA +32 -26
- {nextrec-0.4.6.dist-info → nextrec-0.4.8.dist-info}/RECORD +17 -16
- {nextrec-0.4.6.dist-info → nextrec-0.4.8.dist-info}/WHEEL +0 -0
- {nextrec-0.4.6.dist-info → nextrec-0.4.8.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.6.dist-info → nextrec-0.4.8.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.8"
|
nextrec/basic/callback.py
CHANGED
|
@@ -1,35 +1,413 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
Callback System for Training Process
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
+
Checkpoint: edit on 17/12/2025
|
|
5
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
8
|
|
|
8
9
|
import copy
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Optional
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
import torch
|
|
14
|
+
import pickle
|
|
15
|
+
from nextrec import __version__
|
|
9
16
|
|
|
10
17
|
|
|
11
|
-
class
|
|
12
|
-
|
|
18
|
+
class Callback:
|
|
19
|
+
"""
|
|
20
|
+
Base callback.
|
|
21
|
+
|
|
22
|
+
Notes (DDP):
|
|
23
|
+
- In distributed training, the training loop runs on every rank.
|
|
24
|
+
- For callbacks with side effects (saving, logging, etc.), set
|
|
25
|
+
``run_on_main_process_only=True`` to avoid multi-rank duplication.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
run_on_main_process_only: bool = False
|
|
29
|
+
|
|
30
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
def on_train_end(self, logs: Optional[dict] = None):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
def on_batch_begin(self, batch: int, logs: Optional[dict] = None):
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
def on_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
def on_validation_begin(self, logs: Optional[dict] = None):
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
def on_validation_end(self, logs: Optional[dict] = None):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def set_model(self, model):
|
|
55
|
+
self.model = model
|
|
56
|
+
|
|
57
|
+
def set_params(self, params: dict):
|
|
58
|
+
self.params = params
|
|
59
|
+
|
|
60
|
+
def should_run(self) -> bool:
|
|
61
|
+
if not getattr(self, "run_on_main_process_only", False):
|
|
62
|
+
return True
|
|
63
|
+
model = getattr(self, "model", None)
|
|
64
|
+
if model is None:
|
|
65
|
+
return True
|
|
66
|
+
return bool(getattr(model, "is_main_process", True))
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class CallbackList:
|
|
70
|
+
"""Container for managing multiple callbacks."""
|
|
71
|
+
|
|
72
|
+
def __init__(self, callbacks: Optional[list[Callback]] = None):
|
|
73
|
+
self.callbacks = callbacks or []
|
|
74
|
+
|
|
75
|
+
def append(self, callback: Callback):
|
|
76
|
+
self.callbacks.append(callback)
|
|
77
|
+
|
|
78
|
+
def set_model(self, model):
|
|
79
|
+
for callback in self.callbacks:
|
|
80
|
+
callback.set_model(model)
|
|
81
|
+
|
|
82
|
+
def set_params(self, params: dict):
|
|
83
|
+
for callback in self.callbacks:
|
|
84
|
+
callback.set_params(params)
|
|
85
|
+
|
|
86
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
87
|
+
for callback in self.callbacks:
|
|
88
|
+
if not callback.should_run():
|
|
89
|
+
continue
|
|
90
|
+
callback.on_train_begin(logs)
|
|
91
|
+
|
|
92
|
+
def on_train_end(self, logs: Optional[dict] = None):
|
|
93
|
+
for callback in self.callbacks:
|
|
94
|
+
if not callback.should_run():
|
|
95
|
+
continue
|
|
96
|
+
callback.on_train_end(logs)
|
|
97
|
+
|
|
98
|
+
def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
|
|
99
|
+
for callback in self.callbacks:
|
|
100
|
+
if not callback.should_run():
|
|
101
|
+
continue
|
|
102
|
+
callback.on_epoch_begin(epoch, logs)
|
|
103
|
+
|
|
104
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
105
|
+
for callback in self.callbacks:
|
|
106
|
+
if not callback.should_run():
|
|
107
|
+
continue
|
|
108
|
+
callback.on_epoch_end(epoch, logs)
|
|
109
|
+
|
|
110
|
+
def on_batch_begin(self, batch: int, logs: Optional[dict] = None):
|
|
111
|
+
for callback in self.callbacks:
|
|
112
|
+
if not callback.should_run():
|
|
113
|
+
continue
|
|
114
|
+
callback.on_batch_begin(batch, logs)
|
|
115
|
+
|
|
116
|
+
def on_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
117
|
+
for callback in self.callbacks:
|
|
118
|
+
if not callback.should_run():
|
|
119
|
+
continue
|
|
120
|
+
callback.on_batch_end(batch, logs)
|
|
121
|
+
|
|
122
|
+
def on_validation_begin(self, logs: Optional[dict] = None):
|
|
123
|
+
for callback in self.callbacks:
|
|
124
|
+
if not callback.should_run():
|
|
125
|
+
continue
|
|
126
|
+
callback.on_validation_begin(logs)
|
|
127
|
+
|
|
128
|
+
def on_validation_end(self, logs: Optional[dict] = None):
|
|
129
|
+
for callback in self.callbacks:
|
|
130
|
+
if not callback.should_run():
|
|
131
|
+
continue
|
|
132
|
+
callback.on_validation_end(logs)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class EarlyStopper(Callback):
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
monitor: str = "val_auc",
|
|
140
|
+
patience: int = 20,
|
|
141
|
+
mode: str = "max",
|
|
142
|
+
min_delta: float = 0.0,
|
|
143
|
+
restore_best_weights: bool = True,
|
|
144
|
+
verbose: int = 1,
|
|
145
|
+
):
|
|
146
|
+
super().__init__()
|
|
147
|
+
self.monitor = monitor
|
|
13
148
|
self.patience = patience
|
|
14
|
-
self.
|
|
15
|
-
self.
|
|
149
|
+
self.mode = mode
|
|
150
|
+
self.min_delta = abs(min_delta)
|
|
151
|
+
self.restore_best_weights = restore_best_weights
|
|
152
|
+
self.verbose = verbose
|
|
153
|
+
|
|
154
|
+
self.wait = 0
|
|
155
|
+
self.stopped_epoch = 0
|
|
16
156
|
self.best_weights = None
|
|
157
|
+
self.best_epoch = 0
|
|
158
|
+
|
|
159
|
+
if mode == "min":
|
|
160
|
+
self.best_value = float("inf")
|
|
161
|
+
self.monitor_op = lambda current, best: current < (best - self.min_delta)
|
|
162
|
+
elif mode == "max":
|
|
163
|
+
self.best_value = float("-inf")
|
|
164
|
+
self.monitor_op = lambda current, best: current > (best + self.min_delta)
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError(f"mode must be 'min' or 'max', got {mode}")
|
|
167
|
+
|
|
168
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
169
|
+
self.wait = 0
|
|
170
|
+
self.stopped_epoch = 0
|
|
171
|
+
self.best_weights = None
|
|
172
|
+
self.best_epoch = 0
|
|
173
|
+
if self.mode == "min":
|
|
174
|
+
self.best_value = float("inf")
|
|
175
|
+
else:
|
|
176
|
+
self.best_value = float("-inf")
|
|
177
|
+
|
|
178
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
179
|
+
logs = logs or {}
|
|
180
|
+
current = logs.get(self.monitor)
|
|
181
|
+
|
|
182
|
+
if current is None:
|
|
183
|
+
if self.verbose > 0:
|
|
184
|
+
logging.warning(
|
|
185
|
+
f"Early stopping conditioned on metric `{self.monitor}` "
|
|
186
|
+
f"which is not available. Available metrics are: {','.join(list(logs.keys()))}"
|
|
187
|
+
)
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
if self.monitor_op(current, self.best_value):
|
|
191
|
+
self.best_value = current
|
|
192
|
+
self.best_epoch = epoch
|
|
193
|
+
self.wait = 0
|
|
194
|
+
if self.restore_best_weights:
|
|
195
|
+
self.best_weights = copy.deepcopy(self.model.state_dict())
|
|
196
|
+
else:
|
|
197
|
+
self.wait += 1
|
|
198
|
+
if self.wait >= self.patience:
|
|
199
|
+
self.stopped_epoch = epoch
|
|
200
|
+
if hasattr(self.model, "stop_training"):
|
|
201
|
+
self.model.stop_training = True
|
|
202
|
+
if self.verbose > 0:
|
|
203
|
+
logging.info(
|
|
204
|
+
f"Early stopping triggered at epoch {epoch + 1}. "
|
|
205
|
+
f"Best {self.monitor}: {self.best_value:.6f} at epoch {self.best_epoch + 1}"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def on_train_end(self, logs: Optional[dict] = None):
|
|
209
|
+
if self.restore_best_weights and self.best_weights is not None:
|
|
210
|
+
if self.verbose > 0:
|
|
211
|
+
logging.info(
|
|
212
|
+
f"Restoring model weights from epoch {self.best_epoch + 1} "
|
|
213
|
+
f"with best {self.monitor}: {self.best_value:.6f}"
|
|
214
|
+
)
|
|
215
|
+
self.model.load_state_dict(self.best_weights)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class CheckpointSaver(Callback):
|
|
219
|
+
"""Callback to save model checkpoints during training.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
save_path: Path to save checkpoints.
|
|
223
|
+
monitor: Metric name to monitor for saving best model.
|
|
224
|
+
mode: One of {'min', 'max'}.
|
|
225
|
+
save_best_only: If True, only save when the model is considered the "best".
|
|
226
|
+
save_freq: Frequency of checkpoint saving ('epoch' or integer for every N epochs).
|
|
227
|
+
verbose: Verbosity mode.
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
def __init__(
|
|
231
|
+
self,
|
|
232
|
+
save_path: str | Path,
|
|
233
|
+
monitor: str = "val_auc",
|
|
234
|
+
mode: str = "max",
|
|
235
|
+
save_best_only: bool = False,
|
|
236
|
+
save_freq: str | int = "epoch",
|
|
237
|
+
verbose: int = 1,
|
|
238
|
+
run_on_main_process_only: bool = True,
|
|
239
|
+
):
|
|
240
|
+
super().__init__()
|
|
241
|
+
self.run_on_main_process_only = run_on_main_process_only
|
|
242
|
+
self.save_path = Path(save_path)
|
|
243
|
+
self.monitor = monitor
|
|
17
244
|
self.mode = mode
|
|
245
|
+
self.save_best_only = save_best_only
|
|
246
|
+
self.save_freq = save_freq
|
|
247
|
+
self.verbose = verbose
|
|
18
248
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
elif self.mode == "min":
|
|
26
|
-
if val_metrics < self.best_metrics:
|
|
27
|
-
self.best_metrics = val_metrics
|
|
28
|
-
self.trial_counter = 0
|
|
29
|
-
self.best_weights = copy.deepcopy(weights)
|
|
30
|
-
return False
|
|
31
|
-
elif self.trial_counter + 1 < self.patience:
|
|
32
|
-
self.trial_counter += 1
|
|
33
|
-
return False
|
|
249
|
+
if mode == "min":
|
|
250
|
+
self.best_value = float("inf")
|
|
251
|
+
self.monitor_op = lambda current, best: current < best
|
|
252
|
+
elif mode == "max":
|
|
253
|
+
self.best_value = float("-inf")
|
|
254
|
+
self.monitor_op = lambda current, best: current > best
|
|
34
255
|
else:
|
|
35
|
-
|
|
256
|
+
raise ValueError(f"mode must be 'min' or 'max', got {mode}")
|
|
257
|
+
|
|
258
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
259
|
+
if self.mode == "min":
|
|
260
|
+
self.best_value = float("inf")
|
|
261
|
+
else:
|
|
262
|
+
self.best_value = float("-inf")
|
|
263
|
+
|
|
264
|
+
# Create directory if it doesn't exist
|
|
265
|
+
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
266
|
+
|
|
267
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
268
|
+
logs = logs or {}
|
|
269
|
+
|
|
270
|
+
# Check if we should save this epoch
|
|
271
|
+
should_save = False
|
|
272
|
+
if self.save_freq == "epoch":
|
|
273
|
+
should_save = True
|
|
274
|
+
elif isinstance(self.save_freq, int) and (epoch + 1) % self.save_freq == 0:
|
|
275
|
+
should_save = True
|
|
276
|
+
|
|
277
|
+
if not should_save and self.save_best_only:
|
|
278
|
+
should_save = False
|
|
279
|
+
|
|
280
|
+
# Check if this is the best model
|
|
281
|
+
current = logs.get(self.monitor)
|
|
282
|
+
is_best = False
|
|
283
|
+
|
|
284
|
+
if current is not None and self.monitor_op(current, self.best_value):
|
|
285
|
+
self.best_value = current
|
|
286
|
+
is_best = True
|
|
287
|
+
should_save = True
|
|
288
|
+
|
|
289
|
+
if should_save:
|
|
290
|
+
if not self.save_best_only or is_best:
|
|
291
|
+
checkpoint_path = (
|
|
292
|
+
self.save_path.parent
|
|
293
|
+
/ f"{self.save_path.stem}_epoch_{epoch + 1}{self.save_path.suffix}"
|
|
294
|
+
)
|
|
295
|
+
self.save_checkpoint(checkpoint_path, epoch, logs)
|
|
296
|
+
|
|
297
|
+
if is_best:
|
|
298
|
+
# Use save_path directly without adding _best suffix since it may already contain it
|
|
299
|
+
self.save_checkpoint(self.save_path, epoch, logs)
|
|
300
|
+
if self.verbose > 0:
|
|
301
|
+
logging.info(
|
|
302
|
+
f"Saved best model to {self.save_path} with {self.monitor}: {current:.6f}"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def save_checkpoint(self, path: Path, epoch: int, logs: dict):
|
|
306
|
+
|
|
307
|
+
# Get the actual model (unwrap DDP if needed)
|
|
308
|
+
model_to_save = (
|
|
309
|
+
self.model.ddp_model.module
|
|
310
|
+
if getattr(self.model, "ddp_model", None) is not None
|
|
311
|
+
else self.model
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Save only state_dict to match BaseModel.save_model() format
|
|
315
|
+
torch.save(model_to_save.state_dict(), path)
|
|
316
|
+
|
|
317
|
+
# Also save features_config.pkl if it doesn't exist
|
|
318
|
+
config_path = path.parent / "features_config.pkl"
|
|
319
|
+
if not config_path.exists():
|
|
320
|
+
features_config = {
|
|
321
|
+
"all_features": self.model.all_features,
|
|
322
|
+
"target": self.model.target_columns,
|
|
323
|
+
"id_columns": self.model.id_columns,
|
|
324
|
+
"version": __version__,
|
|
325
|
+
}
|
|
326
|
+
with open(config_path, "wb") as f:
|
|
327
|
+
pickle.dump(features_config, f)
|
|
328
|
+
|
|
329
|
+
if self.verbose > 1:
|
|
330
|
+
logging.info(f"Saved checkpoint to {path}")
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class LearningRateScheduler(Callback):
|
|
334
|
+
"""Callback for learning rate scheduling.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
scheduler: Learning rate scheduler instance or name.
|
|
338
|
+
verbose: Verbosity mode.
|
|
339
|
+
"""
|
|
340
|
+
|
|
341
|
+
def __init__(self, scheduler=None, verbose: int = 0):
|
|
342
|
+
super().__init__()
|
|
343
|
+
self.scheduler = scheduler
|
|
344
|
+
self.verbose = verbose
|
|
345
|
+
|
|
346
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
347
|
+
if self.scheduler is None and hasattr(self.model, "scheduler_fn"):
|
|
348
|
+
self.scheduler = self.model.scheduler_fn
|
|
349
|
+
|
|
350
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
351
|
+
if self.scheduler is not None:
|
|
352
|
+
# Get current lr before step
|
|
353
|
+
if hasattr(self.model, "optimizer_fn"):
|
|
354
|
+
old_lr = self.model.optimizer_fn.param_groups[0]["lr"]
|
|
355
|
+
|
|
356
|
+
# Step the scheduler
|
|
357
|
+
if hasattr(self.scheduler, "step"):
|
|
358
|
+
# Some schedulers need metrics
|
|
359
|
+
if "val_loss" in (logs or {}) and hasattr(self.scheduler, "mode"):
|
|
360
|
+
self.scheduler.step(logs["val_loss"])
|
|
361
|
+
else:
|
|
362
|
+
self.scheduler.step()
|
|
363
|
+
|
|
364
|
+
# Log new lr
|
|
365
|
+
if self.verbose > 0 and hasattr(self.model, "optimizer_fn"):
|
|
366
|
+
if getattr(self.model, "is_main_process", True):
|
|
367
|
+
new_lr = self.model.optimizer_fn.param_groups[0]["lr"]
|
|
368
|
+
if new_lr != old_lr:
|
|
369
|
+
logging.info(
|
|
370
|
+
f"Learning rate changed from {old_lr:.6e} to {new_lr:.6e}"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class MetricsLogger(Callback):
|
|
375
|
+
"""Callback for logging training metrics.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
log_freq: Frequency of logging ('epoch', 'batch', or integer for every N epochs/batches).
|
|
379
|
+
verbose: Verbosity mode.
|
|
380
|
+
"""
|
|
381
|
+
|
|
382
|
+
def __init__(self, log_freq: str | int = "epoch", verbose: int = 1):
|
|
383
|
+
super().__init__()
|
|
384
|
+
self.run_on_main_process_only = True
|
|
385
|
+
self.log_freq = log_freq
|
|
386
|
+
self.verbose = verbose
|
|
387
|
+
self.batch_count = 0
|
|
388
|
+
|
|
389
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
390
|
+
if self.verbose > 0 and (
|
|
391
|
+
self.log_freq == "epoch"
|
|
392
|
+
or (isinstance(self.log_freq, int) and (epoch + 1) % self.log_freq == 0)
|
|
393
|
+
):
|
|
394
|
+
logs = logs or {}
|
|
395
|
+
metrics_str = " - ".join(
|
|
396
|
+
[
|
|
397
|
+
f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
|
|
398
|
+
for k, v in logs.items()
|
|
399
|
+
]
|
|
400
|
+
)
|
|
401
|
+
logging.info(f"Epoch {epoch + 1}: {metrics_str}")
|
|
402
|
+
|
|
403
|
+
def on_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
404
|
+
self.batch_count += 1
|
|
405
|
+
if self.verbose > 1 and self.log_freq == "batch":
|
|
406
|
+
logs = logs or {}
|
|
407
|
+
metrics_str = " - ".join(
|
|
408
|
+
[
|
|
409
|
+
f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
|
|
410
|
+
for k, v in logs.items()
|
|
411
|
+
]
|
|
412
|
+
)
|
|
413
|
+
logging.info(f"Batch {batch}: {metrics_str}")
|