nextrec 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,268 +1,35 @@
1
- """Data processing utilities for NextRec."""
2
-
3
- import torch
4
- import numpy as np
5
- import pandas as pd
6
- import pyarrow.parquet as pq
7
- from pathlib import Path
8
- from typing import Any, Mapping, Sequence
9
-
10
- def stack_section(batch: list[dict], section: str):
11
- """Stack one section of the batch (features/labels/ids)."""
12
- entries = [item.get(section) for item in batch if item.get(section) is not None]
13
- if not entries:
14
- return None
15
- merged: dict = {}
16
- for name in entries[0]: # type: ignore
17
- tensors = [item[section][name] for item in batch if item.get(section) is not None and name in item[section]]
18
- merged[name] = torch.stack(tensors, dim=0)
19
- return merged
20
-
21
- def collate_fn(batch):
22
- """
23
- Collate a list of sample dicts into the unified batch format:
24
- {
25
- "features": {name: Tensor(B, ...)},
26
- "labels": {target: Tensor(B, ...)} or None,
27
- "ids": {id_name: Tensor(B, ...)} or None,
28
- }
29
- """
30
- if not batch:
31
- return {"features": {}, "labels": None, "ids": None}
32
-
33
- first = batch[0]
34
- if isinstance(first, dict) and "features" in first:
35
- # Streaming dataset yields already-batched chunks; avoid adding an extra dim.
36
- if first.get("_already_batched") and len(batch) == 1:
37
- return {
38
- "features": first.get("features", {}),
39
- "labels": first.get("labels"),
40
- "ids": first.get("ids"),
41
- }
42
- return {
43
- "features": stack_section(batch, "features") or {},
44
- "labels": stack_section(batch, "labels"),
45
- "ids": stack_section(batch, "ids"),
46
- }
47
-
48
- # Fallback: stack tuples/lists of tensors
49
- num_tensors = len(first)
50
- result = []
51
- for i in range(num_tensors):
52
- tensor_list = [item[i] for item in batch]
53
- first_item = tensor_list[0]
54
- if isinstance(first_item, torch.Tensor):
55
- stacked = torch.cat(tensor_list, dim=0)
56
- elif isinstance(first_item, np.ndarray):
57
- stacked = np.concatenate(tensor_list, axis=0)
58
- elif isinstance(first_item, list):
59
- combined = []
60
- for entry in tensor_list:
61
- combined.extend(entry)
62
- stacked = combined
63
- else:
64
- stacked = tensor_list
65
- result.append(stacked)
66
- return tuple(result)
67
-
68
- def get_column_data(data: dict | pd.DataFrame, name: str):
69
- """Extract column data from various data structures."""
70
- if isinstance(data, dict):
71
- return data[name] if name in data else None
72
- elif isinstance(data, pd.DataFrame):
73
- if name not in data.columns:
74
- return None
75
- return data[name].values
76
- else:
77
- if hasattr(data, name):
78
- return getattr(data, name)
79
- raise KeyError(f"Unsupported data type for extracting column {name}")
80
-
81
- def resolve_file_paths(path: str) -> tuple[list[str], str]:
82
- """Resolve file or directory path into a sorted list of files and file type."""
83
- path_obj = Path(path)
84
-
85
- if path_obj.is_file():
86
- file_type = path_obj.suffix.lower().lstrip(".")
87
- assert file_type in ["csv", "parquet"], f"Unsupported file extension: {file_type}"
88
- return [str(path_obj)], file_type
89
-
90
- if path_obj.is_dir():
91
- collected_files = [p for p in path_obj.iterdir() if p.is_file()]
92
- csv_files = [str(p) for p in collected_files if p.suffix.lower() == ".csv"]
93
- parquet_files = [str(p) for p in collected_files if p.suffix.lower() == ".parquet"]
94
-
95
- if csv_files and parquet_files:
96
- raise ValueError("Directory contains both CSV and Parquet files. Please keep a single format.")
97
- file_paths = csv_files if csv_files else parquet_files
98
- if not file_paths:
99
- raise ValueError(f"No CSV or Parquet files found in directory: {path}")
100
- file_paths.sort()
101
- file_type = "csv" if csv_files else "parquet"
102
- return file_paths, file_type
103
-
104
- raise ValueError(f"Invalid path: {path}")
105
-
106
- def iter_file_chunks(file_path: str, file_type: str, chunk_size: int):
107
- """Yield DataFrame chunks for CSV/Parquet without loading the whole file."""
108
- if file_type == "csv":
109
- yield from pd.read_csv(file_path, chunksize=chunk_size)
110
- return
111
- parquet_file = pq.ParquetFile(file_path)
112
- for batch in parquet_file.iter_batches(batch_size=chunk_size):
113
- yield batch.to_pandas()
114
-
115
- def read_table(file_path: str, file_type: str) -> pd.DataFrame:
116
- """Read a single CSV/Parquet file."""
117
- if file_type == "csv":
118
- return pd.read_csv(file_path)
119
- return pd.read_parquet(file_path)
120
-
121
- def load_dataframes(file_paths: list[str], file_type: str) -> list[pd.DataFrame]:
122
- """Load multiple files of the same type into DataFrames."""
123
- return [read_table(fp, file_type) for fp in file_paths]
124
-
125
- def default_output_dir(path: str) -> Path:
126
- """Generate a default output directory path based on the input path."""
127
- path_obj = Path(path)
128
- if path_obj.is_file():
129
- return path_obj.parent / f"{path_obj.stem}_preprocessed"
130
- return path_obj.with_name(f"{path_obj.name}_preprocessed")
131
-
132
- def split_dict_random(data_dict: dict, test_size: float = 0.2, random_state: int | None = None):
133
- """Randomly split a dictionary of data into training and testing sets."""
134
- lengths = [len(v) for v in data_dict.values()]
135
- if len(set(lengths)) != 1:
136
- raise ValueError(f"Length mismatch: {lengths}")
137
- n = lengths[0]
138
- rng = np.random.default_rng(random_state)
139
- perm = rng.permutation(n)
140
- cut = int(round(n * (1 - test_size)))
141
- train_idx, test_idx = perm[:cut], perm[cut:]
142
- def take(v, idx):
143
- if isinstance(v, np.ndarray):
144
- return v[idx]
145
- elif isinstance(v, pd.Series):
146
- return v.iloc[idx].to_numpy()
147
- else:
148
- v_arr = np.asarray(v, dtype=object)
149
- return v_arr[idx]
150
- train_dict = {k: take(v, train_idx) for k, v in data_dict.items()}
151
- test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
152
- return train_dict, test_dict
153
-
154
- def build_eval_candidates(
155
- df_all: pd.DataFrame,
156
- user_col: str,
157
- item_col: str,
158
- label_col: str,
159
- user_features: pd.DataFrame,
160
- item_features: pd.DataFrame,
161
- num_pos_per_user: int = 5,
162
- num_neg_per_pos: int = 50,
163
- random_seed: int = 2025,
164
- ) -> pd.DataFrame:
165
- """Build evaluation candidates with positive and negative samples for each user. """
166
- rng = np.random.default_rng(random_seed)
167
-
168
- users = df_all[user_col].unique()
169
- all_items = item_features[item_col].unique()
170
- rows = []
171
- user_hist_items = {u: df_all[df_all[user_col] == u][item_col].unique() for u in users}
172
- for u in users:
173
- df_user = df_all[df_all[user_col] == u]
174
- pos_items = df_user[df_user[label_col] == 1][item_col].unique()
175
- if len(pos_items) == 0:
176
- continue
177
- pos_items = pos_items[:num_pos_per_user]
178
- seen_items = set(user_hist_items[u])
179
- neg_pool = np.setdiff1d(all_items, np.fromiter(seen_items, dtype=all_items.dtype))
180
- if len(neg_pool) == 0:
181
- continue
182
- for pos in pos_items:
183
- if len(neg_pool) <= num_neg_per_pos:
184
- neg_items = neg_pool
185
- else:
186
- neg_items = rng.choice(neg_pool, size=num_neg_per_pos, replace=False)
187
- rows.append((u, pos, 1))
188
- for ni in neg_items:
189
- rows.append((u, ni, 0))
190
- eval_df = pd.DataFrame(rows, columns=[user_col, item_col, label_col])
191
- eval_df = eval_df.merge(user_features, on=user_col, how='left')
192
- eval_df = eval_df.merge(item_features, on=item_col, how='left')
193
- return eval_df
194
-
195
- def batch_to_dict(batch_data: Any, include_ids: bool = True) -> dict:
196
- """Standardize a dataloader batch into a dict of features, labels, and ids."""
197
- if not (isinstance(batch_data, Mapping) and "features" in batch_data):
198
- raise TypeError(
199
- "[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader."
200
- )
201
- return {
202
- "features": batch_data.get("features", {}),
203
- "labels": batch_data.get("labels"),
204
- "ids": batch_data.get("ids") if include_ids else None,
205
- }
206
-
207
-
208
- # def get_user_ids(
209
- # data: dict | pd.DataFrame | None, user_id_column: str = "user_id"
210
- # ) -> np.ndarray | None:
211
- # """Extract user IDs from a dataset dict or DataFrame."""
212
- # if data is None:
213
- # return None
214
- # if isinstance(data, pd.DataFrame) and user_id_column in data.columns:
215
- # return np.asarray(data[user_id_column].values)
216
- # if isinstance(data, dict) and user_id_column in data:
217
- # return np.asarray(data[user_id_column])
218
- # return None
219
-
220
-
221
- # def get_user_ids_from_batch(
222
- # batch_dict: Mapping[str, Any], id_columns: Sequence[str] | None = None
223
- # ) -> np.ndarray | None:
224
- # """Extract the prioritized user id column from a batch dict."""
225
- # ids_container = batch_dict.get("ids") if isinstance(batch_dict, Mapping) else None
226
- # if not ids_container:
227
- # return None
228
-
229
- # batch_user_id = None
230
- # if id_columns:
231
- # for id_name in id_columns:
232
- # if id_name in ids_container:
233
- # batch_user_id = ids_container[id_name]
234
- # break
235
- # if batch_user_id is None:
236
- # batch_user_id = next(iter(ids_container.values()), None)
237
- # if batch_user_id is None:
238
- # return None
239
-
240
- # if isinstance(batch_user_id, torch.Tensor):
241
- # ids_np = batch_user_id.detach().cpu().numpy()
242
- # else:
243
- # ids_np = np.asarray(batch_user_id)
244
- # if ids_np.ndim == 0:
245
- # ids_np = ids_np.reshape(1)
246
- # return ids_np.reshape(ids_np.shape[0])
247
-
248
-
249
- def get_user_ids(data, id_columns: list[str] | str | None = None) -> np.ndarray | None:
250
- id_columns = id_columns if isinstance(id_columns, list) else [id_columns] if isinstance(id_columns, str) else []
251
- if not id_columns:
252
- return None
253
-
254
- main_id = id_columns[0]
255
- if isinstance(data, pd.DataFrame) and main_id in data.columns:
256
- arr = np.asarray(data[main_id].values)
257
- return arr.reshape(arr.shape[0])
258
- if isinstance(data, dict):
259
- ids_container = data.get("ids")
260
- if isinstance(ids_container, dict) and main_id in ids_container:
261
- val = ids_container[main_id]
262
- val = val.detach().cpu().numpy() if isinstance(val, torch.Tensor) else np.asarray(val)
263
- return val.reshape(val.shape[0])
264
- if main_id in data:
265
- arr = np.asarray(data[main_id])
266
- return arr.reshape(arr.shape[0])
267
-
268
- return None
1
+ """
2
+ Data processing utilities for NextRec (Refactored)
3
+
4
+ This module now re-exports functions from specialized submodules:
5
+ - batch_utils: collate_fn, batch_to_dict
6
+ - data_processing: get_column_data, split_dict_random, build_eval_candidates, get_user_ids
7
+ - nextrec.utils.file_utils: resolve_file_paths, iter_file_chunks, read_table, load_dataframes, default_output_dir
8
+
9
+ Date: create on 27/10/2025
10
+ Last update: 03/12/2025 (refactored)
11
+ Author: Yang Zhou, zyaztec@gmail.com
12
+ """
13
+
14
+ # Import from new organized modules
15
+ from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
16
+ from nextrec.data.data_processing import get_column_data, split_dict_random, build_eval_candidates, get_user_ids
17
+ from nextrec.utils.file import resolve_file_paths, iter_file_chunks, read_table, load_dataframes, default_output_dir
18
+
19
+ __all__ = [
20
+ # Batch utilities
21
+ 'collate_fn',
22
+ 'batch_to_dict',
23
+ 'stack_section',
24
+ # Data processing
25
+ 'get_column_data',
26
+ 'split_dict_random',
27
+ 'build_eval_candidates',
28
+ 'get_user_ids',
29
+ # File utilities
30
+ 'resolve_file_paths',
31
+ 'iter_file_chunks',
32
+ 'read_table',
33
+ 'load_dataframes',
34
+ 'default_output_dir',
35
+ ]
@@ -20,8 +20,10 @@ from nextrec.data.preprocessor import DataProcessor
20
20
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
21
21
 
