nextrec 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/features.py +1 -1
- nextrec/basic/loggers.py +71 -8
- nextrec/basic/model.py +45 -11
- nextrec/basic/session.py +3 -10
- nextrec/data/__init__.py +47 -9
- nextrec/data/batch_utils.py +80 -0
- nextrec/data/data_processing.py +152 -0
- nextrec/data/data_utils.py +35 -268
- nextrec/data/dataloader.py +6 -4
- nextrec/data/preprocessor.py +39 -85
- nextrec/models/multi_task/poso.py +1 -1
- nextrec/utils/__init__.py +53 -3
- nextrec/utils/device.py +37 -0
- nextrec/utils/feature.py +13 -0
- nextrec/utils/file.py +70 -0
- nextrec/utils/initializer.py +0 -8
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +0 -19
- nextrec/utils/tensor.py +61 -0
- {nextrec-0.3.3.dist-info → nextrec-0.3.5.dist-info}/METADATA +3 -3
- {nextrec-0.3.3.dist-info → nextrec-0.3.5.dist-info}/RECORD +24 -18
- nextrec/utils/common.py +0 -60
- {nextrec-0.3.3.dist-info → nextrec-0.3.5.dist-info}/WHEEL +0 -0
- {nextrec-0.3.3.dist-info → nextrec-0.3.5.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.3.
|
|
1
|
+
__version__ = "0.3.5"
|
nextrec/basic/features.py
CHANGED
|
@@ -7,7 +7,7 @@ Author: Yang Zhou, zyaztec@gmail.com
|
|
|
7
7
|
"""
|
|
8
8
|
import torch
|
|
9
9
|
from nextrec.utils.embedding import get_auto_embedding_dim
|
|
10
|
-
from nextrec.utils.
|
|
10
|
+
from nextrec.utils.feature import normalize_to_list
|
|
11
11
|
|
|
12
12
|
class BaseFeature(object):
|
|
13
13
|
def __repr__(self):
|
nextrec/basic/loggers.py
CHANGED
|
@@ -2,17 +2,19 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 03/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
|
|
10
9
|
import os
|
|
11
10
|
import re
|
|
12
11
|
import sys
|
|
12
|
+
import json
|
|
13
13
|
import copy
|
|
14
14
|
import logging
|
|
15
|
-
|
|
15
|
+
import numbers
|
|
16
|
+
from typing import Mapping, Any
|
|
17
|
+
from nextrec.basic.session import create_session, Session
|
|
16
18
|
|
|
17
19
|
ANSI_CODES = {
|
|
18
20
|
'black': '\033[30m',
|
|
@@ -77,17 +79,12 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
|
|
|
77
79
|
"""Apply ANSI color and bold formatting to the given text."""
|
|
78
80
|
if not color and not bold:
|
|
79
81
|
return text
|
|
80
|
-
|
|
81
82
|
result = ""
|
|
82
|
-
|
|
83
83
|
if bold:
|
|
84
84
|
result += ANSI_BOLD
|
|
85
|
-
|
|
86
85
|
if color and color in ANSI_CODES:
|
|
87
86
|
result += ANSI_CODES[color]
|
|
88
|
-
|
|
89
87
|
result += text + ANSI_RESET
|
|
90
|
-
|
|
91
88
|
return result
|
|
92
89
|
|
|
93
90
|
def setup_logger(session_id: str | os.PathLike | None = None):
|
|
@@ -126,3 +123,69 @@ def setup_logger(session_id: str | os.PathLike | None = None):
|
|
|
126
123
|
logger.addHandler(console_handler)
|
|
127
124
|
|
|
128
125
|
return logger
|
|
126
|
+
|
|
127
|
+
class TrainingLogger:
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
session: Session,
|
|
131
|
+
enable_tensorboard: bool,
|
|
132
|
+
log_name: str = "training_metrics.jsonl",
|
|
133
|
+
) -> None:
|
|
134
|
+
self.session = session
|
|
135
|
+
self.enable_tensorboard = enable_tensorboard
|
|
136
|
+
self.log_path = session.metrics_dir / log_name
|
|
137
|
+
self.log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
138
|
+
|
|
139
|
+
self.tb_writer = None
|
|
140
|
+
self.tb_dir = None
|
|
141
|
+
|
|
142
|
+
if self.enable_tensorboard:
|
|
143
|
+
self._init_tensorboard()
|
|
144
|
+
|
|
145
|
+
def _init_tensorboard(self) -> None:
|
|
146
|
+
try:
|
|
147
|
+
from torch.utils.tensorboard import SummaryWriter # type: ignore
|
|
148
|
+
except ImportError:
|
|
149
|
+
logging.warning("[TrainingLogger] tensorboard not installed, disable tensorboard logging.")
|
|
150
|
+
self.enable_tensorboard = False
|
|
151
|
+
return
|
|
152
|
+
tb_dir = self.session.logs_dir / "tensorboard"
|
|
153
|
+
tb_dir.mkdir(parents=True, exist_ok=True)
|
|
154
|
+
self.tb_dir = tb_dir
|
|
155
|
+
self.tb_writer = SummaryWriter(log_dir=str(tb_dir))
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def tensorboard_logdir(self):
|
|
159
|
+
return self.tb_dir
|
|
160
|
+
|
|
161
|
+
def format_metrics(self, metrics: Mapping[str, Any], split: str) -> dict[str, float]:
|
|
162
|
+
formatted: dict[str, float] = {}
|
|
163
|
+
for key, value in metrics.items():
|
|
164
|
+
if isinstance(value, numbers.Number):
|
|
165
|
+
formatted[f"{split}/{key}"] = float(value)
|
|
166
|
+
elif hasattr(value, "item"):
|
|
167
|
+
try:
|
|
168
|
+
formatted[f"{split}/{key}"] = float(value.item())
|
|
169
|
+
except Exception:
|
|
170
|
+
continue
|
|
171
|
+
return formatted
|
|
172
|
+
|
|
173
|
+
def log_metrics(self, metrics: Mapping[str, Any], step: int, split: str = "train") -> None:
|
|
174
|
+
payload = self.format_metrics(metrics, split)
|
|
175
|
+
payload["step"] = int(step)
|
|
176
|
+
with self.log_path.open("a", encoding="utf-8") as f:
|
|
177
|
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
|
178
|
+
|
|
179
|
+
if not self.tb_writer:
|
|
180
|
+
return
|
|
181
|
+
step = int(payload.get("step", 0))
|
|
182
|
+
for key, value in payload.items():
|
|
183
|
+
if key == "step":
|
|
184
|
+
continue
|
|
185
|
+
self.tb_writer.add_scalar(key, value, global_step=step)
|
|
186
|
+
|
|
187
|
+
def close(self) -> None:
|
|
188
|
+
if self.tb_writer:
|
|
189
|
+
self.tb_writer.flush()
|
|
190
|
+
self.tb_writer.close()
|
|
191
|
+
self.tb_writer = None
|
nextrec/basic/model.py
CHANGED
|
@@ -10,6 +10,8 @@ import os
|
|
|
10
10
|
import tqdm
|
|
11
11
|
import pickle
|
|
12
12
|
import logging
|
|
13
|
+
import getpass
|
|
14
|
+
import socket
|
|
13
15
|
import numpy as np
|
|
14
16
|
import pandas as pd
|
|
15
17
|
import torch
|
|
@@ -24,15 +26,17 @@ from nextrec.basic.callback import EarlyStopper
|
|
|
24
26
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
|
|
25
27
|
from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
|
|
26
28
|
|
|
27
|
-
from nextrec.basic.loggers import setup_logger, colorize
|
|
29
|
+
from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
|
|
28
30
|
from nextrec.basic.session import resolve_save_path, create_session
|
|
29
31
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
30
32
|
|
|
31
33
|
from nextrec.data.dataloader import build_tensors_from_data
|
|
32
|
-
from nextrec.data.
|
|
34
|
+
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
35
|
+
from nextrec.data.batch_utils import collate_fn, batch_to_dict
|
|
33
36
|
|
|
34
37
|
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
35
|
-
from nextrec.utils import get_optimizer, get_scheduler
|
|
38
|
+
from nextrec.utils import get_optimizer, get_scheduler
|
|
39
|
+
from nextrec.utils.tensor import to_tensor
|
|
36
40
|
|
|
37
41
|
from nextrec import __version__
|
|
38
42
|
|
|
@@ -88,6 +92,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
88
92
|
self.early_stop_patience = early_stop_patience
|
|
89
93
|
self.max_gradient_norm = 1.0
|
|
90
94
|
self.logger_initialized = False
|
|
95
|
+
self.training_logger: TrainingLogger | None = None
|
|
91
96
|
|
|
92
97
|
def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
|
|
93
98
|
exclude_modules = exclude_modules or []
|
|
@@ -275,11 +280,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
275
280
|
metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
276
281
|
epochs:int=1, shuffle:bool=True, batch_size:int=32,
|
|
277
282
|
user_id_column: str | None = None,
|
|
278
|
-
validation_split: float | None = None
|
|
283
|
+
validation_split: float | None = None,
|
|
284
|
+
tensorboard: bool = True,):
|
|
279
285
|
self.to(self.device)
|
|
280
286
|
if not self.logger_initialized:
|
|
281
287
|
setup_logger(session_id=self.session_id)
|
|
282
288
|
self.logger_initialized = True
|
|
289
|
+
self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
|
|
283
290
|
|
|
284
291
|
self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
285
292
|
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
@@ -303,6 +310,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
303
310
|
is_streaming = True
|
|
304
311
|
|
|
305
312
|
self.summary()
|
|
313
|
+
logging.info("")
|
|
314
|
+
if self.training_logger and self.training_logger.enable_tensorboard:
|
|
315
|
+
tb_dir = self.training_logger.tensorboard_logdir
|
|
316
|
+
if tb_dir:
|
|
317
|
+
user = getpass.getuser()
|
|
318
|
+
host = socket.gethostname()
|
|
319
|
+
tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
|
|
320
|
+
ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
|
|
321
|
+
logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
|
|
322
|
+
logging.info(colorize("To view logs, run:", color="cyan"))
|
|
323
|
+
logging.info(colorize(f" {tb_cmd}", color="cyan"))
|
|
324
|
+
logging.info(colorize("Then SSH port forward:", color="cyan"))
|
|
325
|
+
logging.info(colorize(f" {ssh_hint}", color="cyan"))
|
|
326
|
+
|
|
306
327
|
logging.info("")
|
|
307
328
|
logging.info(colorize("=" * 80, bold=True))
|
|
308
329
|
if is_streaming:
|
|
@@ -312,7 +333,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
312
333
|
logging.info(colorize("=" * 80, bold=True))
|
|
313
334
|
logging.info("")
|
|
314
335
|
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
315
|
-
|
|
336
|
+
|
|
316
337
|
for epoch in range(epochs):
|
|
317
338
|
self.epoch_index = epoch
|
|
318
339
|
if is_streaming:
|
|
@@ -326,7 +347,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
326
347
|
else:
|
|
327
348
|
train_loss = train_result
|
|
328
349
|
train_metrics = None
|
|
329
|
-
|
|
350
|
+
|
|
351
|
+
train_log_payload: dict[str, float] = {}
|
|
330
352
|
# handle logging for single-task and multi-task
|
|
331
353
|
if self.nums_task == 1:
|
|
332
354
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
@@ -334,6 +356,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
334
356
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
|
|
335
357
|
log_str += f", {metrics_str}"
|
|
336
358
|
logging.info(colorize(log_str))
|
|
359
|
+
train_log_payload["loss"] = float(train_loss)
|
|
360
|
+
if train_metrics:
|
|
361
|
+
train_log_payload.update(train_metrics)
|
|
337
362
|
else:
|
|
338
363
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
339
364
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
@@ -356,12 +381,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
356
381
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
357
382
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
358
383
|
logging.info(colorize(log_str))
|
|
384
|
+
train_log_payload["loss"] = float(total_loss_val)
|
|
385
|
+
if train_metrics:
|
|
386
|
+
train_log_payload.update(train_metrics)
|
|
387
|
+
if self.training_logger:
|
|
388
|
+
self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
|
|
359
389
|
if valid_loader is not None:
|
|
360
390
|
# pass user_ids only if needed for GAUC metric
|
|
361
391
|
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
362
392
|
if self.nums_task == 1:
|
|
363
393
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
364
|
-
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
394
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
365
395
|
else:
|
|
366
396
|
# multi task metrics
|
|
367
397
|
task_metrics = {}
|
|
@@ -378,7 +408,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
378
408
|
if target_name in task_metrics:
|
|
379
409
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
380
410
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
381
|
-
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
411
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
412
|
+
if val_metrics and self.training_logger:
|
|
413
|
+
self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
|
|
382
414
|
# Handle empty validation metrics
|
|
383
415
|
if not val_metrics:
|
|
384
416
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
@@ -401,6 +433,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
401
433
|
self.best_metric = primary_metric
|
|
402
434
|
improved = True
|
|
403
435
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
436
|
+
logging.info(" ")
|
|
404
437
|
if improved:
|
|
405
438
|
logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
|
|
406
439
|
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
@@ -431,6 +464,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
431
464
|
if valid_loader is not None:
|
|
432
465
|
logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
|
|
433
466
|
self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
|
|
467
|
+
if self.training_logger:
|
|
468
|
+
self.training_logger.close()
|
|
434
469
|
return self
|
|
435
470
|
|
|
436
471
|
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
@@ -527,6 +562,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
527
562
|
batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
|
|
528
563
|
if batch_user_id is not None:
|
|
529
564
|
collected_user_ids.append(batch_user_id)
|
|
565
|
+
logging.info(" ")
|
|
530
566
|
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
531
567
|
if len(y_true_list) > 0:
|
|
532
568
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
@@ -956,9 +992,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
956
992
|
logger.info(f" Session ID: {self.session_id}")
|
|
957
993
|
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
958
994
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
959
|
-
|
|
960
|
-
logger.info("")
|
|
961
|
-
logger.info("")
|
|
995
|
+
|
|
962
996
|
|
|
963
997
|
|
|
964
998
|
class BaseMatchModel(BaseModel):
|
nextrec/basic/session.py
CHANGED
|
@@ -1,14 +1,5 @@
|
|
|
1
1
|
"""Session and experiment utilities.
|
|
2
2
|
|
|
3
|
-
This module centralizes session/experiment management so the rest of the
|
|
4
|
-
framework writes all artifacts to a consistent location:: <pwd>/log/<experiment_id>/
|
|
5
|
-
|
|
6
|
-
Within that folder we keep model parameters, checkpoints, training metrics,
|
|
7
|
-
evaluation metrics, and consolidated log output. When users do not provide an
|
|
8
|
-
``experiment_id`` a timestamp-based identifier is generated once per process to
|
|
9
|
-
avoid scattering files across multiple directories. Test runs are redirected to
|
|
10
|
-
temporary folders so local trees are not polluted.
|
|
11
|
-
|
|
12
3
|
Date: create on 23/11/2025
|
|
13
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
14
5
|
"""
|
|
@@ -16,7 +7,7 @@ Author: Yang Zhou,zyaztec@gmail.com
|
|
|
16
7
|
import os
|
|
17
8
|
import tempfile
|
|
18
9
|
from dataclasses import dataclass
|
|
19
|
-
from datetime import datetime
|
|
10
|
+
from datetime import datetime, timezone
|
|
20
11
|
from pathlib import Path
|
|
21
12
|
|
|
22
13
|
__all__ = [
|
|
@@ -74,6 +65,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
|
74
65
|
if experiment_id is not None and str(experiment_id).strip():
|
|
75
66
|
exp_id = str(experiment_id).strip()
|
|
76
67
|
else:
|
|
68
|
+
# Use local time for session naming
|
|
77
69
|
exp_id = "nextrec_session_" + datetime.now().strftime("%Y%m%d")
|
|
78
70
|
|
|
79
71
|
if (
|
|
@@ -111,6 +103,7 @@ def resolve_save_path(
|
|
|
111
103
|
timestamp.
|
|
112
104
|
- Parent directories are created.
|
|
113
105
|
"""
|
|
106
|
+
# Use local time for file timestamps
|
|
114
107
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if add_timestamp else None
|
|
115
108
|
|
|
116
109
|
normalized_suffix = suffix if suffix.startswith(".") else f".{suffix}"
|
nextrec/data/__init__.py
CHANGED
|
@@ -1,48 +1,86 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Data utilities package for NextRec
|
|
3
3
|
|
|
4
|
-
This package provides data processing and manipulation utilities
|
|
4
|
+
This package provides data processing and manipulation utilities organized by category:
|
|
5
|
+
- batch_utils: Batch collation and processing
|
|
6
|
+
- data_processing: Data manipulation and user ID extraction
|
|
7
|
+
- data_utils: Legacy module (re-exports from specialized modules)
|
|
8
|
+
- dataloader: Dataset and DataLoader implementations
|
|
9
|
+
- preprocessor: Data preprocessing pipeline
|
|
5
10
|
|
|
6
11
|
Date: create on 13/11/2025
|
|
12
|
+
Last update: 03/12/2025 (refactored)
|
|
7
13
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
8
14
|
"""
|
|
9
15
|
|
|
10
|
-
|
|
11
|
-
|
|
16
|
+
# Batch utilities
|
|
17
|
+
from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
|
|
18
|
+
|
|
19
|
+
# Data processing utilities
|
|
20
|
+
from nextrec.data.data_processing import (
|
|
12
21
|
get_column_data,
|
|
13
|
-
default_output_dir,
|
|
14
22
|
split_dict_random,
|
|
15
23
|
build_eval_candidates,
|
|
24
|
+
get_user_ids,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# File utilities (from utils package)
|
|
28
|
+
from nextrec.utils.file import (
|
|
16
29
|
resolve_file_paths,
|
|
17
30
|
iter_file_chunks,
|
|
18
31
|
read_table,
|
|
19
32
|
load_dataframes,
|
|
33
|
+
default_output_dir,
|
|
20
34
|
)
|
|
21
|
-
|
|
22
|
-
|
|
35
|
+
|
|
36
|
+
# DataLoader components
|
|
23
37
|
from nextrec.data.dataloader import (
|
|
24
38
|
TensorDictDataset,
|
|
25
39
|
FileDataset,
|
|
26
40
|
RecDataLoader,
|
|
27
41
|
build_tensors_from_data,
|
|
28
42
|
)
|
|
43
|
+
|
|
44
|
+
# Preprocessor
|
|
29
45
|
from nextrec.data.preprocessor import DataProcessor
|
|
30
46
|
|
|
47
|
+
# Feature definitions
|
|
48
|
+
from nextrec.basic.features import FeatureSet
|
|
49
|
+
|
|
50
|
+
# Legacy module (for backward compatibility)
|
|
51
|
+
from nextrec.data import data_utils
|
|
52
|
+
|
|
31
53
|
__all__ = [
|
|
54
|
+
# Batch utilities
|
|
32
55
|
'collate_fn',
|
|
56
|
+
'batch_to_dict',
|
|
57
|
+
'stack_section',
|
|
58
|
+
|
|
59
|
+
# Data processing
|
|
33
60
|
'get_column_data',
|
|
34
|
-
'default_output_dir',
|
|
35
61
|
'split_dict_random',
|
|
36
62
|
'build_eval_candidates',
|
|
63
|
+
'get_user_ids',
|
|
64
|
+
|
|
65
|
+
# File utilities
|
|
37
66
|
'resolve_file_paths',
|
|
38
67
|
'iter_file_chunks',
|
|
39
68
|
'read_table',
|
|
40
69
|
'load_dataframes',
|
|
41
|
-
'
|
|
42
|
-
|
|
70
|
+
'default_output_dir',
|
|
71
|
+
|
|
72
|
+
# DataLoader
|
|
43
73
|
'TensorDictDataset',
|
|
44
74
|
'FileDataset',
|
|
45
75
|
'RecDataLoader',
|
|
46
76
|
'build_tensors_from_data',
|
|
77
|
+
|
|
78
|
+
# Preprocessor
|
|
47
79
|
'DataProcessor',
|
|
80
|
+
|
|
81
|
+
# Features
|
|
82
|
+
'FeatureSet',
|
|
83
|
+
|
|
84
|
+
# Legacy module
|
|
85
|
+
'data_utils',
|
|
48
86
|
]
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Batch collation utilities for NextRec
|
|
3
|
+
|
|
4
|
+
Date: create on 03/12/2025
|
|
5
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import numpy as np
|
|
10
|
+
from typing import Any, Mapping
|
|
11
|
+
|
|
12
|
+
def stack_section(batch: list[dict], section: str):
|
|
13
|
+
entries = [item.get(section) for item in batch if item.get(section) is not None]
|
|
14
|
+
if not entries:
|
|
15
|
+
return None
|
|
16
|
+
merged: dict = {}
|
|
17
|
+
for name in entries[0]: # type: ignore
|
|
18
|
+
tensors = [item[section][name] for item in batch if item.get(section) is not None and name in item[section]]
|
|
19
|
+
merged[name] = torch.stack(tensors, dim=0)
|
|
20
|
+
return merged
|
|
21
|
+
|
|
22
|
+
def collate_fn(batch):
|
|
23
|
+
"""
|
|
24
|
+
Collate a list of sample dicts into the unified batch format:
|
|
25
|
+
{
|
|
26
|
+
"features": {name: Tensor(B, ...)},
|
|
27
|
+
"labels": {target: Tensor(B, ...)} or None,
|
|
28
|
+
"ids": {id_name: Tensor(B, ...)} or None,
|
|
29
|
+
}
|
|
30
|
+
Args: batch: List of samples from DataLoader
|
|
31
|
+
|
|
32
|
+
Returns: dict: Batched data in unified format
|
|
33
|
+
"""
|
|
34
|
+
if not batch:
|
|
35
|
+
return {"features": {}, "labels": None, "ids": None}
|
|
36
|
+
|
|
37
|
+
first = batch[0]
|
|
38
|
+
if isinstance(first, dict) and "features" in first:
|
|
39
|
+
# Streaming dataset yields already-batched chunks; avoid adding an extra dim.
|
|
40
|
+
if first.get("_already_batched") and len(batch) == 1:
|
|
41
|
+
return {
|
|
42
|
+
"features": first.get("features", {}),
|
|
43
|
+
"labels": first.get("labels"),
|
|
44
|
+
"ids": first.get("ids"),
|
|
45
|
+
}
|
|
46
|
+
return {
|
|
47
|
+
"features": stack_section(batch, "features") or {},
|
|
48
|
+
"labels": stack_section(batch, "labels"),
|
|
49
|
+
"ids": stack_section(batch, "ids"),
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
# Fallback: stack tuples/lists of tensors
|
|
53
|
+
num_tensors = len(first)
|
|
54
|
+
result = []
|
|
55
|
+
for i in range(num_tensors):
|
|
56
|
+
tensor_list = [item[i] for item in batch]
|
|
57
|
+
first_item = tensor_list[0]
|
|
58
|
+
if isinstance(first_item, torch.Tensor):
|
|
59
|
+
stacked = torch.cat(tensor_list, dim=0)
|
|
60
|
+
elif isinstance(first_item, np.ndarray):
|
|
61
|
+
stacked = np.concatenate(tensor_list, axis=0)
|
|
62
|
+
elif isinstance(first_item, list):
|
|
63
|
+
combined = []
|
|
64
|
+
for entry in tensor_list:
|
|
65
|
+
combined.extend(entry)
|
|
66
|
+
stacked = combined
|
|
67
|
+
else:
|
|
68
|
+
stacked = tensor_list
|
|
69
|
+
result.append(stacked)
|
|
70
|
+
return tuple(result)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def batch_to_dict(batch_data: Any, include_ids: bool = True) -> dict:
|
|
74
|
+
if not (isinstance(batch_data, Mapping) and "features" in batch_data):
|
|
75
|
+
raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
|
|
76
|
+
return {
|
|
77
|
+
"features": batch_data.get("features", {}),
|
|
78
|
+
"labels": batch_data.get("labels"),
|
|
79
|
+
"ids": batch_data.get("ids") if include_ids else None,
|
|
80
|
+
}
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data processing utilities for NextRec
|
|
3
|
+
|
|
4
|
+
Date: create on 03/12/2025
|
|
5
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from typing import Any, Mapping
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
15
|
+
if isinstance(data, dict):
|
|
16
|
+
return data[name] if name in data else None
|
|
17
|
+
elif isinstance(data, pd.DataFrame):
|
|
18
|
+
if name not in data.columns:
|
|
19
|
+
return None
|
|
20
|
+
return data[name].values
|
|
21
|
+
else:
|
|
22
|
+
if hasattr(data, name):
|
|
23
|
+
return getattr(data, name)
|
|
24
|
+
raise KeyError(f"Unsupported data type for extracting column {name}")
|
|
25
|
+
|
|
26
|
+
def split_dict_random(
|
|
27
|
+
data_dict: dict,
|
|
28
|
+
test_size: float = 0.2,
|
|
29
|
+
random_state: int | None = None
|
|
30
|
+
):
|
|
31
|
+
lengths = [len(v) for v in data_dict.values()]
|
|
32
|
+
if len(set(lengths)) != 1:
|
|
33
|
+
raise ValueError(f"Length mismatch: {lengths}")
|
|
34
|
+
|
|
35
|
+
n = lengths[0]
|
|
36
|
+
rng = np.random.default_rng(random_state)
|
|
37
|
+
perm = rng.permutation(n)
|
|
38
|
+
cut = int(round(n * (1 - test_size)))
|
|
39
|
+
train_idx, test_idx = perm[:cut], perm[cut:]
|
|
40
|
+
|
|
41
|
+
def take(v, idx):
|
|
42
|
+
if isinstance(v, np.ndarray):
|
|
43
|
+
return v[idx]
|
|
44
|
+
elif isinstance(v, pd.Series):
|
|
45
|
+
return v.iloc[idx].to_numpy()
|
|
46
|
+
else:
|
|
47
|
+
v_arr = np.asarray(v, dtype=object)
|
|
48
|
+
return v_arr[idx]
|
|
49
|
+
|
|
50
|
+
train_dict = {k: take(v, train_idx) for k, v in data_dict.items()}
|
|
51
|
+
test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
|
|
52
|
+
return train_dict, test_dict
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def build_eval_candidates(
|
|
56
|
+
df_all: pd.DataFrame,
|
|
57
|
+
user_col: str,
|
|
58
|
+
item_col: str,
|
|
59
|
+
label_col: str,
|
|
60
|
+
user_features: pd.DataFrame,
|
|
61
|
+
item_features: pd.DataFrame,
|
|
62
|
+
num_pos_per_user: int = 5,
|
|
63
|
+
num_neg_per_pos: int = 50,
|
|
64
|
+
random_seed: int = 2025,
|
|
65
|
+
) -> pd.DataFrame:
|
|
66
|
+
"""
|
|
67
|
+
Build evaluation candidates with positive and negative samples for each user.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
df_all: Full interaction DataFrame
|
|
71
|
+
user_col: Name of the user ID column
|
|
72
|
+
item_col: Name of the item ID column
|
|
73
|
+
label_col: Name of the label column
|
|
74
|
+
user_features: DataFrame containing user features
|
|
75
|
+
item_features: DataFrame containing item features
|
|
76
|
+
num_pos_per_user: Number of positive samples per user (default: 5)
|
|
77
|
+
num_neg_per_pos: Number of negative samples per positive (default: 50)
|
|
78
|
+
random_seed: Random seed for reproducibility (default: 2025)
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
pd.DataFrame: Evaluation candidates with features
|
|
82
|
+
"""
|
|
83
|
+
rng = np.random.default_rng(random_seed)
|
|
84
|
+
|
|
85
|
+
users = df_all[user_col].unique()
|
|
86
|
+
all_items = item_features[item_col].unique()
|
|
87
|
+
rows = []
|
|
88
|
+
user_hist_items = {u: df_all[df_all[user_col] == u][item_col].unique() for u in users}
|
|
89
|
+
|
|
90
|
+
for u in users:
|
|
91
|
+
df_user = df_all[df_all[user_col] == u]
|
|
92
|
+
pos_items = df_user[df_user[label_col] == 1][item_col].unique()
|
|
93
|
+
if len(pos_items) == 0:
|
|
94
|
+
continue
|
|
95
|
+
pos_items = pos_items[:num_pos_per_user]
|
|
96
|
+
seen_items = set(user_hist_items[u])
|
|
97
|
+
neg_pool = np.setdiff1d(all_items, np.fromiter(seen_items, dtype=all_items.dtype))
|
|
98
|
+
if len(neg_pool) == 0:
|
|
99
|
+
continue
|
|
100
|
+
for pos in pos_items:
|
|
101
|
+
if len(neg_pool) <= num_neg_per_pos:
|
|
102
|
+
neg_items = neg_pool
|
|
103
|
+
else:
|
|
104
|
+
neg_items = rng.choice(neg_pool, size=num_neg_per_pos, replace=False)
|
|
105
|
+
rows.append((u, pos, 1))
|
|
106
|
+
for ni in neg_items:
|
|
107
|
+
rows.append((u, ni, 0))
|
|
108
|
+
|
|
109
|
+
eval_df = pd.DataFrame(rows, columns=[user_col, item_col, label_col])
|
|
110
|
+
eval_df = eval_df.merge(user_features, on=user_col, how='left')
|
|
111
|
+
eval_df = eval_df.merge(item_features, on=item_col, how='left')
|
|
112
|
+
return eval_df
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_user_ids(
|
|
116
|
+
data: Any,
|
|
117
|
+
id_columns: list[str] | str | None = None
|
|
118
|
+
) -> np.ndarray | None:
|
|
119
|
+
"""
|
|
120
|
+
Extract user IDs from various data structures.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
data: Data source (DataFrame, dict, or batch dict)
|
|
124
|
+
id_columns: List or single ID column name(s) (default: None)
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
np.ndarray | None: User IDs as numpy array, or None if not found
|
|
128
|
+
"""
|
|
129
|
+
id_columns = (
|
|
130
|
+
id_columns if isinstance(id_columns, list)
|
|
131
|
+
else [id_columns] if isinstance(id_columns, str)
|
|
132
|
+
else []
|
|
133
|
+
)
|
|
134
|
+
if not id_columns:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
main_id = id_columns[0]
|
|
138
|
+
if isinstance(data, pd.DataFrame) and main_id in data.columns:
|
|
139
|
+
arr = np.asarray(data[main_id].values)
|
|
140
|
+
return arr.reshape(arr.shape[0])
|
|
141
|
+
|
|
142
|
+
if isinstance(data, dict):
|
|
143
|
+
ids_container = data.get("ids")
|
|
144
|
+
if isinstance(ids_container, dict) and main_id in ids_container:
|
|
145
|
+
val = ids_container[main_id]
|
|
146
|
+
val = val.detach().cpu().numpy() if isinstance(val, torch.Tensor) else np.asarray(val)
|
|
147
|
+
return val.reshape(val.shape[0])
|
|
148
|
+
if main_id in data:
|
|
149
|
+
arr = np.asarray(data[main_id])
|
|
150
|
+
return arr.reshape(arr.shape[0])
|
|
151
|
+
|
|
152
|
+
return None
|