nextrec 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__init__.py CHANGED
@@ -1,33 +1,3 @@
1
- """
2
- NextRec - A Unified Deep Learning Framework for Recommender Systems
3
- ===================================================================
4
-
5
- NextRec provides a comprehensive suite of recommendation models including:
6
- - Ranking models (CTR prediction)
7
- - Matching models (retrieval)
8
- - Multi-task learning models
9
- - Generative recommendation models
10
-
11
- Quick Start
12
- -----------
13
- >>> from nextrec.basic.features import DenseFeature, SparseFeature
14
- >>> from nextrec.models.ranking.deepfm import DeepFM
15
- >>>
16
- >>> # Define features
17
- >>> dense_features = [DenseFeature('age')]
18
- >>> sparse_features = [SparseFeature('category', vocab_size=100, embedding_dim=16)]
19
- >>>
20
- >>> # Build model
21
- >>> model = DeepFM(
22
- ... dense_features=dense_features,
23
- ... sparse_features=sparse_features,
24
- ... targets=['label']
25
- ... )
26
- >>>
27
- >>> # Train model
28
- >>> model.fit(train_data=df_train, valid_data=df_valid)
29
- """
30
-
31
1
  from nextrec.__version__ import __version__
32
2
 
33
3
  __all__ = [
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.4"
1
+ __version__ = "0.3.6"
nextrec/basic/features.py CHANGED
@@ -7,7 +7,7 @@ Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
  import torch
9
9
  from nextrec.utils.embedding import get_auto_embedding_dim
10
- from nextrec.utils.common import normalize_to_list
10
+ from nextrec.utils.feature import normalize_to_list
11
11
 
12
12
  class BaseFeature(object):
13
13
  def __repr__(self):
nextrec/basic/loggers.py CHANGED
@@ -99,7 +99,7 @@ def setup_logger(session_id: str | os.PathLike | None = None):
99
99
  session = create_session(str(session_id) if session_id is not None else None)
100
100
  log_dir = session.logs_dir
101
101
  log_dir.mkdir(parents=True, exist_ok=True)
102
- log_file = log_dir / f"{session.experiment_id}.log"
102
+ log_file = log_dir / f"{session.log_basename}.log"
103
103
 
104
104
  console_format = '%(message)s'
105
105
  file_format = '%(asctime)s - %(levelname)s - %(message)s'
nextrec/basic/model.py CHANGED
@@ -31,10 +31,12 @@ from nextrec.basic.session import resolve_save_path, create_session
31
31
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
32
32
 
33
33
  from nextrec.data.dataloader import build_tensors_from_data
34
- from nextrec.data.data_utils import get_column_data, collate_fn, batch_to_dict, get_user_ids
34
+ from nextrec.data.data_processing import get_column_data, get_user_ids
35
+ from nextrec.data.batch_utils import collate_fn, batch_to_dict
35
36
 
36
37
  from nextrec.loss import get_loss_fn, get_loss_kwargs
37
- from nextrec.utils import get_optimizer, get_scheduler, to_tensor
38
+ from nextrec.utils import get_optimizer, get_scheduler
39
+ from nextrec.utils.tensor import to_tensor
38
40
 
39
41
  from nextrec import __version__
40
42
 
@@ -153,7 +155,7 @@ class BaseModel(FeatureSet, nn.Module):
153
155
  raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
154
156
  return X_input, y
155
157
 
156
- def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
158
+ def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool, num_workers: int = 0,) -> tuple[DataLoader, dict | pd.DataFrame]:
157
159
  """This function will split training data into training and validation sets when: 1. valid_data is None; 2. validation_split is provided."""
158
160
  if not (0 < validation_split < 1):
159
161
  raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
@@ -182,7 +184,7 @@ class BaseModel(FeatureSet, nn.Module):
182
184
  arr = np.asarray(value)
183
185
  train_split[key] = arr[train_indices]
184
186
  valid_split[key] = arr[valid_indices]
185
- train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
187
+ train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
186
188
  logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
187
189
  return train_loader, valid_split
188
190
 
@@ -263,14 +265,14 @@ class BaseModel(FeatureSet, nn.Module):
263
265
  task_losses.append(task_loss)
264
266
  return torch.stack(task_losses).sum()
265
267
 
266
- def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
268
+ def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True, num_workers: int = 0,) -> DataLoader:
267
269
  if isinstance(data, DataLoader):
268
270
  return data
269
271
  tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
270
272
  if tensors is None:
271
273
  raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
272
274
  dataset = TensorDictDataset(tensors)