22
22
  from nextrec.basic.loggers import colorize
23
- from nextrec.data import get_column_data, collate_fn, resolve_file_paths, read_table
24
- from nextrec.utils import to_tensor
23
+ from nextrec.data.data_processing import get_column_data
24
+ from nextrec.data.batch_utils import collate_fn
25
+ from nextrec.utils.file import resolve_file_paths, read_table
26
+ from nextrec.utils.tensor import to_tensor
25
27
 
26
28
  class TensorDictDataset(Dataset):
27
29
  """Dataset returning sample-level dicts matching the unified batch schema."""
@@ -124,20 +126,22 @@ class RecDataLoader(FeatureSet):
124
126
  batch_size: int = 32,
125
127
  shuffle: bool = True,
126
128
  load_full: bool = True,
127
- chunk_size: int = 10000) -> DataLoader:
129
+ chunk_size: int = 10000,
130
+ num_workers: int = 0) -> DataLoader:
128
131
  if isinstance(data, DataLoader):
129
132
  return data
130
133
  elif isinstance(data, (str, os.PathLike)):
131
- return self.create_from_path(path=data, batch_size=batch_size, shuffle=shuffle, load_full=load_full, chunk_size=chunk_size)
134
+ return self.create_from_path(path=data, batch_size=batch_size, shuffle=shuffle, load_full=load_full, chunk_size=chunk_size, num_workers=num_workers)
132
135
  elif isinstance(data, (dict, pd.DataFrame)):
