nextrec 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.6"
1
+ __version__ = "0.4.8"
nextrec/basic/callback.py CHANGED
@@ -1,35 +1,413 @@
1
1
  """
2
- EarlyStopper definitions
2
+ Callback System for Training Process
3
3
 
4
4
  Date: create on 27/10/2025
5
+ Checkpoint: edit on 17/12/2025
5
6
  Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
8
 
8
9
  import copy
10
+ import logging
11
+ from typing import Optional
12
+ from pathlib import Path
13
+ import torch
14
+ import pickle
15
+ from nextrec import __version__
9
16
 
10
17
 
11
- class EarlyStopper(object):
12
- def __init__(self, patience: int = 20, mode: str = "max"):
18
+ class Callback:
19
+ """
20
+ Base callback.
21
+
22
+ Notes (DDP):
23
+ - In distributed training, the training loop runs on every rank.
24
+ - For callbacks with side effects (saving, logging, etc.), set
25
+ ``run_on_main_process_only=True`` to avoid multi-rank duplication.
26
+ """
27
+
28
+ run_on_main_process_only: bool = False
29
+
30
+ def on_train_begin(self, logs: Optional[dict] = None):
31
+ pass
32
+
33
+ def on_train_end(self, logs: Optional[dict] = None):
34
+ pass
35
+
36
+ def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
37
+ pass
38
+
39
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
40
+ pass
41
+
42
+ def on_batch_begin(self, batch: int, logs: Optional[dict] = None):
43
+ pass
44
+
45
+ def on_batch_end(self, batch: int, logs: Optional[dict] = None):
46
+ pass
47
+
48
+ def on_validation_begin(self, logs: Optional[dict] = None):
49
+ pass
50
+
51
+ def on_validation_end(self, logs: Optional[dict] = None):
52
+ pass
53
+
54
+ def set_model(self, model):
55
+ self.model = model
56
+
57
+ def set_params(self, params: dict):
58
+ self.params = params
59
+
60
+ def should_run(self) -> bool:
61
+ if not getattr(self, "run_on_main_process_only", False):
62
+ return True
63
+ model = getattr(self, "model", None)
64
+ if model is None:
65
+ return True
66
+ return bool(getattr(model, "is_main_process", True))
67
+
68
+
69
+ class CallbackList:
70
+ """Container for managing multiple callbacks."""
71
+
72
+ def __init__(self, callbacks: Optional[list[Callback]] = None):
73
+ self.callbacks = callbacks or []
74
+
75
+ def append(self, callback: Callback):
76
+ self.callbacks.append(callback)
77
+
78
+ def set_model(self, model):
79
+ for callback in self.callbacks:
80
+ callback.set_model(model)
81
+
82
+ def set_params(self, params: dict):
83
+ for callback in self.callbacks:
84
+ callback.set_params(params)
85
+
86
+ def on_train_begin(self, logs: Optional[dict] = None):
87
+ for callback in self.callbacks:
88
+ if not callback.should_run():
89
+ continue
90
+ callback.on_train_begin(logs)
91
+
92
+ def on_train_end(self, logs: Optional[dict] = None):
93
+ for callback in self.callbacks:
94
+ if not callback.should_run():
95
+ continue
96
+ callback.on_train_end(logs)
97
+
98
+ def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
99
+ for callback in self.callbacks:
100
+ if not callback.should_run():
101
+ continue
102
+ callback.on_epoch_begin(epoch, logs)
103
+
104
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
105
+ for callback in self.callbacks:
106
+ if not callback.should_run():
107
+ continue
108
+ callback.on_epoch_end(epoch, logs)
109
+
110
+ def on_batch_begin(self, batch: int, logs: Optional[dict] = None):
111
+ for callback in self.callbacks:
112
+ if not callback.should_run():
113
+ continue
114
+ callback.on_batch_begin(batch, logs)
115
+
116
+ def on_batch_end(self, batch: int, logs: Optional[dict] = None):
117
+ for callback in self.callbacks:
118
+ if not callback.should_run():
119
+ continue
120
+ callback.on_batch_end(batch, logs)
121
+
122
+ def on_validation_begin(self, logs: Optional[dict] = None):
123
+ for callback in self.callbacks:
124
+ if not callback.should_run():
125
+ continue
126
+ callback.on_validation_begin(logs)
127
+
128
+ def on_validation_end(self, logs: Optional[dict] = None):
129
+ for callback in self.callbacks:
130
+ if not callback.should_run():
131
+ continue
132
+ callback.on_validation_end(logs)
133
+
134
+
135
+ class EarlyStopper(Callback):
136
+
137
+ def __init__(
138
+ self,
139
+ monitor: str = "val_auc",
140
+ patience: int = 20,
141
+ mode: str = "max",
142
+ min_delta: float = 0.0,
143
+ restore_best_weights: bool = True,
144
+ verbose: int = 1,
145
+ ):
146
+ super().__init__()
147
+ self.monitor = monitor
13
148
  self.patience = patience
14
- self.trial_counter = 0
15
- self.best_metrics = 0
149
+ self.mode = mode
150
+ self.min_delta = abs(min_delta)
151
+ self.restore_best_weights = restore_best_weights
152
+ self.verbose = verbose
153
+
154
+ self.wait = 0
155
+ self.stopped_epoch = 0
16
156
  self.best_weights = None
157
+ self.best_epoch = 0
158
+
159
+ if mode == "min":
160
+ self.best_value = float("inf")
161
+ self.monitor_op = lambda current, best: current < (best - self.min_delta)
162
+ elif mode == "max":
163
+ self.best_value = float("-inf")
164
+ self.monitor_op = lambda current, best: current > (best + self.min_delta)
165
+ else:
166
+ raise ValueError(f"mode must be 'min' or 'max', got {mode}")
167
+
168
+ def on_train_begin(self, logs: Optional[dict] = None):
169
+ self.wait = 0
170
+ self.stopped_epoch = 0
171
+ self.best_weights = None
172
+ self.best_epoch = 0
173
+ if self.mode == "min":
174
+ self.best_value = float("inf")
175
+ else:
176
+ self.best_value = float("-inf")
177
+
178
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
179
+ logs = logs or {}
180
+ current = logs.get(self.monitor)
181
+
182
+ if current is None:
183
+ if self.verbose > 0:
184
+ logging.warning(
185
+ f"Early stopping conditioned on metric `{self.monitor}` "
186
+ f"which is not available. Available metrics are: {','.join(list(logs.keys()))}"
187
+ )
188
+ return
189
+
190
+ if self.monitor_op(current, self.best_value):
191
+ self.best_value = current
192
+ self.best_epoch = epoch
193
+ self.wait = 0
194
+ if self.restore_best_weights:
195
+ self.best_weights = copy.deepcopy(self.model.state_dict())
196
+ else:
197
+ self.wait += 1
198
+ if self.wait >= self.patience:
199
+ self.stopped_epoch = epoch
200
+ if hasattr(self.model, "stop_training"):
201
+ self.model.stop_training = True
202
+ if self.verbose > 0:
203
+ logging.info(
204
+ f"Early stopping triggered at epoch {epoch + 1}. "
205
+ f"Best {self.monitor}: {self.best_value:.6f} at epoch {self.best_epoch + 1}"
206
+ )
207
+
208
+ def on_train_end(self, logs: Optional[dict] = None):
209
+ if self.restore_best_weights and self.best_weights is not None:
210
+ if self.verbose > 0:
211
+ logging.info(
212
+ f"Restoring model weights from epoch {self.best_epoch + 1} "
213
+ f"with best {self.monitor}: {self.best_value:.6f}"
214
+ )
215
+ self.model.load_state_dict(self.best_weights)
216
+
217
+
218
+ class CheckpointSaver(Callback):
219
+ """Callback to save model checkpoints during training.
220
+
221
+ Args:
222
+ save_path: Path to save checkpoints.
223
+ monitor: Metric name to monitor for saving best model.
224
+ mode: One of {'min', 'max'}.
225
+ save_best_only: If True, only save when the model is considered the "best".
226
+ save_freq: Frequency of checkpoint saving ('epoch' or integer for every N epochs).
227
+ verbose: Verbosity mode.
228
+ """
229
+
230
+ def __init__(
231
+ self,
232
+ save_path: str | Path,
233
+ monitor: str = "val_auc",
234
+ mode: str = "max",
235
+ save_best_only: bool = False,
236
+ save_freq: str | int = "epoch",
237
+ verbose: int = 1,
238
+ run_on_main_process_only: bool = True,
239
+ ):
240
+ super().__init__()
241
+ self.run_on_main_process_only = run_on_main_process_only
242
+ self.save_path = Path(save_path)
243
+ self.monitor = monitor
17
244
  self.mode = mode
245
+ self.save_best_only = save_best_only
246
+ self.save_freq = save_freq
247
+ self.verbose = verbose
18
248
 
19
- def stop_training(self, val_metrics, weights):
20
- if self.mode == "max":
21
- if val_metrics > self.best_metrics:
22
- self.best_metrics = val_metrics
23
- self.trial_counter = 0
24
- self.best_weights = copy.deepcopy(weights)
25
- elif self.mode == "min":
26
- if val_metrics < self.best_metrics:
27
- self.best_metrics = val_metrics
28
- self.trial_counter = 0
29
- self.best_weights = copy.deepcopy(weights)
30
- return False
31
- elif self.trial_counter + 1 < self.patience:
32
- self.trial_counter += 1
33
- return False
249
+ if mode == "min":
250
+ self.best_value = float("inf")
251
+ self.monitor_op = lambda current, best: current < best
252
+ elif mode == "max":
253
+ self.best_value = float("-inf")
254
+ self.monitor_op = lambda current, best: current > best
34
255
  else:
35
- return True
256
+ raise ValueError(f"mode must be 'min' or 'max', got {mode}")
257
+
258
+ def on_train_begin(self, logs: Optional[dict] = None):
259
+ if self.mode == "min":
260
+ self.best_value = float("inf")
261
+ else:
262
+ self.best_value = float("-inf")
263
+
264
+ # Create directory if it doesn't exist
265
+ self.save_path.parent.mkdir(parents=True, exist_ok=True)
266
+
267
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
268
+ logs = logs or {}
269
+
270
+ # Check if we should save this epoch
271
+ should_save = False
272
+ if self.save_freq == "epoch":
273
+ should_save = True
274
+ elif isinstance(self.save_freq, int) and (epoch + 1) % self.save_freq == 0:
275
+ should_save = True
276
+
277
+ if not should_save and self.save_best_only:
278
+ should_save = False
279
+
280
+ # Check if this is the best model
281
+ current = logs.get(self.monitor)
282
+ is_best = False
283
+
284
+ if current is not None and self.monitor_op(current, self.best_value):
285
+ self.best_value = current
286
+ is_best = True
287
+ should_save = True
288
+
289
+ if should_save:
290
+ if not self.save_best_only or is_best:
291
+ checkpoint_path = (
292
+ self.save_path.parent
293
+ / f"{self.save_path.stem}_epoch_{epoch + 1}{self.save_path.suffix}"
294
+ )
295
+ self.save_checkpoint(checkpoint_path, epoch, logs)
296
+
297
+ if is_best:
298
+ # Use save_path directly without adding _best suffix since it may already contain it
299
+ self.save_checkpoint(self.save_path, epoch, logs)
300
+ if self.verbose > 0:
301
+ logging.info(
302
+ f"Saved best model to {self.save_path} with {self.monitor}: {current:.6f}"
303
+ )
304
+
305
+ def save_checkpoint(self, path: Path, epoch: int, logs: dict):
306
+
307
+ # Get the actual model (unwrap DDP if needed)
308
+ model_to_save = (
309
+ self.model.ddp_model.module
310
+ if getattr(self.model, "ddp_model", None) is not None
311
+ else self.model
312
+ )
313
+
314
+ # Save only state_dict to match BaseModel.save_model() format
315
+ torch.save(model_to_save.state_dict(), path)
316
+
317
+ # Also save features_config.pkl if it doesn't exist
318
+ config_path = path.parent / "features_config.pkl"
319
+ if not config_path.exists():
320
+ features_config = {
321
+ "all_features": self.model.all_features,
322
+ "target": self.model.target_columns,
323
+ "id_columns": self.model.id_columns,
324
+ "version": __version__,
325
+ }
326
+ with open(config_path, "wb") as f:
327
+ pickle.dump(features_config, f)
328
+
329
+ if self.verbose > 1:
330
+ logging.info(f"Saved checkpoint to {path}")
331
+
332
+
333
+ class LearningRateScheduler(Callback):
334
+ """Callback for learning rate scheduling.
335
+
336
+ Args:
337
+ scheduler: Learning rate scheduler instance or name.
338
+ verbose: Verbosity mode.
339
+ """
340
+
341
+ def __init__(self, scheduler=None, verbose: int = 0):
342
+ super().__init__()
343
+ self.scheduler = scheduler
344
+ self.verbose = verbose
345
+
346
+ def on_train_begin(self, logs: Optional[dict] = None):
347
+ if self.scheduler is None and hasattr(self.model, "scheduler_fn"):
348
+ self.scheduler = self.model.scheduler_fn
349
+
350
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
351
+ if self.scheduler is not None:
352
+ # Get current lr before step
353
+ if hasattr(self.model, "optimizer_fn"):
354
+ old_lr = self.model.optimizer_fn.param_groups[0]["lr"]
355
+
356
+ # Step the scheduler
357
+ if hasattr(self.scheduler, "step"):
358
+ # Some schedulers need metrics
359
+ if "val_loss" in (logs or {}) and hasattr(self.scheduler, "mode"):
360
+ self.scheduler.step(logs["val_loss"])
361
+ else:
362
+ self.scheduler.step()
363
+
364
+ # Log new lr
365
+ if self.verbose > 0 and hasattr(self.model, "optimizer_fn"):
366
+ if getattr(self.model, "is_main_process", True):
367
+ new_lr = self.model.optimizer_fn.param_groups[0]["lr"]
368
+ if new_lr != old_lr:
369
+ logging.info(
370
+ f"Learning rate changed from {old_lr:.6e} to {new_lr:.6e}"
371
+ )
372
+
373
+
374
+ class MetricsLogger(Callback):
375
+ """Callback for logging training metrics.
376
+
377
+ Args:
378
+ log_freq: Frequency of logging ('epoch', 'batch', or integer for every N epochs/batches).
379
+ verbose: Verbosity mode.
380
+ """
381
+
382
+ def __init__(self, log_freq: str | int = "epoch", verbose: int = 1):
383
+ super().__init__()
384
+ self.run_on_main_process_only = True
385
+ self.log_freq = log_freq
386
+ self.verbose = verbose
387
+ self.batch_count = 0
388
+
389
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
390
+ if self.verbose > 0 and (
391
+ self.log_freq == "epoch"
392
+ or (isinstance(self.log_freq, int) and (epoch + 1) % self.log_freq == 0)
393
+ ):
394
+ logs = logs or {}
395
+ metrics_str = " - ".join(
396
+ [
397
+ f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
398
+ for k, v in logs.items()
399
+ ]
400
+ )
401
+ logging.info(f"Epoch {epoch + 1}: {metrics_str}")
402
+
403
+ def on_batch_end(self, batch: int, logs: Optional[dict] = None):
404
+ self.batch_count += 1
405
+ if self.verbose > 1 and self.log_freq == "batch":
406
+ logs = logs or {}
407
+ metrics_str = " - ".join(
408
+ [
409
+ f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
410
+ for k, v in logs.items()
411
+ ]
412
+ )
413
+ logging.info(f"Batch {batch}: {metrics_str}")