nextrec 0.4.7__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/callback.py +30 -15
  3. nextrec/basic/features.py +1 -0
  4. nextrec/basic/layers.py +6 -8
  5. nextrec/basic/loggers.py +14 -7
  6. nextrec/basic/metrics.py +6 -76
  7. nextrec/basic/model.py +337 -328
  8. nextrec/cli.py +25 -4
  9. nextrec/data/__init__.py +13 -16
  10. nextrec/data/batch_utils.py +3 -2
  11. nextrec/data/data_processing.py +10 -2
  12. nextrec/data/data_utils.py +9 -14
  13. nextrec/data/dataloader.py +12 -13
  14. nextrec/data/preprocessor.py +328 -255
  15. nextrec/loss/__init__.py +1 -5
  16. nextrec/loss/loss_utils.py +2 -8
  17. nextrec/models/generative/__init__.py +1 -8
  18. nextrec/models/generative/hstu.py +6 -4
  19. nextrec/models/multi_task/esmm.py +2 -2
  20. nextrec/models/multi_task/mmoe.py +2 -2
  21. nextrec/models/multi_task/ple.py +2 -2
  22. nextrec/models/multi_task/poso.py +2 -3
  23. nextrec/models/multi_task/share_bottom.py +2 -2
  24. nextrec/models/ranking/afm.py +2 -2
  25. nextrec/models/ranking/autoint.py +2 -2
  26. nextrec/models/ranking/dcn.py +2 -2
  27. nextrec/models/ranking/dcn_v2.py +2 -2
  28. nextrec/models/ranking/deepfm.py +2 -2
  29. nextrec/models/ranking/dien.py +3 -3
  30. nextrec/models/ranking/din.py +3 -3
  31. nextrec/models/ranking/ffm.py +0 -0
  32. nextrec/models/ranking/fibinet.py +5 -5
  33. nextrec/models/ranking/fm.py +3 -7
  34. nextrec/models/ranking/lr.py +0 -0
  35. nextrec/models/ranking/masknet.py +2 -2
  36. nextrec/models/ranking/pnn.py +2 -2
  37. nextrec/models/ranking/widedeep.py +2 -2
  38. nextrec/models/ranking/xdeepfm.py +2 -2
  39. nextrec/models/representation/__init__.py +9 -0
  40. nextrec/models/{generative → representation}/rqvae.py +9 -9
  41. nextrec/models/retrieval/__init__.py +0 -0
  42. nextrec/models/{match → retrieval}/dssm.py +8 -3
  43. nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
  44. nextrec/models/{match → retrieval}/mind.py +4 -3
  45. nextrec/models/{match → retrieval}/sdm.py +4 -3
  46. nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
  47. nextrec/utils/__init__.py +60 -46
  48. nextrec/utils/config.py +12 -10
  49. nextrec/utils/console.py +371 -0
  50. nextrec/utils/{synthetic_data.py → data.py} +102 -15
  51. nextrec/utils/feature.py +15 -0
  52. nextrec/utils/torch_utils.py +411 -0
  53. {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/METADATA +8 -7
  54. nextrec-0.4.9.dist-info/RECORD +70 -0
  55. nextrec/utils/device.py +0 -78
  56. nextrec/utils/distributed.py +0 -141
  57. nextrec/utils/file.py +0 -92
  58. nextrec/utils/initializer.py +0 -79
  59. nextrec/utils/optimizer.py +0 -75
  60. nextrec/utils/tensor.py +0 -72
  61. nextrec-0.4.7.dist-info/RECORD +0 -70
  62. /nextrec/models/{match/__init__.py → ranking/eulernet.py} +0 -0
  63. {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/WHEEL +0 -0
  64. {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/entry_points.txt +0 -0
  65. {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.7"
1
+ __version__ = "0.4.9"
nextrec/basic/callback.py CHANGED
@@ -2,17 +2,20 @@
2
2
  Callback System for Training Process
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 17/12/2025
5
+ Checkpoint: edit on 19/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
9
  import copy
10
10
  import logging
11
- from typing import Optional
11
+ import pickle
12
12
  from pathlib import Path
13
+ from typing import Optional
14
+
13
15
  import torch
14
- import pickle
16
+
15
17
  from nextrec import __version__
18
+ from nextrec.basic.loggers import colorize, format_kv
16
19
 
17
20
 
18
21
  class Callback:
@@ -209,8 +212,13 @@ class EarlyStopper(Callback):
209
212
  if self.restore_best_weights and self.best_weights is not None:
210
213
  if self.verbose > 0:
211
214
  logging.info(
212
- f"Restoring model weights from epoch {self.best_epoch + 1} "
213
- f"with best {self.monitor}: {self.best_value:.6f}"
215
+ colorize(
216
+ format_kv(
217
+ "Restoring model weights from epoch",
218
+ f"{self.best_epoch + 1} with best {self.monitor}: {self.best_value:.6f}",
219
+ ),
220
+ color="bright_blue",
221
+ )
214
222
  )
215
223
  self.model.load_state_dict(self.best_weights)
216
224
 
@@ -229,7 +237,8 @@ class CheckpointSaver(Callback):
229
237
 
230
238
  def __init__(
231
239
  self,
232
- save_path: str | Path,
240
+ best_path: str | Path,
241
+ checkpoint_path: str | Path,
233
242
  monitor: str = "val_auc",
234
243
  mode: str = "max",
235
244
  save_best_only: bool = False,
@@ -239,7 +248,8 @@ class CheckpointSaver(Callback):
239
248
  ):
240
249
  super().__init__()
241
250
  self.run_on_main_process_only = run_on_main_process_only
242
- self.save_path = Path(save_path)
251
+ self.best_path = Path(best_path)
252
+ self.checkpoint_path = Path(checkpoint_path)
243
253
  self.monitor = monitor
244
254
  self.mode = mode
245
255
  self.save_best_only = save_best_only
@@ -260,14 +270,13 @@ class CheckpointSaver(Callback):
260
270
  self.best_value = float("inf")
261
271
  else:
262
272
  self.best_value = float("-inf")
263
-
264
- # Create directory if it doesn't exist
265
- self.save_path.parent.mkdir(parents=True, exist_ok=True)
273
+ self.best_path.parent.mkdir(parents=True, exist_ok=True)
274
+ self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
266
275
 
267
276
  def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
277
+ logging.info("")
268
278
  logs = logs or {}
269
279
 
270
- # Check if we should save this epoch
271
280
  should_save = False
272
281
  if self.save_freq == "epoch":
273
282
  should_save = True
@@ -289,17 +298,23 @@ class CheckpointSaver(Callback):
289
298
  if should_save:
290
299
  if not self.save_best_only or is_best:
291
300
  checkpoint_path = (
292
- self.save_path.parent
293
- / f"{self.save_path.stem}_epoch_{epoch + 1}{self.save_path.suffix}"
301
+ self.checkpoint_path.parent
302
+ / f"{self.checkpoint_path.stem}{self.checkpoint_path.suffix}"
294
303
  )
295
304
  self.save_checkpoint(checkpoint_path, epoch, logs)
296
305
 
297
306
  if is_best:
298
307
  # Use save_path directly without adding _best suffix since it may already contain it
299
- self.save_checkpoint(self.save_path, epoch, logs)
308
+ self.save_checkpoint(self.best_path, epoch, logs)
300
309
  if self.verbose > 0:
301
310
  logging.info(
302
- f"Saved best model to {self.save_path} with {self.monitor}: {current:.6f}"
311
+ colorize(
312
+ format_kv(
313
+ "Saved best model to",
314
+ f"{self.best_path} with {self.monitor}: {current:.6f}",
315
+ ),
316
+ color="bright_blue",
317
+ )
303
318
  )
304
319
 
305
320
  def save_checkpoint(self, path: Path, epoch: int, logs: dict):
nextrec/basic/features.py CHANGED
@@ -7,6 +7,7 @@ Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
9
  import torch
10
+
10
11
  from nextrec.utils.embedding import get_auto_embedding_dim
11
12
  from nextrec.utils.feature import normalize_to_list
12
13
 
nextrec/basic/layers.py CHANGED
@@ -2,22 +2,22 @@
2
2
  Layer implementations used across NextRec models.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 19/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ from collections import OrderedDict
12
+ from itertools import combinations
13
+
11
14
  import torch
12
15
  import torch.nn as nn
13
16
  import torch.nn.functional as F
14
17
 
15
- from itertools import combinations
16
- from collections import OrderedDict
17
-
18
- from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
19
- from nextrec.utils.initializer import get_initializer
20
18
  from nextrec.basic.activation import activation_layer
19
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
20
+ from nextrec.utils.torch_utils import get_initializer
21
21
 
22
22
 
23
23
  class PredictionLayer(nn.Module):
@@ -81,8 +81,6 @@ class PredictionLayer(nn.Module):
81
81
  outputs.append(torch.sigmoid(task_logits))
82
82
  elif task == "regression":
83
83
  outputs.append(task_logits)
84
- elif task == "multiclass":
85
- outputs.append(torch.softmax(task_logits, dim=-1))
86
84
  else:
87
85
  raise ValueError(
88
86
  f"[PredictionLayer Error]: Unsupported task_type '{task_type}'."
nextrec/basic/loggers.py CHANGED
@@ -2,20 +2,20 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 03/12/2025
5
+ Checkpoint: edit on 19/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
- import os
10
- import re
11
- import sys
12
- import json
13
9
  import copy
10
+ import json
14
11
  import logging
15
12
  import numbers
13
+ import os
14
+ import re
15
+ import sys
16
+ from typing import Any, Mapping
16
17
 
17
- from typing import Mapping, Any
18
- from nextrec.basic.session import create_session, Session
18
+ from nextrec.basic.session import Session, create_session
19
19
 
20
20
  ANSI_CODES = {
21
21
  "black": "\033[30m",
@@ -91,6 +91,13 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
91
91
  return result
92
92
 
93
93
 
94
+ def format_kv(label: str, value: Any, width: int = 34, indent: int = 0) -> str:
95
+ """Format key-value lines with consistent alignment."""
96
+ label_text = label if label.endswith(":") else f"{label}:"
97
+ prefix = " " * indent
98
+ return f"{prefix}{label_text:<{width}} {value}"
99
+
100
+
94
101
  def setup_logger(session_id: str | os.PathLike | None = None):
95
102
  """Set up a logger that logs to both console and a file with ANSI formatting.
96
103
  Only console output has colors; file output is stripped of ANSI codes.
nextrec/basic/metrics.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Metrics computation and configuration for model evaluation.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 02/12/2025
5
+ Checkpoint: edit on 19/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -11,15 +11,15 @@ from typing import Any
11
11
 
12
12
  import numpy as np
13
13
  from sklearn.metrics import (
14
- roc_auc_score,
14
+ accuracy_score,
15
+ f1_score,
15
16
  log_loss,
16
- mean_squared_error,
17
17
  mean_absolute_error,
18
- accuracy_score,
18
+ mean_squared_error,
19
19
  precision_score,
20
- recall_score,
21
- f1_score,
22
20
  r2_score,
21
+ recall_score,
22
+ roc_auc_score,
23
23
  )
24
24
 
25
25
  CLASSIFICATION_METRICS = {
@@ -44,11 +44,6 @@ TASK_DEFAULT_METRICS = {
44
44
  + [f"recall@{k}" for k in (5, 10, 20)]
45
45
  + [f"ndcg@{k}" for k in (5, 10, 20)]
46
46
  + [f"mrr@{k}" for k in (5, 10, 20)],
47
- # generative/multiclass next-item prediction defaults
48
- "multiclass": ["accuracy"]
49
- + [f"hitrate@{k}" for k in (1, 5, 10)]
50
- + [f"recall@{k}" for k in (1, 5, 10)]
51
- + [f"mrr@{k}" for k in (1, 5, 10)],
52
47
  }
53
48
 
54
49
 
@@ -163,51 +158,6 @@ def group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarr
163
158
  return groups
164
159
 
165
160
 
166
- def normalize_multiclass_inputs(
167
- y_true: np.ndarray, y_pred: np.ndarray
168
- ) -> tuple[np.ndarray, np.ndarray]:
169
- """
170
- Normalize multiclass inputs to consistent shapes.
171
-
172
- y_true: [N] of class ids
173
- y_pred: [N, C] of logits/probabilities
174
- """
175
- labels = np.asarray(y_true).reshape(-1)
176
- scores = np.asarray(y_pred)
177
- if scores.ndim == 1:
178
- scores = scores.reshape(scores.shape[0], -1)
179
- if scores.shape[0] != labels.shape[0]:
180
- raise ValueError(
181
- f"[Metric Warning] y_true length {labels.shape[0]} != y_pred batch {scores.shape[0]} for multiclass metrics."
182
- )
183
- return labels.astype(int), scores
184
-
185
-
186
- def multiclass_topk_hit_rate(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
187
- labels, scores = normalize_multiclass_inputs(y_true, y_pred)
188
- if scores.shape[1] == 0:
189
- return 0.0
190
- k = min(k, scores.shape[1])
191
- topk_idx = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k]
192
- hits = (topk_idx == labels[:, None]).any(axis=1)
193
- return float(hits.mean()) if hits.size > 0 else 0.0
194
-
195
-
196
- def multiclass_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
197
- labels, scores = normalize_multiclass_inputs(y_true, y_pred)
198
- if scores.shape[1] == 0:
199
- return 0.0
200
- k = min(k, scores.shape[1])
201
- # full sort for stable ranks
202
- topk_idx = np.argsort(-scores, axis=1)[:, :k]
203
- ranks = np.full(labels.shape, fill_value=k + 1, dtype=np.float32)
204
- for idx in range(k):
205
- match = topk_idx[:, idx] == labels
206
- ranks[match] = idx + 1
207
- reciprocals = np.where(ranks <= k, 1.0 / ranks, 0.0)
208
- return float(reciprocals.mean()) if reciprocals.size > 0 else 0.0
209
-
210
-
211
161
  def compute_precision_at_k(
212
162
  y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
213
163
  ) -> float:
@@ -514,26 +464,6 @@ def compute_single_metric(
514
464
  """Compute a single metric given true and predicted values."""
515
465
  y_p_binary = (y_pred > 0.5).astype(int)
516
466
  metric_lower = metric.lower()
517
- is_multiclass = task_type == "multiclass" and y_pred.ndim >= 2
518
- if is_multiclass:
519
- # Dedicated path for multiclass logits (e.g., next-item prediction)
520
- labels, scores = normalize_multiclass_inputs(y_true, y_pred)
521
- if metric_lower in ("accuracy", "acc"):
522
- preds = scores.argmax(axis=1)
523
- return float((preds == labels).mean())
524
- if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
525
- k_str = metric_lower.split("@")[1]
526
- k = int(k_str)
527
- return multiclass_topk_hit_rate(labels, scores, k)
528
- if metric_lower.startswith("recall@"):
529
- k = int(metric_lower.split("@")[1])
530
- return multiclass_topk_hit_rate(labels, scores, k)
531
- if metric_lower.startswith("mrr@"):
532
- k = int(metric_lower.split("@")[1])
533
- return multiclass_mrr_at_k(labels, scores, k)
534
- # fall back to accuracy if unsupported metric is requested
535
- preds = scores.argmax(axis=1)
536
- return float((preds == labels).mean())
537
467
  try:
538
468
  if metric_lower.startswith("recall@"):
539
469
  k = int(metric_lower.split("@")[1])