273
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
275
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=num_workers)
274
276
 
275
277
  def fit(self,
276
278
  train_data: dict | pd.DataFrame | DataLoader,
@@ -279,6 +281,7 @@ class BaseModel(FeatureSet, nn.Module):
279
281
  epochs:int=1, shuffle:bool=True, batch_size:int=32,
280
282
  user_id_column: str | None = None,
281
283
  validation_split: float | None = None,
284
+ num_workers: int = 0,
282
285
  tensorboard: bool = True,):
283
286
  self.to(self.device)
284
287
  if not self.logger_initialized:
@@ -295,11 +298,11 @@ class BaseModel(FeatureSet, nn.Module):
295
298
  self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
296
299
 
297
300
  if validation_split is not None and valid_data is None:
298
- train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,) # type: ignore
301
+ train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
299
302
  else:
300
- train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
303
+ train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers))
301
304
 
302
- valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column)
305
+ valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column, num_workers=num_workers)
303
306
  try:
304
307
  self.steps_per_epoch = len(train_loader)
305
308
  is_streaming = False
@@ -386,7 +389,7 @@ class BaseModel(FeatureSet, nn.Module):
386
389
  self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
387
390
  if valid_loader is not None:
388
391
  # pass user_ids only if needed for GAUC metric
389
- val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
392
+ val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None, num_workers=num_workers) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
390
393
  if self.nums_task == 1:
391
394
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
392
395
  logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
@@ -511,12 +514,12 @@ class BaseModel(FeatureSet, nn.Module):
511
514
  return avg_loss, metrics_dict
512
515
  return avg_loss
513
516
 
514
- def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id') -> tuple[DataLoader | None, np.ndarray | None]:
517
+ def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id', num_workers: int = 0,) -> tuple[DataLoader | None, np.ndarray | None]:
515
518
  if valid_data is None:
516
519
  return None, None
517
520
  if isinstance(valid_data, DataLoader):
518
521
  return valid_data, None
519
- valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
522
+ valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
520
523
  valid_user_ids = None
521
524
  if needs_user_ids:
522
525
  if user_id_column is None:
@@ -529,7 +532,8 @@ class BaseModel(FeatureSet, nn.Module):
529
532
  metrics: list[str] | dict[str, list[str]] | None = None,
530
533
  batch_size: int = 32,
531
534
  user_ids: np.ndarray | None = None,
532
- user_id_column: str = 'user_id') -> dict:
535
+ user_id_column: str = 'user_id',
536
+ num_workers: int = 0,) -> dict:
533
537
  self.eval()
534
538
  eval_metrics = metrics if metrics is not None else self.metrics
535
539
  if eval_metrics is None:
@@ -541,7 +545,7 @@ class BaseModel(FeatureSet, nn.Module):
541
545
  else:
542
546
  if user_ids is None and needs_user_ids:
543
547
  user_ids = get_user_ids(data=data, id_columns=user_id_column)
544
- data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False)
548
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
545
549
  y_true_list = []
546
550
  y_pred_list = []
547
551
  collected_user_ids = []
@@ -601,6 +605,7 @@ class BaseModel(FeatureSet, nn.Module):
601
605
  include_ids: bool | None = None,
602
606
  return_dataframe: bool = True,
603
607
  streaming_chunk_size: int = 10000,
608
+ num_workers: int = 0,
604
609
  ) -> pd.DataFrame | np.ndarray:
605
610
  self.eval()
606
611
  if include_ids is None:
@@ -613,7 +618,7 @@ class BaseModel(FeatureSet, nn.Module):
613
618
  rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns,)
614
619
  data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
615
620
  elif not isinstance(data, DataLoader):
616
- data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
621
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
617
622
  else:
618
623
  data_loader = data
619
624
 
nextrec/basic/session.py CHANGED
@@ -1,14 +1,5 @@
1
1
  """Session and experiment utilities.
2
2
 
3
- This module centralizes session/experiment management so the rest of the
4
- framework writes all artifacts to a consistent location:: <pwd>/log/<experiment_id>/
5
-
6
- Within that folder we keep model parameters, checkpoints, training metrics,
7
- evaluation metrics, and consolidated log output. When users do not provide an
8
- ``experiment_id`` a timestamp-based identifier is generated once per process to
9
- avoid scattering files across multiple directories. Test runs are redirected to
10
- temporary folders so local trees are not polluted.
11
-
12
3
  Date: create on 23/11/2025
13
4
  Author: Yang Zhou,zyaztec@gmail.com
14
5
  """