133
- return self.create_from_memory(data=data, batch_size=batch_size, shuffle=shuffle)
136
+ return self.create_from_memory(data=data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
134
137
  else:
135
138
  raise ValueError(f"[RecDataLoader Error] Unsupported data type: {type(data)}")
136
139
 
137
140
  def create_from_memory(self,
138
141
  data: dict | pd.DataFrame,
139
142
  batch_size: int,
140
- shuffle: bool) -> DataLoader:
143
+ shuffle: bool,
144
+ num_workers: int = 0) -> DataLoader:
141
145
  raw_data = data
142
146
 
143
147
  if self.processor is not None:
@@ -148,14 +152,15 @@ class RecDataLoader(FeatureSet):
148
152
  if tensors is None:
149
153
  raise ValueError("[RecDataLoader Error] No valid tensors could be built from the provided data.")
150
154
  dataset = TensorDictDataset(tensors)
151
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
155
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=num_workers)
152
156
 
153
157
  def create_from_path(self,
154
158
  path: str,
155
159
  batch_size: int,
156
160
  shuffle: bool,
157
161
  load_full: bool,
158
- chunk_size: int = 10000) -> DataLoader:
162
+ chunk_size: int = 10000,
163
+ num_workers: int = 0) -> DataLoader:
159
164
  file_paths, file_type = resolve_file_paths(str(Path(path)))
