nextrec 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +0 -30
- nextrec/__version__.py +1 -1
- nextrec/basic/features.py +1 -1
- nextrec/basic/loggers.py +1 -1
- nextrec/basic/model.py +20 -15
- nextrec/basic/session.py +7 -12
- nextrec/data/__init__.py +30 -17
- nextrec/data/batch_utils.py +80 -0
- nextrec/data/data_processing.py +152 -0
- nextrec/data/data_utils.py +35 -268
- nextrec/data/dataloader.py +19 -12
- nextrec/data/preprocessor.py +6 -16
- nextrec/models/generative/__init__.py +0 -5
- nextrec/models/match/__init__.py +0 -13
- nextrec/models/multi_task/__init__.py +0 -0
- nextrec/models/multi_task/poso.py +1 -1
- nextrec/models/ranking/__init__.py +0 -27
- nextrec/utils/__init__.py +53 -3
- nextrec/utils/device.py +38 -0
- nextrec/utils/feature.py +13 -0
- nextrec/utils/file.py +70 -0
- nextrec/utils/initializer.py +0 -8
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +0 -19
- nextrec/utils/tensor.py +61 -0
- {nextrec-0.3.4.dist-info → nextrec-0.3.6.dist-info}/METADATA +3 -3
- {nextrec-0.3.4.dist-info → nextrec-0.3.6.dist-info}/RECORD +29 -22
- nextrec/utils/common.py +0 -60
- {nextrec-0.3.4.dist-info → nextrec-0.3.6.dist-info}/WHEEL +0 -0
- {nextrec-0.3.4.dist-info → nextrec-0.3.6.dist-info}/licenses/LICENSE +0 -0
nextrec/__init__.py
CHANGED
|
@@ -1,33 +1,3 @@
|
|
|
1
|
-
"""
|
|
2
|
-
NextRec - A Unified Deep Learning Framework for Recommender Systems
|
|
3
|
-
===================================================================
|
|
4
|
-
|
|
5
|
-
NextRec provides a comprehensive suite of recommendation models including:
|
|
6
|
-
- Ranking models (CTR prediction)
|
|
7
|
-
- Matching models (retrieval)
|
|
8
|
-
- Multi-task learning models
|
|
9
|
-
- Generative recommendation models
|
|
10
|
-
|
|
11
|
-
Quick Start
|
|
12
|
-
-----------
|
|
13
|
-
>>> from nextrec.basic.features import DenseFeature, SparseFeature
|
|
14
|
-
>>> from nextrec.models.ranking.deepfm import DeepFM
|
|
15
|
-
>>>
|
|
16
|
-
>>> # Define features
|
|
17
|
-
>>> dense_features = [DenseFeature('age')]
|
|
18
|
-
>>> sparse_features = [SparseFeature('category', vocab_size=100, embedding_dim=16)]
|
|
19
|
-
>>>
|
|
20
|
-
>>> # Build model
|
|
21
|
-
>>> model = DeepFM(
|
|
22
|
-
... dense_features=dense_features,
|
|
23
|
-
... sparse_features=sparse_features,
|
|
24
|
-
... targets=['label']
|
|
25
|
-
... )
|
|
26
|
-
>>>
|
|
27
|
-
>>> # Train model
|
|
28
|
-
>>> model.fit(train_data=df_train, valid_data=df_valid)
|
|
29
|
-
"""
|
|
30
|
-
|
|
31
1
|
from nextrec.__version__ import __version__
|
|
32
2
|
|
|
33
3
|
__all__ = [
|
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.3.
|
|
1
|
+
__version__ = "0.3.6"
|
nextrec/basic/features.py
CHANGED
|
@@ -7,7 +7,7 @@ Author: Yang Zhou, zyaztec@gmail.com
|
|
|
7
7
|
"""
|
|
8
8
|
import torch
|
|
9
9
|
from nextrec.utils.embedding import get_auto_embedding_dim
|
|
10
|
-
from nextrec.utils.
|
|
10
|
+
from nextrec.utils.feature import normalize_to_list
|
|
11
11
|
|
|
12
12
|
class BaseFeature(object):
|
|
13
13
|
def __repr__(self):
|
nextrec/basic/loggers.py
CHANGED
|
@@ -99,7 +99,7 @@ def setup_logger(session_id: str | os.PathLike | None = None):
|
|
|
99
99
|
session = create_session(str(session_id) if session_id is not None else None)
|
|
100
100
|
log_dir = session.logs_dir
|
|
101
101
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
102
|
-
log_file = log_dir / f"{session.
|
|
102
|
+
log_file = log_dir / f"{session.log_basename}.log"
|
|
103
103
|
|
|
104
104
|
console_format = '%(message)s'
|
|
105
105
|
file_format = '%(asctime)s - %(levelname)s - %(message)s'
|
nextrec/basic/model.py
CHANGED
|
@@ -31,10 +31,12 @@ from nextrec.basic.session import resolve_save_path, create_session
|
|
|
31
31
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
32
32
|
|
|
33
33
|
from nextrec.data.dataloader import build_tensors_from_data
|
|
34
|
-
from nextrec.data.
|
|
34
|
+
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
35
|
+
from nextrec.data.batch_utils import collate_fn, batch_to_dict
|
|
35
36
|
|
|
36
37
|
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
37
|
-
from nextrec.utils import get_optimizer, get_scheduler
|
|
38
|
+
from nextrec.utils import get_optimizer, get_scheduler
|
|
39
|
+
from nextrec.utils.tensor import to_tensor
|
|
38
40
|
|
|
39
41
|
from nextrec import __version__
|
|
40
42
|
|
|
@@ -153,7 +155,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
153
155
|
raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
|
|
154
156
|
return X_input, y
|
|
155
157
|
|
|
156
|
-
def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
158
|
+
def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool, num_workers: int = 0,) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
157
159
|
"""This function will split training data into training and validation sets when: 1. valid_data is None; 2. validation_split is provided."""
|
|
158
160
|
if not (0 < validation_split < 1):
|
|
159
161
|
raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
|
|
@@ -182,7 +184,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
182
184
|
arr = np.asarray(value)
|
|
183
185
|
train_split[key] = arr[train_indices]
|
|
184
186
|
valid_split[key] = arr[valid_indices]
|
|
185
|
-
train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
187
|
+
train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
|
186
188
|
logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
|
|
187
189
|
return train_loader, valid_split
|
|
188
190
|
|
|
@@ -263,14 +265,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
263
265
|
task_losses.append(task_loss)
|
|
264
266
|
return torch.stack(task_losses).sum()
|
|
265
267
|
|
|
266
|
-
def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
|
|
268
|
+
def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True, num_workers: int = 0,) -> DataLoader:
|
|
267
269
|
if isinstance(data, DataLoader):
|
|
268
270
|
return data
|
|
269
271
|
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
|
|
270
272
|
if tensors is None:
|
|
271
273
|
raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
|
|
272
274
|
dataset = TensorDictDataset(tensors)
|
|
273
|
-
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
|
|
275
|
+
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=num_workers)
|
|
274
276
|
|
|
275
277
|
def fit(self,
|
|
276
278
|
train_data: dict | pd.DataFrame | DataLoader,
|
|
@@ -279,6 +281,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
279
281
|
epochs:int=1, shuffle:bool=True, batch_size:int=32,
|
|
280
282
|
user_id_column: str | None = None,
|
|
281
283
|
validation_split: float | None = None,
|
|
284
|
+
num_workers: int = 0,
|
|
282
285
|
tensorboard: bool = True,):
|
|
283
286
|
self.to(self.device)
|
|
284
287
|
if not self.logger_initialized:
|
|
@@ -295,11 +298,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
295
298
|
self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
296
299
|
|
|
297
300
|
if validation_split is not None and valid_data is None:
|
|
298
|
-
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,) # type: ignore
|
|
301
|
+
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
|
|
299
302
|
else:
|
|
300
|
-
train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
|
|
303
|
+
train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers))
|
|
301
304
|
|
|
302
|
-
valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column)
|
|
305
|
+
valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column, num_workers=num_workers)
|
|
303
306
|
try:
|
|
304
307
|
self.steps_per_epoch = len(train_loader)
|
|
305
308
|
is_streaming = False
|
|
@@ -386,7 +389,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
386
389
|
self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
|
|
387
390
|
if valid_loader is not None:
|
|
388
391
|
# pass user_ids only if needed for GAUC metric
|
|
389
|
-
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
392
|
+
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None, num_workers=num_workers) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
390
393
|
if self.nums_task == 1:
|
|
391
394
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
392
395
|
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
@@ -511,12 +514,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
511
514
|
return avg_loss, metrics_dict
|
|
512
515
|
return avg_loss
|
|
513
516
|
|
|
514
|
-
def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id') -> tuple[DataLoader | None, np.ndarray | None]:
|
|
517
|
+
def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id', num_workers: int = 0,) -> tuple[DataLoader | None, np.ndarray | None]:
|
|
515
518
|
if valid_data is None:
|
|
516
519
|
return None, None
|
|
517
520
|
if isinstance(valid_data, DataLoader):
|
|
518
521
|
return valid_data, None
|
|
519
|
-
valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
|
|
522
|
+
valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
520
523
|
valid_user_ids = None
|
|
521
524
|
if needs_user_ids:
|
|
522
525
|
if user_id_column is None:
|
|
@@ -529,7 +532,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
529
532
|
metrics: list[str] | dict[str, list[str]] | None = None,
|
|
530
533
|
batch_size: int = 32,
|
|
531
534
|
user_ids: np.ndarray | None = None,
|
|
532
|
-
user_id_column: str = 'user_id'
|
|
535
|
+
user_id_column: str = 'user_id',
|
|
536
|
+
num_workers: int = 0,) -> dict:
|
|
533
537
|
self.eval()
|
|
534
538
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
535
539
|
if eval_metrics is None:
|
|
@@ -541,7 +545,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
541
545
|
else:
|
|
542
546
|
if user_ids is None and needs_user_ids:
|
|
543
547
|
user_ids = get_user_ids(data=data, id_columns=user_id_column)
|
|
544
|
-
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
548
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
545
549
|
y_true_list = []
|
|
546
550
|
y_pred_list = []
|
|
547
551
|
collected_user_ids = []
|
|
@@ -601,6 +605,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
601
605
|
include_ids: bool | None = None,
|
|
602
606
|
return_dataframe: bool = True,
|
|
603
607
|
streaming_chunk_size: int = 10000,
|
|
608
|
+
num_workers: int = 0,
|
|
604
609
|
) -> pd.DataFrame | np.ndarray:
|
|
605
610
|
self.eval()
|
|
606
611
|
if include_ids is None:
|
|
@@ -613,7 +618,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
613
618
|
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns,)
|
|
614
619
|
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
615
620
|
elif not isinstance(data, DataLoader):
|
|
616
|
-
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
621
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
617
622
|
else:
|
|
618
623
|
data_loader = data
|
|
619
624
|
|
nextrec/basic/session.py
CHANGED
|
@@ -1,14 +1,5 @@
|
|
|
1
1
|
"""Session and experiment utilities.
|
|
2
2
|
|
|
3
|
-
This module centralizes session/experiment management so the rest of the
|
|
4
|
-
framework writes all artifacts to a consistent location:: <pwd>/log/<experiment_id>/
|
|
5
|
-
|
|
6
|
-
Within that folder we keep model parameters, checkpoints, training metrics,
|
|
7
|
-
evaluation metrics, and consolidated log output. When users do not provide an
|
|
8
|
-
``experiment_id`` a timestamp-based identifier is generated once per process to
|
|
9
|
-
avoid scattering files across multiple directories. Test runs are redirected to
|
|
10
|
-
temporary folders so local trees are not polluted.
|
|
11
|
-
|
|
12
3
|
Date: create on 23/11/2025
|
|
13
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
14
5
|
"""
|
|
@@ -16,7 +7,7 @@ Author: Yang Zhou,zyaztec@gmail.com
|
|
|
16
7
|
import os
|
|
17
8
|
import tempfile
|
|
18
9
|
from dataclasses import dataclass
|
|
19
|
-
from datetime import datetime
|
|
10
|
+
from datetime import datetime, timezone
|
|
20
11
|
from pathlib import Path
|
|
21
12
|
|
|
22
13
|
__all__ = [
|
|
@@ -31,6 +22,7 @@ class Session:
|
|
|
31
22
|
|
|
32
23
|
experiment_id: str
|
|
33
24
|
root: Path
|
|
25
|
+
log_basename: str # The base name for log files, without path separators
|
|
34
26
|
|
|
35
27
|
@property
|
|
36
28
|
def logs_dir(self) -> Path:
|
|
@@ -69,13 +61,15 @@ class Session:
|
|
|
69
61
|
return path
|
|
70
62
|
|
|
71
63
|
def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
72
|
-
"""Create a :class:`Session` instance with prepared directories."""
|
|
73
64
|
|
|
74
65
|
if experiment_id is not None and str(experiment_id).strip():
|
|
75
66
|
exp_id = str(experiment_id).strip()
|
|
76
67
|
else:
|
|
68
|
+
# Use local time for session naming
|
|
77
69
|
exp_id = "nextrec_session_" + datetime.now().strftime("%Y%m%d")
|
|
78
70
|
|
|
71
|
+
log_basename = Path(exp_id).name if exp_id else exp_id
|
|
72
|
+
|
|
79
73
|
if (
|
|
80
74
|
os.getenv("PYTEST_CURRENT_TEST")
|
|
81
75
|
or os.getenv("PYTEST_RUNNING")
|
|
@@ -90,7 +84,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
|
90
84
|
session_path.mkdir(parents=True, exist_ok=True)
|
|
91
85
|
root = session_path.resolve()
|
|
92
86
|
|
|
93
|
-
return Session(experiment_id=exp_id, root=root)
|
|
87
|
+
return Session(experiment_id=exp_id, root=root, log_basename=log_basename)
|
|
94
88
|
|
|
95
89
|
def resolve_save_path(
|
|
96
90
|
path: str | os.PathLike | Path | None,
|
|
@@ -111,6 +105,7 @@ def resolve_save_path(
|
|
|
111
105
|
timestamp.
|
|
112
106
|
- Parent directories are created.
|
|
113
107
|
"""
|
|
108
|
+
# Use local time for file timestamps
|
|
114
109
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if add_timestamp else None
|
|
115
110
|
|
|
116
111
|
normalized_suffix = suffix if suffix.startswith(".") else f".{suffix}"
|
nextrec/data/__init__.py
CHANGED
|
@@ -1,48 +1,61 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
This package provides data processing and manipulation utilities.
|
|
5
|
-
|
|
6
|
-
Date: create on 13/11/2025
|
|
7
|
-
Author: Yang Zhou, zyaztec@gmail.com
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
|
-
from nextrec.data.data_utils import (
|
|
11
|
-
collate_fn,
|
|
1
|
+
from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
|
|
2
|
+
from nextrec.data.data_processing import (
|
|
12
3
|
get_column_data,
|
|
13
|
-
default_output_dir,
|
|
14
4
|
split_dict_random,
|
|
15
5
|
build_eval_candidates,
|
|
6
|
+
get_user_ids,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from nextrec.utils.file import (
|
|
16
10
|
resolve_file_paths,
|
|
17
11
|
iter_file_chunks,
|
|
18
12
|
read_table,
|
|
19
13
|
load_dataframes,
|
|
14
|
+
default_output_dir,
|
|
20
15
|
)
|
|
21
|
-
|
|
22
|
-
from nextrec.data import data_utils
|
|
16
|
+
|
|
23
17
|
from nextrec.data.dataloader import (
|
|
24
18
|
TensorDictDataset,
|
|
25
19
|
FileDataset,
|
|
26
20
|
RecDataLoader,
|
|
27
21
|
build_tensors_from_data,
|
|
28
22
|
)
|
|
23
|
+
|
|
29
24
|
from nextrec.data.preprocessor import DataProcessor
|
|
25
|
+
from nextrec.basic.features import FeatureSet
|
|
26
|
+
from nextrec.data import data_utils
|
|
30
27
|
|
|
31
28
|
__all__ = [
|
|
29
|
+
# Batch utilities
|
|
32
30
|
'collate_fn',
|
|
31
|
+
'batch_to_dict',
|
|
32
|
+
'stack_section',
|
|
33
|
+
|
|
34
|
+
# Data processing
|
|
33
35
|
'get_column_data',
|
|
34
|
-
'default_output_dir',
|
|
35
36
|
'split_dict_random',
|
|
36
37
|
'build_eval_candidates',
|
|
38
|
+
'get_user_ids',
|
|
39
|
+
|
|
40
|
+
# File utilities
|
|
37
41
|
'resolve_file_paths',
|
|
38
42
|
'iter_file_chunks',
|
|
39
43
|
'read_table',
|
|
40
44
|
'load_dataframes',
|
|
41
|
-
'
|
|
42
|
-
|
|
45
|
+
'default_output_dir',
|
|
46
|
+
|
|
47
|
+
# DataLoader
|
|
43
48
|
'TensorDictDataset',
|
|
44
49
|
'FileDataset',
|
|
45
50
|
'RecDataLoader',
|
|
46
51
|
'build_tensors_from_data',
|
|
52
|
+
|
|
53
|
+
# Preprocessor
|
|
47
54
|
'DataProcessor',
|
|
55
|
+
|
|
56
|
+
# Features
|
|
57
|
+
'FeatureSet',
|
|
58
|
+
|
|
59
|
+
# Legacy module
|
|
60
|
+
'data_utils',
|
|
48
61
|
]
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Batch collation utilities for NextRec
|
|
3
|
+
|
|
4
|
+
Date: create on 03/12/2025
|
|
5
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import numpy as np
|
|
10
|
+
from typing import Any, Mapping
|
|
11
|
+
|
|
12
|
+
def stack_section(batch: list[dict], section: str):
|
|
13
|
+
entries = [item.get(section) for item in batch if item.get(section) is not None]
|
|
14
|
+
if not entries:
|
|
15
|
+
return None
|
|
16
|
+
merged: dict = {}
|
|
17
|
+
for name in entries[0]: # type: ignore
|
|
18
|
+
tensors = [item[section][name] for item in batch if item.get(section) is not None and name in item[section]]
|
|
19
|
+
merged[name] = torch.stack(tensors, dim=0)
|
|
20
|
+
return merged
|
|
21
|
+
|
|
22
|
+
def collate_fn(batch):
|
|
23
|
+
"""
|
|
24
|
+
Collate a list of sample dicts into the unified batch format:
|
|
25
|
+
{
|
|
26
|
+
"features": {name: Tensor(B, ...)},
|
|
27
|
+
"labels": {target: Tensor(B, ...)} or None,
|
|
28
|
+
"ids": {id_name: Tensor(B, ...)} or None,
|
|
29
|
+
}
|
|
30
|
+
Args: batch: List of samples from DataLoader
|
|
31
|
+
|
|
32
|
+
Returns: dict: Batched data in unified format
|
|
33
|
+
"""
|
|
34
|
+
if not batch:
|
|
35
|
+
return {"features": {}, "labels": None, "ids": None}
|
|
36
|
+
|
|
37
|
+
first = batch[0]
|
|
38
|
+
if isinstance(first, dict) and "features" in first:
|
|
39
|
+
# Streaming dataset yields already-batched chunks; avoid adding an extra dim.
|
|
40
|
+
if first.get("_already_batched") and len(batch) == 1:
|
|
41
|
+
return {
|
|
42
|
+
"features": first.get("features", {}),
|
|
43
|
+
"labels": first.get("labels"),
|
|
44
|
+
"ids": first.get("ids"),
|
|
45
|
+
}
|
|
46
|
+
return {
|
|
47
|
+
"features": stack_section(batch, "features") or {},
|
|
48
|
+
"labels": stack_section(batch, "labels"),
|
|
49
|
+
"ids": stack_section(batch, "ids"),
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
# Fallback: stack tuples/lists of tensors
|
|
53
|
+
num_tensors = len(first)
|
|
54
|
+
result = []
|
|
55
|
+
for i in range(num_tensors):
|
|
56
|
+
tensor_list = [item[i] for item in batch]
|
|
57
|
+
first_item = tensor_list[0]
|
|
58
|
+
if isinstance(first_item, torch.Tensor):
|
|
59
|
+
stacked = torch.cat(tensor_list, dim=0)
|
|
60
|
+
elif isinstance(first_item, np.ndarray):
|
|
61
|
+
stacked = np.concatenate(tensor_list, axis=0)
|
|
62
|
+
elif isinstance(first_item, list):
|
|
63
|
+
combined = []
|
|
64
|
+
for entry in tensor_list:
|
|
65
|
+
combined.extend(entry)
|
|
66
|
+
stacked = combined
|
|
67
|
+
else:
|
|
68
|
+
stacked = tensor_list
|
|
69
|
+
result.append(stacked)
|
|
70
|
+
return tuple(result)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def batch_to_dict(batch_data: Any, include_ids: bool = True) -> dict:
|
|
74
|
+
if not (isinstance(batch_data, Mapping) and "features" in batch_data):
|
|
75
|
+
raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
|
|
76
|
+
return {
|
|
77
|
+
"features": batch_data.get("features", {}),
|
|
78
|
+
"labels": batch_data.get("labels"),
|
|
79
|
+
"ids": batch_data.get("ids") if include_ids else None,
|
|
80
|
+
}
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data processing utilities for NextRec
|
|
3
|
+
|
|
4
|
+
Date: create on 03/12/2025
|
|
5
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from typing import Any, Mapping
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
15
|
+
if isinstance(data, dict):
|
|
16
|
+
return data[name] if name in data else None
|
|
17
|
+
elif isinstance(data, pd.DataFrame):
|
|
18
|
+
if name not in data.columns:
|
|
19
|
+
return None
|
|
20
|
+
return data[name].values
|
|
21
|
+
else:
|
|
22
|
+
if hasattr(data, name):
|
|
23
|
+
return getattr(data, name)
|
|
24
|
+
raise KeyError(f"Unsupported data type for extracting column {name}")
|
|
25
|
+
|
|
26
|
+
def split_dict_random(
|
|
27
|
+
data_dict: dict,
|
|
28
|
+
test_size: float = 0.2,
|
|
29
|
+
random_state: int | None = None
|
|
30
|
+
):
|
|
31
|
+
lengths = [len(v) for v in data_dict.values()]
|
|
32
|
+
if len(set(lengths)) != 1:
|
|
33
|
+
raise ValueError(f"Length mismatch: {lengths}")
|
|
34
|
+
|
|
35
|
+
n = lengths[0]
|
|
36
|
+
rng = np.random.default_rng(random_state)
|
|
37
|
+
perm = rng.permutation(n)
|
|
38
|
+
cut = int(round(n * (1 - test_size)))
|
|
39
|
+
train_idx, test_idx = perm[:cut], perm[cut:]
|
|
40
|
+
|
|
41
|
+
def take(v, idx):
|
|
42
|
+
if isinstance(v, np.ndarray):
|
|
43
|
+
return v[idx]
|
|
44
|
+
elif isinstance(v, pd.Series):
|
|
45
|
+
return v.iloc[idx].to_numpy()
|
|
46
|
+
else:
|
|
47
|
+
v_arr = np.asarray(v, dtype=object)
|
|
48
|
+
return v_arr[idx]
|
|
49
|
+
|
|
50
|
+
train_dict = {k: take(v, train_idx) for k, v in data_dict.items()}
|
|
51
|
+
test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
|
|
52
|
+
return train_dict, test_dict
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def build_eval_candidates(
|
|
56
|
+
df_all: pd.DataFrame,
|
|
57
|
+
user_col: str,
|
|
58
|
+
item_col: str,
|
|
59
|
+
label_col: str,
|
|
60
|
+
user_features: pd.DataFrame,
|
|
61
|
+
item_features: pd.DataFrame,
|
|
62
|
+
num_pos_per_user: int = 5,
|
|
63
|
+
num_neg_per_pos: int = 50,
|
|
64
|
+
random_seed: int = 2025,
|
|
65
|
+
) -> pd.DataFrame:
|
|
66
|
+
"""
|
|
67
|
+
Build evaluation candidates with positive and negative samples for each user.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
df_all: Full interaction DataFrame
|
|
71
|
+
user_col: Name of the user ID column
|
|
72
|
+
item_col: Name of the item ID column
|
|
73
|
+
label_col: Name of the label column
|
|
74
|
+
user_features: DataFrame containing user features
|
|
75
|
+
item_features: DataFrame containing item features
|
|
76
|
+
num_pos_per_user: Number of positive samples per user (default: 5)
|
|
77
|
+
num_neg_per_pos: Number of negative samples per positive (default: 50)
|
|
78
|
+
random_seed: Random seed for reproducibility (default: 2025)
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
pd.DataFrame: Evaluation candidates with features
|
|
82
|
+
"""
|
|
83
|
+
rng = np.random.default_rng(random_seed)
|
|
84
|
+
|
|
85
|
+
users = df_all[user_col].unique()
|
|
86
|
+
all_items = item_features[item_col].unique()
|
|
87
|
+
rows = []
|
|
88
|
+
user_hist_items = {u: df_all[df_all[user_col] == u][item_col].unique() for u in users}
|
|
89
|
+
|
|
90
|
+
for u in users:
|
|
91
|
+
df_user = df_all[df_all[user_col] == u]
|
|
92
|
+
pos_items = df_user[df_user[label_col] == 1][item_col].unique()
|
|
93
|
+
if len(pos_items) == 0:
|
|
94
|
+
continue
|
|
95
|
+
pos_items = pos_items[:num_pos_per_user]
|
|
96
|
+
seen_items = set(user_hist_items[u])
|
|
97
|
+
neg_pool = np.setdiff1d(all_items, np.fromiter(seen_items, dtype=all_items.dtype))
|
|
98
|
+
if len(neg_pool) == 0:
|
|
99
|
+
continue
|
|
100
|
+
for pos in pos_items:
|
|
101
|
+
if len(neg_pool) <= num_neg_per_pos:
|
|
102
|
+
neg_items = neg_pool
|
|
103
|
+
else:
|
|
104
|
+
neg_items = rng.choice(neg_pool, size=num_neg_per_pos, replace=False)
|
|
105
|
+
rows.append((u, pos, 1))
|
|
106
|
+
for ni in neg_items:
|
|
107
|
+
rows.append((u, ni, 0))
|
|
108
|
+
|
|
109
|
+
eval_df = pd.DataFrame(rows, columns=[user_col, item_col, label_col])
|
|
110
|
+
eval_df = eval_df.merge(user_features, on=user_col, how='left')
|
|
111
|
+
eval_df = eval_df.merge(item_features, on=item_col, how='left')
|
|
112
|
+
return eval_df
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_user_ids(
|
|
116
|
+
data: Any,
|
|
117
|
+
id_columns: list[str] | str | None = None
|
|
118
|
+
) -> np.ndarray | None:
|
|
119
|
+
"""
|
|
120
|
+
Extract user IDs from various data structures.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
data: Data source (DataFrame, dict, or batch dict)
|
|
124
|
+
id_columns: List or single ID column name(s) (default: None)
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
np.ndarray | None: User IDs as numpy array, or None if not found
|
|
128
|
+
"""
|
|
129
|
+
id_columns = (
|
|
130
|
+
id_columns if isinstance(id_columns, list)
|
|
131
|
+
else [id_columns] if isinstance(id_columns, str)
|
|
132
|
+
else []
|
|
133
|
+
)
|
|
134
|
+
if not id_columns:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
main_id = id_columns[0]
|
|
138
|
+
if isinstance(data, pd.DataFrame) and main_id in data.columns:
|
|
139
|
+
arr = np.asarray(data[main_id].values)
|
|
140
|
+
return arr.reshape(arr.shape[0])
|
|
141
|
+
|
|
142
|
+
if isinstance(data, dict):
|
|
143
|
+
ids_container = data.get("ids")
|
|
144
|
+
if isinstance(ids_container, dict) and main_id in ids_container:
|
|
145
|
+
val = ids_container[main_id]
|
|
146
|
+
val = val.detach().cpu().numpy() if isinstance(val, torch.Tensor) else np.asarray(val)
|
|
147
|
+
return val.reshape(val.shape[0])
|
|
148
|
+
if main_id in data:
|
|
149
|
+
arr = np.asarray(data[main_id])
|
|
150
|
+
return arr.reshape(arr.shape[0])
|
|
151
|
+
|
|
152
|
+
return None
|