@@ -16,7 +7,7 @@ Author: Yang Zhou,zyaztec@gmail.com
16
7
  import os
17
8
  import tempfile
18
9
  from dataclasses import dataclass
19
- from datetime import datetime
10
+ from datetime import datetime, timezone
20
11
  from pathlib import Path
21
12
 
22
13
  __all__ = [
@@ -31,6 +22,7 @@ class Session:
31
22
 
32
23
  experiment_id: str
33
24
  root: Path
25
+ log_basename: str # The base name for log files, without path separators
34
26
 
35
27
  @property
36
28
  def logs_dir(self) -> Path:
@@ -69,13 +61,15 @@ class Session:
69
61
  return path
70
62
 
71
63
  def create_session(experiment_id: str | Path | None = None) -> Session:
72
- """Create a :class:`Session` instance with prepared directories."""
73
64
 
74
65
  if experiment_id is not None and str(experiment_id).strip():
75
66
  exp_id = str(experiment_id).strip()
76
67
  else:
68
+ # Use local time for session naming
77
69
  exp_id = "nextrec_session_" + datetime.now().strftime("%Y%m%d")
78
70
 
71
+ log_basename = Path(exp_id).name if exp_id else exp_id
72
+
79
73
  if (
80
74
  os.getenv("PYTEST_CURRENT_TEST")
81
75
  or os.getenv("PYTEST_RUNNING")
@@ -90,7 +84,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
90
84
  session_path.mkdir(parents=True, exist_ok=True)
91
85
  root = session_path.resolve()
92
86
 
93
- return Session(experiment_id=exp_id, root=root)
87
+ return Session(experiment_id=exp_id, root=root, log_basename=log_basename)
94
88
 
95
89
  def resolve_save_path(
96
90
  path: str | os.PathLike | Path | None,
@@ -111,6 +105,7 @@ def resolve_save_path(
111
105
  timestamp.
112
106
  - Parent directories are created.
113
107
  """
108
+ # Use local time for file timestamps
114
109
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if add_timestamp else None
115
110
 
116
111
  normalized_suffix = suffix if suffix.startswith(".") else f".{suffix}"
nextrec/data/__init__.py CHANGED
@@ -1,48 +1,61 @@
1
- """
2
- Data utilities package for NextRec
3
-
4
- This package provides data processing and manipulation utilities.
5
-
6
- Date: create on 13/11/2025
7
- Author: Yang Zhou, zyaztec@gmail.com
8
- """
9
-
10
- from nextrec.data.data_utils import (
11
- collate_fn,
1
+ from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
2
+ from nextrec.data.data_processing import (
12
3
  get_column_data,
13
- default_output_dir,
14
4
  split_dict_random,
15
5
  build_eval_candidates,
6
+ get_user_ids,
7
+ )
8
+
9
+ from nextrec.utils.file import (
16
10
  resolve_file_paths,
17
11
  iter_file_chunks,
18
12
  read_table,
19
13
  load_dataframes,
14
+ default_output_dir,
20
15
  )
21
- from nextrec.basic.features import FeatureSet
22
- from nextrec.data import data_utils
16
+
23
17
  from nextrec.data.dataloader import (
24
18
  TensorDictDataset,
25
19
  FileDataset,
26
20
  RecDataLoader,
27
21
  build_tensors_from_data,
28
22
  )
23
+
29
24
  from nextrec.data.preprocessor import DataProcessor
25
+ from nextrec.basic.features import FeatureSet
26
+ from nextrec.data import data_utils
30
27
 
31
28
  __all__ = [
29
+ # Batch utilities
32
30
  'collate_fn',
31
+ 'batch_to_dict',
32
+ 'stack_section',
33
+
34
+ # Data processing
33
35
  'get_column_data',
34
- 'default_output_dir',
35
36
  'split_dict_random',
36
37
  'build_eval_candidates',
38
+ 'get_user_ids',
39
+
40
+ # File utilities
37
41
  'resolve_file_paths',
38
42
  'iter_file_chunks',
39
43
  'read_table',
40
44
  'load_dataframes',
41
- 'FeatureSet',
42
- 'data_utils',
45
+ 'default_output_dir',
46
+
47
+ # DataLoader
43
48
  'TensorDictDataset',
44
49
  'FileDataset',
45
50
  'RecDataLoader',
46
51
  'build_tensors_from_data',
52
+
53
+ # Preprocessor
47
54
  'DataProcessor',
55
+
56
+ # Features
57
+ 'FeatureSet',
58
+
59
+ # Legacy module
60
+ 'data_utils',
48
61
  ]
@@ -0,0 +1,80 @@
1
+ """
2
+ Batch collation utilities for NextRec
3
+
4
+ Date: create on 03/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ from typing import Any, Mapping
11
+
12
+ def stack_section(batch: list[dict], section: str):
13
+ entries = [item.get(section) for item in batch if item.get(section) is not None]
14
+ if not entries:
15
+ return None
16
+ merged: dict = {}
17
+ for name in entries[0]: # type: ignore
18
+ tensors = [item[section][name] for item in batch if item.get(section) is not None and name in item[section]]
19
+ merged[name] = torch.stack(tensors, dim=0)
20
+ return merged
21
+
22
+ def collate_fn(batch):
23
+ """
24
+ Collate a list of sample dicts into the unified batch format:
25
+ {
26
+ "features": {name: Tensor(B, ...)},
27
+ "labels": {target: Tensor(B, ...)} or None,
28
+ "ids": {id_name: Tensor(B, ...)} or None,
29
+ }
30
+ Args: batch: List of samples from DataLoader
31
+
32
+ Returns: dict: Batched data in unified format
33
+ """
34
+ if not batch:
35
+ return {"features": {}, "labels": None, "ids": None}
36
+
37
+ first = batch[0]
38
+ if isinstance(first, dict) and "features" in first:
39
+ # Streaming dataset yields already-batched chunks; avoid adding an extra dim.
40
+ if first.get("_already_batched") and len(batch) == 1:
41
+ return {
42
+ "features": first.get("features", {}),
43
+ "labels": first.get("labels"),
44
+ "ids": first.get("ids"),
45
+ }
46
+ return {
47
+ "features": stack_section(batch, "features") or {},
48
+ "labels": stack_section(batch, "labels"),
49
+ "ids": stack_section(batch, "ids"),
50
+ }
51
+
52
+ # Fallback: stack tuples/lists of tensors
53
+ num_tensors = len(first)
54
+ result = []
55
+ for i in range(num_tensors):
56
+ tensor_list = [item[i] for item in batch]
57
+ first_item = tensor_list[0]
58
+ if isinstance(first_item, torch.Tensor):
59
+ stacked = torch.cat(tensor_list, dim=0)
60
+ elif isinstance(first_item, np.ndarray):
61
+ stacked = np.concatenate(tensor_list, axis=0)
62
+ elif isinstance(first_item, list):
63
+ combined = []
64
+ for entry in tensor_list:
65
+ combined.extend(entry)
66
+ stacked = combined
67
+ else:
68
+ stacked = tensor_list
69
+ result.append(stacked)
70
+ return tuple(result)
71
+
72
+
73
+ def batch_to_dict(batch_data: Any, include_ids: bool = True) -> dict:
74
+ if not (isinstance(batch_data, Mapping) and "features" in batch_data):
75
+ raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
76
+ return {
77
+ "features": batch_data.get("features", {}),
78
+ "labels": batch_data.get("labels"),
79
+ "ids": batch_data.get("ids") if include_ids else None,
80
+ }
@@ -0,0 +1,152 @@
1
+ """
2
+ Data processing utilities for NextRec
3
+
4
+ Date: create on 03/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ import pandas as pd
11
+ from typing import Any, Mapping
12
+
13
+
14
+ def get_column_data(data: dict | pd.DataFrame, name: str):
15
+ if isinstance(data, dict):
16
+ return data[name] if name in data else None
17
+ elif isinstance(data, pd.DataFrame):
18
+ if name not in data.columns:
19
+ return None
20
+ return data[name].values
21
+ else:
22
+ if hasattr(data, name):
23
+ return getattr(data, name)
24
+ raise KeyError(f"Unsupported data type for extracting column {name}")
25
+
26
+ def split_dict_random(
27
+ data_dict: dict,
28
+ test_size: float = 0.2,
29
+ random_state: int | None = None
30
+ ):
31
+ lengths = [len(v) for v in data_dict.values()]
32
+ if len(set(lengths)) != 1:
33
+ raise ValueError(f"Length mismatch: {lengths}")
34
+
35
+ n = lengths[0]
36
+ rng = np.random.default_rng(random_state)
37
+ perm = rng.permutation(n)
38
+ cut = int(round(n * (1 - test_size)))
39
+ train_idx, test_idx = perm[:cut], perm[cut:]
40
+
41
+ def take(v, idx):
42
+ if isinstance(v, np.ndarray):
43
+ return v[idx]
44
+ elif isinstance(v, pd.Series):
45
+ return v.iloc[idx].to_numpy()
46
+ else:
47
+ v_arr = np.asarray(v, dtype=object)
48
+ return v_arr[idx]
49
+
50
+ train_dict = {k: take(v, train_idx) for k, v in data_dict.items()}
51
+ test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
52
+ return train_dict, test_dict
53
+
54
+
55
+ def build_eval_candidates(
56
+ df_all: pd.DataFrame,
57
+ user_col: str,
58
+ item_col: str,
59
+ label_col: str,
60
+ user_features: pd.DataFrame,
61
+ item_features: pd.DataFrame,
62
+ num_pos_per_user: int = 5,
63
+ num_neg_per_pos: int = 50,
64
+ random_seed: int = 2025,
65
+ ) -> pd.DataFrame:
66
+ """
67
+ Build evaluation candidates with positive and negative samples for each user.
68
+
69
+ Args:
70
+ df_all: Full interaction DataFrame
71
+ user_col: Name of the user ID column
72
+ item_col: Name of the item ID column
73
+ label_col: Name of the label column
74
+ user_features: DataFrame containing user features
75
+ item_features: DataFrame containing item features
76
+ num_pos_per_user: Number of positive samples per user (default: 5)
77
+ num_neg_per_pos: Number of negative samples per positive (default: 50)
78
+ random_seed: Random seed for reproducibility (default: 2025)
79
+
80
+ Returns:
81
+ pd.DataFrame: Evaluation candidates with features
82
+ """
83
+ rng = np.random.default_rng(random_seed)
84
+
85
+ users = df_all[user_col].unique()
86
+ all_items = item_features[item_col].unique()
87
+ rows = []
88
+ user_hist_items = {u: df_all[df_all[user_col] == u][item_col].unique() for u in users}
89
+
90
+ for u in users:
91
+ df_user = df_all[df_all[user_col] == u]
92
+ pos_items = df_user[df_user[label_col] == 1][item_col].unique()
93
+ if len(pos_items) == 0:
94
+ continue
95
+ pos_items = pos_items[:num_pos_per_user]
96
+ seen_items = set(user_hist_items[u])
97
+ neg_pool = np.setdiff1d(all_items, np.fromiter(seen_items, dtype=all_items.dtype))
98
+ if len(neg_pool) == 0:
99
+ continue
100
+ for pos in pos_items:
101
+ if len(neg_pool) <= num_neg_per_pos:
102
+ neg_items = neg_pool
103
+ else:
104
+ neg_items = rng.choice(neg_pool, size=num_neg_per_pos, replace=False)
105
+ rows.append((u, pos, 1))
106
+ for ni in neg_items:
107
+ rows.append((u, ni, 0))
108
+
109
+ eval_df = pd.DataFrame(rows, columns=[user_col, item_col, label_col])
110
+ eval_df = eval_df.merge(user_features, on=user_col, how='left')
111
+ eval_df = eval_df.merge(item_features, on=item_col, how='left')
112
+ return eval_df
113
+
114
+
115
+ def get_user_ids(
116
+ data: Any,
117
+ id_columns: list[str] | str | None = None
118
+ ) -> np.ndarray | None:
119
+ """
120
+ Extract user IDs from various data structures.
121
+
122
+ Args:
123
+ data: Data source (DataFrame, dict, or batch dict)
124
+ id_columns: List or single ID column name(s) (default: None)
125
+
126
+ Returns:
127
+ np.ndarray | None: User IDs as numpy array, or None if not found
128
+ """
129
+ id_columns = (
130
+ id_columns if isinstance(id_columns, list)
131
+ else [id_columns] if isinstance(id_columns, str)
132
+ else []
133
+ )
134
+ if not id_columns:
135
+ return None
136
+
137
+ main_id = id_columns[0]
138
+ if isinstance(data, pd.DataFrame) and main_id in data.columns:
139
+ arr = np.asarray(data[main_id].values)
140
+ return arr.reshape(arr.shape[0])
141
+
142
+ if isinstance(data, dict):
143
+ ids_container = data.get("ids")
144
+ if isinstance(ids_container, dict) and main_id in ids_container:
145
+ val = ids_container[main_id]
146
+ val = val.detach().cpu().numpy() if isinstance(val, torch.Tensor) else np.asarray(val)
147
+ return val.reshape(val.shape[0])
148
+ if main_id in data:
149
+ arr = np.asarray(data[main_id])
150
+ return arr.reshape(arr.shape[0])
151
+
152
+ return None