160
165
  # Load full data into memory
161
166
  if load_full:
@@ -167,6 +172,7 @@ class RecDataLoader(FeatureSet):
167
172
  except OSError:
168
173
  pass
169
174
  try:
175
+ df = read_table(file_path, file_type=file_type)
170
176
  dfs.append(df)
171
177
  except MemoryError as exc:
172
178
  raise MemoryError(f"[RecDataLoader Error] Out of memory while reading {file_path}. Consider using load_full=False with streaming.") from exc
@@ -174,22 +180,23 @@ class RecDataLoader(FeatureSet):
174
180
  combined_df = pd.concat(dfs, ignore_index=True)
175
181
  except MemoryError as exc:
176
182
  raise MemoryError(f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use load_full=False to stream or reduce chunk_size.") from exc
177
- return self.create_from_memory(combined_df, batch_size, shuffle,)
183
+ return self.create_from_memory(combined_df, batch_size, shuffle, num_workers=num_workers)
178
184
  else:
179
- return self.load_files_streaming(file_paths, file_type, batch_size, chunk_size, shuffle)
185
+ return self.load_files_streaming(file_paths, file_type, batch_size, chunk_size, shuffle, num_workers=num_workers)
180
186
 
181
187
  def load_files_streaming(self,
182
188
  file_paths: list[str],
183
189
  file_type: str,
184
190
  batch_size: int,
185
191
  chunk_size: int,
186
- shuffle: bool) -> DataLoader:
192
+ shuffle: bool,
193
+ num_workers: int = 0) -> DataLoader:
187
194
  if shuffle:
188
195
  logging.info("[RecDataLoader Info] Shuffle is ignored in streaming mode (IterableDataset).")
189
196
  if batch_size != 1:
190
197
  logging.info("[RecDataLoader Info] Streaming mode enforces batch_size=1; tune chunk_size to control memory/throughput.")
191
198
  dataset = FileDataset(file_paths=file_paths, dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target_columns=self.target_columns, id_columns=self.id_columns, chunk_size=chunk_size, file_type=file_type, processor=self.processor)
192
- return DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
199
+ return DataLoader(dataset, batch_size=1, collate_fn=collate_fn, num_workers=num_workers)
193
200
 
194
201
  def normalize_sequence_column(column, feature: SequenceFeature) -> np.ndarray:
195
202
  if isinstance(column, pd.Series):
@@ -16,24 +16,14 @@ import pandas as pd
16
16
  import tqdm
17
17
  from pathlib import Path
18
18
  from typing import Dict, Union, Optional, Literal, Any
19
- from sklearn.preprocessing import (
20
- StandardScaler,
21
- MinMaxScaler,
22
- RobustScaler,
23
- MaxAbsScaler,
24
- LabelEncoder
25
- )
19
+ from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler, MaxAbsScaler, LabelEncoder
20
+
26
21
 
27
- from nextrec.basic.loggers import setup_logger, colorize
28
- from nextrec.data.data_utils import (
29
- resolve_file_paths,
30
- iter_file_chunks,
31
- read_table,
32
- load_dataframes,
33
- default_output_dir,
34
- )
35
- from nextrec.basic.session import resolve_save_path
36
22
  from nextrec.basic.features import FeatureSet
23
+ from nextrec.basic.loggers import colorize
24
+ from nextrec.basic.session import resolve_save_path
25
+ from nextrec.utils.file import resolve_file_paths, iter_file_chunks, read_table, load_dataframes, default_output_dir
26
+
37
27
  from nextrec.__version__ import __version__
38
28
 
39
29
 
@@ -1,5 +0,0 @@
1
- from .hstu import HSTU
2
-
3
- __all__ = [
4
- "HSTU",
5
- ]
@@ -1,13 +0,0 @@
1
- from .dssm import DSSM
2
- from .dssm_v2 import DSSM_v2
3
- from .youtube_dnn import YoutubeDNN
4
- from .mind import MIND
5
- from .sdm import SDM
6
-
7
- __all__ = [
8
- 'DSSM',
9
- 'DSSM_v2',
10
- 'YoutubeDNN',
11
- 'MIND',
12
- 'SDM',
13
- ]
File without changes
@@ -46,7 +46,7 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
46
46
  from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
47
47
  from nextrec.basic.activation import activation_layer
48
48
  from nextrec.basic.model import BaseModel
49
- from nextrec.utils.common import merge_features
49
+ from nextrec.utils.model import merge_features
50
50
 
51
51
 
52
52
  class POSOGate(nn.Module):
@@ -1,27 +0,0 @@
1
- from .fm import FM
2
- from .afm import AFM
3
- from .masknet import MaskNet
4
- from .pnn import PNN
5
- from .deepfm import DeepFM
6
- from .autoint import AutoInt
7
- from .widedeep import WideDeep
8
- from .xdeepfm import xDeepFM
9
- from .dcn import DCN
10
- from .fibinet import FiBiNET
11
- from .din import DIN
12
- from .dien import DIEN
13
-
14
- __all__ = [
15
- 'DeepFM',
16
- 'AutoInt',
17
- 'WideDeep',
18
- 'xDeepFM',
19
- 'DCN',
20
- 'DIN',
21
- 'DIEN',
22
- 'FM',
23
- 'AFM',
24
- 'MaskNet',
25
- 'PNN',
26
- 'FiBiNET',
27
- ]
nextrec/utils/__init__.py CHANGED
@@ -1,18 +1,68 @@
1
+ """
2
+ Utilities package for NextRec
3
+
4
+ This package provides various utility functions organized by category:
5
+ - optimizer: Optimizer and scheduler utilities
6
+ - initializer: Weight initialization utilities
7
+ - embedding: Embedding dimension calculation
8
+ - device_utils: Device management and selection
9
+ - tensor_utils: Tensor operations and conversions
10
+ - file_utils: File I/O operations
11
+ - model_utils: Model-related utilities
12
+ - feature_utils: Feature processing utilities
13
+
14
+ Date: create on 13/11/2025
15
+ Last update: 03/12/2025 (refactored)
16
+ Author: Yang Zhou, zyaztec@gmail.com
17
+ """
18
+
1
19
  from .optimizer import get_optimizer, get_scheduler
2
20
  from .initializer import get_initializer
3
21
  from .embedding import get_auto_embedding_dim
4
- from .common import resolve_device, to_tensor
5
- from . import optimizer, initializer, embedding, common
22
+ from .device import resolve_device, get_device_info
23
+ from .tensor import to_tensor, stack_tensors, concat_tensors, pad_sequence_tensors
24
+ from .file import resolve_file_paths, read_table, load_dataframes, iter_file_chunks, default_output_dir
25
+ from .model import merge_features, get_mlp_output_dim
26
+ from .feature import normalize_to_list
27
+ from . import optimizer, initializer, embedding
6
28
 
7
29
  __all__ = [
30
+ # Optimizer & Scheduler
8
31
  'get_optimizer',
9
32
  'get_scheduler',
33
+
34
+ # Initializer
10
35
  'get_initializer',
36
+
37
+ # Embedding
11
38
  'get_auto_embedding_dim',
39
+
40
+ # Device utilities
12
41
  'resolve_device',
42
+ 'get_device_info',
43
+
44
+ # Tensor utilities
13
45
  'to_tensor',
46
+ 'stack_tensors',
47
+ 'concat_tensors',
48
+ 'pad_sequence_tensors',
49
+
50
+ # File utilities
51
+ 'resolve_file_paths',
52
+ 'read_table',
53
+ 'load_dataframes',
54
+ 'iter_file_chunks',
55
+ 'default_output_dir',
56
+
57
+ # Model utilities
58
+ 'merge_features',
59
+ 'get_mlp_output_dim',
60
+
61
+ # Feature utilities
62
+ 'normalize_to_list',
63
+
64
+ # Module exports
14
65
  'optimizer',
15
66
  'initializer',
16
67
  'embedding',
17
- 'common',
18
68
  ]
@@ -0,0 +1,38 @@
1
+ """
2
+ Device management utilities for NextRec
3
+
4
+ Date: create on 03/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ """
7
+ import os
8
+ import torch
9
+ import platform
10
+ import multiprocessing
11
+
12
+
13
+ def resolve_device() -> str:
14
+ if torch.cuda.is_available():
15
+ return "cuda"
16
+ if torch.backends.mps.is_available():
17
+ mac_ver = platform.mac_ver()[0]
18
+ try:
19
+ major, minor = (int(x) for x in mac_ver.split(".")[:2])
20
+ except Exception:
21
+ major, minor = 0, 0
22
+ if major >= 14:
23
+ return "mps"
24
+ return "cpu"
25
+
26
+ def get_device_info() -> dict:
27
+ info = {
28
+ 'cuda_available': torch.cuda.is_available(),
29
+ 'cuda_device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
30
+ 'mps_available': torch.backends.mps.is_available(),
31
+ 'current_device': resolve_device(),
32
+ }
33
+
34
+ if torch.cuda.is_available():
35
+ info['cuda_device_name'] = torch.cuda.get_device_name(0)
36
+ info['cuda_capability'] = torch.cuda.get_device_capability(0)
37
+
38
+ return info
@@ -0,0 +1,13 @@
1
+ """
2
+ Feature processing utilities for NextRec
3
+
4
+ Date: create on 03/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ """
7
+
8
+ def normalize_to_list(value: str | list[str] | None) -> list[str]:
9
+ if value is None:
10
+ return []
11
+ if isinstance(value, str):
12
+ return [value]
13
+ return list(value)