nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +220 -106
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1082 -400
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +272 -95
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +53 -37
  29. nextrec/models/multi_task/mmoe.py +64 -45
  30. nextrec/models/multi_task/ple.py +101 -48
  31. nextrec/models/multi_task/poso.py +113 -36
  32. nextrec/models/multi_task/share_bottom.py +48 -35
  33. nextrec/models/ranking/afm.py +72 -37
  34. nextrec/models/ranking/autoint.py +72 -55
  35. nextrec/models/ranking/dcn.py +55 -35
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +32 -22
  38. nextrec/models/ranking/dien.py +155 -99
  39. nextrec/models/ranking/din.py +85 -57
  40. nextrec/models/ranking/fibinet.py +52 -32
  41. nextrec/models/ranking/fm.py +29 -23
  42. nextrec/models/ranking/masknet.py +91 -29
  43. nextrec/models/ranking/pnn.py +31 -28
  44. nextrec/models/ranking/widedeep.py +34 -26
  45. nextrec/models/ranking/xdeepfm.py +60 -38
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +283 -165
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.4.1.dist-info/RECORD +0 -66
  61. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,490 @@
1
+ """
2
+ Configuration utilities for NextRec
3
+
4
+ This module provides utilities for loading and processing configuration files,
5
+ including feature configuration, model configuration, and training configuration.
6
+
7
+ Date: create on 06/12/2025
8
+ Author: Yang Zhou, zyaztec@gmail.com
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import importlib
14
+ import importlib.util
15
+ import inspect
16
+ from copy import deepcopy
17
+ from pathlib import Path
18
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
19
+
20
+ from nextrec.utils.feature import normalize_to_list
21
+
22
+ if TYPE_CHECKING:
23
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
24
+ from nextrec.data.preprocessor import DataProcessor
25
+
26
+
27
+ def resolve_path(path_str: str | Path, base_dir: Path) -> Path:
28
+ path = Path(path_str).expanduser()
29
+ if path.is_absolute():
30
+ return path
31
+ if path.exists():
32
+ return path.resolve()
33
+ return (base_dir / path).resolve()
34
+
35
+
36
+ def select_features(
37
+ feature_cfg: Dict[str, Any], df_columns: List[str]
38
+ ) -> Tuple[List[str], List[str], List[str]]:
39
+ columns = set(df_columns)
40
+
41
+ def pick(group: str) -> List[str]:
42
+ cfg = feature_cfg.get(group, {}) or {}
43
+ names = [name for name in cfg.keys() if name in columns]
44
+ missing = [name for name in cfg.keys() if name not in columns]
45
+ if missing:
46
+ print(f"[feature_config] skipped missing {group} columns: {missing}")
47
+ return names
48
+
49
+ dense_names = pick("dense")
50
+ sparse_names = pick("sparse")
51
+ sequence_names = pick("sequence")
52
+ return dense_names, sparse_names, sequence_names
53
+
54
+
55
+ def register_processor_features(
56
+ processor: DataProcessor,
57
+ feature_cfg: Dict[str, Any],
58
+ dense_names: List[str],
59
+ sparse_names: List[str],
60
+ sequence_names: List[str],
61
+ ) -> None:
62
+ """
63
+ Register features to DataProcessor based on feature configuration.
64
+
65
+ Args:
66
+ processor: DataProcessor instance
67
+ feature_cfg: Feature configuration dictionary
68
+ dense_names: List of dense feature names
69
+ sparse_names: List of sparse feature names
70
+ sequence_names: List of sequence feature names
71
+ """
72
+ dense_cfg = feature_cfg.get("dense", {}) or {}
73
+ sparse_cfg = feature_cfg.get("sparse", {}) or {}
74
+ sequence_cfg = feature_cfg.get("sequence", {}) or {}
75
+
76
+ for name in dense_names:
77
+ proc_cfg = dense_cfg.get(name, {}).get("processor_config", {}) or {}
78
+ processor.add_numeric_feature(
79
+ name,
80
+ scaler=proc_cfg.get("scaler", "standard"),
81
+ fill_na=proc_cfg.get("fill_na"),
82
+ )
83
+
84
+ for name in sparse_names:
85
+ proc_cfg = sparse_cfg.get(name, {}).get("processor_config", {}) or {}
86
+ processor.add_sparse_feature(
87
+ name,
88
+ encode_method=proc_cfg.get("encode_method", "hash"),
89
+ hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
90
+ fill_na=proc_cfg.get("fill_na", "<UNK>"),
91
+ )
92
+
93
+ for name in sequence_names:
94
+ proc_cfg = sequence_cfg.get(name, {}).get("processor_config", {}) or {}
95
+ processor.add_sequence_feature(
96
+ name,
97
+ encode_method=proc_cfg.get("encode_method", "hash"),
98
+ hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
99
+ max_len=proc_cfg.get("max_len", 50),
100
+ pad_value=proc_cfg.get("pad_value", 0),
101
+ truncate=proc_cfg.get("truncate", "post"),
102
+ separator=proc_cfg.get("separator", ","),
103
+ )
104
+
105
+
106
+ def build_feature_objects(
107
+ processor: "DataProcessor",
108
+ feature_cfg: Dict[str, Any],
109
+ dense_names: List[str],
110
+ sparse_names: List[str],
111
+ sequence_names: List[str],
112
+ ) -> Tuple[List["DenseFeature"], List["SparseFeature"], List["SequenceFeature"]]:
113
+ """
114
+ Build feature objects from processor and feature configuration.
115
+
116
+ Args:
117
+ processor: Fitted DataProcessor instance
118
+ feature_cfg: Feature configuration dictionary
119
+ dense_names: List of dense feature names
120
+ sparse_names: List of sparse feature names
121
+ sequence_names: List of sequence feature names
122
+ """
123
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
124
+
125
+ dense_cfg = feature_cfg.get("dense", {}) or {}
126
+ sparse_cfg = feature_cfg.get("sparse", {}) or {}
127
+ sequence_cfg = feature_cfg.get("sequence", {}) or {}
128
+ vocab_sizes = processor.get_vocab_sizes()
129
+
130
+ dense_features: List[DenseFeature] = []
131
+ for name in dense_names:
132
+ embed_cfg = dense_cfg.get(name, {}).get("embedding_config", {}) or {}
133
+ dense_features.append(
134
+ DenseFeature(
135
+ name=name,
136
+ embedding_dim=embed_cfg.get("embedding_dim"),
137
+ input_dim=embed_cfg.get("input_dim", 1),
138
+ use_embedding=embed_cfg.get("use_embedding", False),
139
+ )
140
+ )
141
+
142
+ sparse_features: List[SparseFeature] = []
143
+ for name in sparse_names:
144
+ entry = sparse_cfg.get(name, {}) or {}
145
+ proc_cfg = entry.get("processor_config", {}) or {}
146
+ embed_cfg = entry.get("embedding_config", {}) or {}
147
+ vocab_size = (
148
+ embed_cfg.get("vocab_size")
149
+ or proc_cfg.get("hash_size")
150
+ or vocab_sizes.get(name, 0)
151
+ or 1
152
+ )
153
+ sparse_features.append(
154
+ SparseFeature(
155
+ name=name,
156
+ vocab_size=int(vocab_size),
157
+ embedding_dim=embed_cfg.get("embedding_dim"),
158
+ padding_idx=embed_cfg.get("padding_idx"),
159
+ l1_reg=embed_cfg.get("l1_reg", 0.0),
160
+ l2_reg=embed_cfg.get("l2_reg", 1e-5),
161
+ trainable=embed_cfg.get("trainable", True),
162
+ )
163
+ )
164
+
165
+ sequence_features: List[SequenceFeature] = []
166
+ for name in sequence_names:
167
+ entry = sequence_cfg.get(name, {}) or {}
168
+ proc_cfg = entry.get("processor_config", {}) or {}
169
+ embed_cfg = entry.get("embedding_config", {}) or {}
170
+ vocab_size = (
171
+ embed_cfg.get("vocab_size")
172
+ or proc_cfg.get("hash_size")
173
+ or vocab_sizes.get(name, 0)
174
+ or 1
175
+ )
176
+ sequence_features.append(
177
+ SequenceFeature(
178
+ name=name,
179
+ vocab_size=int(vocab_size),
180
+ max_len=embed_cfg.get("max_len") or proc_cfg.get("max_len", 50),
181
+ embedding_dim=embed_cfg.get("embedding_dim"),
182
+ padding_idx=embed_cfg.get("padding_idx"),
183
+ combiner=embed_cfg.get("combiner", "mean"),
184
+ l1_reg=embed_cfg.get("l1_reg", 0.0),
185
+ l2_reg=embed_cfg.get("l2_reg", 1e-5),
186
+ trainable=embed_cfg.get("trainable", True),
187
+ )
188
+ )
189
+
190
+ return dense_features, sparse_features, sequence_features
191
+
192
+
193
+ def extract_feature_groups(
194
+ feature_cfg: Dict[str, Any], df_columns: List[str]
195
+ ) -> Tuple[Dict[str, List[str]], List[str]]:
196
+ """
197
+ Extract and validate feature groups from feature configuration.
198
+
199
+ Args:
200
+ feature_cfg: Feature configuration dictionary
201
+ df_columns: Available dataframe columns
202
+ """
203
+ feature_groups = feature_cfg.get("feature_groups") or {}
204
+ if not feature_groups:
205
+ return {}, []
206
+
207
+ defined = (
208
+ set((feature_cfg.get("dense") or {}).keys())
209
+ | set((feature_cfg.get("sparse") or {}).keys())
210
+ | set((feature_cfg.get("sequence") or {}).keys())
211
+ )
212
+ available_cols = set(df_columns)
213
+ resolved: Dict[str, List[str]] = {}
214
+ collected: List[str] = []
215
+
216
+ for group_name, names in feature_groups.items():
217
+ name_list = normalize_to_list(names)
218
+ filtered = []
219
+ missing_defined = [n for n in name_list if n not in defined]
220
+ missing_cols = [n for n in name_list if n not in available_cols]
221
+
222
+ if missing_defined:
223
+ print(
224
+ f"[feature_config] feature_groups.{group_name} contains features not defined in dense/sparse/sequence: {missing_defined}"
225
+ )
226
+
227
+ for n in name_list:
228
+ if n in available_cols:
229
+ if n not in filtered:
230
+ filtered.append(n)
231
+ else:
232
+ if n not in missing_cols:
233
+ missing_cols.append(n)
234
+
235
+ if missing_cols:
236
+ print(
237
+ f"[feature_config] feature_groups.{group_name} missing data columns: {missing_cols}"
238
+ )
239
+
240
+ resolved[group_name] = filtered
241
+ collected.extend(filtered)
242
+
243
+ return resolved, collected
244
+
245
+
246
+ def load_model_class(model_cfg: Dict[str, Any], base_dir: Path) -> type:
247
+ """
248
+ Load model class from configuration.
249
+
250
+ Args:
251
+ model_cfg: Model configuration dictionary
252
+ base_dir: Base directory for resolving relative paths
253
+ """
254
+
255
+ def camelize(name: str) -> str:
256
+ """Convert snake_case or kebab-case to CamelCase."""
257
+ return "".join(
258
+ part.capitalize()
259
+ for part in name.replace("_", " ").replace("-", " ").split()
260
+ )
261
+
262
+ module_path = model_cfg.get("module_path")
263
+ name = model_cfg.get("model") or model_cfg.get("name")
264
+ module_name = model_cfg.get("module") or model_cfg.get("module_name")
265
+ class_name = model_cfg.get("class_name")
266
+
267
+ # Case 1: Custom file path
268
+ if module_path:
269
+ resolved = resolve_path(module_path, base_dir)
270
+ if not resolved.exists():
271
+ raise FileNotFoundError(f"Custom model file not found: {resolved}")
272
+
273
+ spec = importlib.util.spec_from_file_location(resolved.stem, resolved)
274
+ if spec is None or spec.loader is None:
275
+ raise ImportError(f"Unable to load custom model file: {resolved}")
276
+
277
+ module = importlib.util.module_from_spec(spec)
278
+ spec.loader.exec_module(module)
279
+
280
+ if class_name and hasattr(module, class_name):
281
+ return getattr(module, class_name)
282
+
283
+ # Auto-pick first BaseModel subclass
284
+ from nextrec.basic.model import BaseModel
285
+
286
+ for attr in module.__dict__.values():
287
+ if (
288
+ isinstance(attr, type)
289
+ and issubclass(attr, BaseModel)
290
+ and attr is not BaseModel
291
+ ):
292
+ return attr
293
+
294
+ raise AttributeError(
295
+ f"No BaseModel subclass found in {resolved}, please provide class_name"
296
+ )
297
+
298
+ # Case 2: Builtin model by short name
299
+ if name and not module_name:
300
+ from nextrec.basic.model import BaseModel
301
+
302
+ candidates = [
303
+ f"nextrec.models.{name.lower()}",
304
+ f"nextrec.models.ranking.{name.lower()}",
305
+ f"nextrec.models.match.{name.lower()}",
306
+ f"nextrec.models.multi_task.{name.lower()}",
307
+ f"nextrec.models.generative.{name.lower()}",
308
+ ]
309
+ errors = []
310
+
311
+ for mod in candidates:
312
+ try:
313
+ module = importlib.import_module(mod)
314
+ cls_name = class_name or camelize(name)
315
+
316
+ if hasattr(module, cls_name):
317
+ return getattr(module, cls_name)
318
+
319
+ # Fallback: first BaseModel subclass
320
+ for attr in module.__dict__.values():
321
+ if (
322
+ isinstance(attr, type)
323
+ and issubclass(attr, BaseModel)
324
+ and attr is not BaseModel
325
+ ):
326
+ return attr
327
+
328
+ errors.append(f"{mod} missing class {cls_name}")
329
+ except Exception as exc:
330
+ errors.append(f"{mod}: {exc}")
331
+
332
+ raise ImportError(f"Unable to find model for model='{name}'. Tried: {errors}")
333
+
334
+ # Case 3: Explicit module + class
335
+ if module_name and class_name:
336
+ module = importlib.import_module(module_name)
337
+ if not hasattr(module, class_name):
338
+ raise AttributeError(f"Class {class_name} not found in {module_name}")
339
+ return getattr(module, class_name)
340
+
341
+ raise ValueError(
342
+ "model configuration must provide 'model' (builtin name), 'module_path' (custom path), or 'module'+'class_name'"
343
+ )
344
+
345
+
346
+ def build_model_instance(
347
+ model_cfg: Dict[str, Any],
348
+ model_cfg_path: Path,
349
+ dense_features: List[DenseFeature],
350
+ sparse_features: List[SparseFeature],
351
+ sequence_features: List[SequenceFeature],
352
+ target: List[str],
353
+ device: str,
354
+ ) -> Any:
355
+ """
356
+ Build model instance from configuration and feature objects.
357
+
358
+ Args:
359
+ model_cfg: Model configuration dictionary
360
+ model_cfg_path: Path to model config file (for resolving relative paths)
361
+ dense_features: List of dense feature objects
362
+ sparse_features: List of sparse feature objects
363
+ sequence_features: List of sequence feature objects
364
+ target: List of target column names
365
+ device: Device string (e.g., 'cpu', 'cuda:0')
366
+ """
367
+ dense_map = {f.name: f for f in dense_features}
368
+ sparse_map = {f.name: f for f in sparse_features}
369
+ sequence_map = {f.name: f for f in sequence_features}
370
+ feature_pool: Dict[str, Any] = {**dense_map, **sparse_map, **sequence_map}
371
+
372
+ model_cls = load_model_class(model_cfg, model_cfg_path.parent)
373
+ params_cfg = deepcopy(model_cfg.get("params") or {})
374
+ feature_groups = params_cfg.pop("feature_groups", {}) or {}
375
+ feature_bindings_cfg = (
376
+ model_cfg.get("feature_bindings")
377
+ or params_cfg.pop("feature_bindings", {})
378
+ or {}
379
+ )
380
+ sig_params = inspect.signature(model_cls.__init__).parameters
381
+
382
+ def _select(names: List[str] | None, pool: Dict[str, Any], desc: str) -> List[Any]:
383
+ """Select features from pool by names."""
384
+ if names is None:
385
+ return list(pool.values())
386
+ missing = [n for n in names if n not in feature_pool]
387
+ if missing:
388
+ raise ValueError(
389
+ f"feature_groups.{desc} contains unknown features: {missing}"
390
+ )
391
+ return [feature_pool[n] for n in names]
392
+
393
+ def accepts(name: str) -> bool:
394
+ """Check if parameter name is accepted by model __init__."""
395
+ return name in sig_params
396
+
397
+ accepts_var_kwargs = any(
398
+ param.kind == inspect.Parameter.VAR_KEYWORD for param in sig_params.values()
399
+ )
400
+
401
+ init_kwargs: Dict[str, Any] = dict(params_cfg)
402
+
403
+ # Explicit bindings (model_config.feature_bindings) take priority
404
+ for param_name, binding in feature_bindings_cfg.items():
405
+ if param_name in init_kwargs:
406
+ continue
407
+
408
+ if isinstance(binding, (list, tuple, set)):
409
+ if accepts(param_name) or accepts_var_kwargs:
410
+ init_kwargs[param_name] = _select(
411
+ list(binding), feature_pool, f"feature_bindings.{param_name}"
412
+ )
413
+ continue
414
+
415
+ if isinstance(binding, dict):
416
+ direct_features = binding.get("features") or binding.get("feature_names")
417
+ if direct_features and (accepts(param_name) or accepts_var_kwargs):
418
+ init_kwargs[param_name] = _select(
419
+ normalize_to_list(direct_features),
420
+ feature_pool,
421
+ f"feature_bindings.{param_name}",
422
+ )
423
+ continue
424
+ group_key = binding.get("group") or binding.get("group_key")
425
+ else:
426
+ group_key = binding
427
+
428
+ if group_key not in feature_groups:
429
+ print(
430
+ f"[feature_config] feature_bindings refers to unknown group '{group_key}', skipped"
431
+ )
432
+ continue
433
+
434
+ if accepts(param_name) or accepts_var_kwargs:
435
+ init_kwargs[param_name] = _select(
436
+ feature_groups[group_key], feature_pool, str(group_key)
437
+ )
438
+
439
+ # Dynamic feature groups: any key in feature_groups that matches __init__ will be filled
440
+ for group_key, names in feature_groups.items():
441
+ if accepts(str(group_key)):
442
+ init_kwargs.setdefault(
443
+ str(group_key), _select(names, feature_pool, str(group_key))
444
+ )
445
+
446
+ # Generalized mapping: match params to feature_groups by normalized names
447
+ def _normalize_group_key(key: str) -> str:
448
+ """Normalize group key by removing common suffixes."""
449
+ key = key.lower()
450
+ for suffix in ("_features", "_feature", "_feats", "_feat", "_list", "_group"):
451
+ if key.endswith(suffix):
452
+ key = key[: -len(suffix)]
453
+ return key
454
+
455
+ normalized_groups = {}
456
+ for gk in feature_groups:
457
+ norm = _normalize_group_key(gk)
458
+ normalized_groups.setdefault(norm, gk)
459
+
460
+ for param_name in sig_params:
461
+ if param_name in ("self",) or param_name in init_kwargs:
462
+ continue
463
+ norm_param = _normalize_group_key(param_name)
464
+ if norm_param in normalized_groups and (
465
+ accepts(param_name) or accepts_var_kwargs
466
+ ):
467
+ group_key = normalized_groups[norm_param]
468
+ init_kwargs[param_name] = _select(
469
+ feature_groups[group_key], feature_pool, str(group_key)
470
+ )
471
+
472
+ # Feature wiring: prefer explicit groups when provided
473
+ if accepts("dense_features"):
474
+ init_kwargs.setdefault("dense_features", dense_features)
475
+ if accepts("sparse_features"):
476
+ init_kwargs.setdefault("sparse_features", sparse_features)
477
+ if accepts("sequence_features"):
478
+ init_kwargs.setdefault("sequence_features", sequence_features)
479
+
480
+ if accepts("target"):
481
+ init_kwargs.setdefault("target", target)
482
+ if accepts("device"):
483
+ init_kwargs.setdefault("device", device)
484
+
485
+ # Pass session_id if model accepts it
486
+ if "session_id" not in init_kwargs and model_cfg.get("session_id") is not None:
487
+ if accepts("session_id") or accepts_var_kwargs:
488
+ init_kwargs["session_id"] = model_cfg.get("session_id")
489
+
490
+ return model_cls(**init_kwargs)
nextrec/utils/device.py CHANGED
@@ -2,13 +2,13 @@
2
2
  Device management utilities for NextRec
3
3
 
4
4
  Date: create on 03/12/2025
5
+ Checkpoint: edit on 06/12/2025
5
6
  Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
- import os
8
+
8
9
  import torch
9
10
  import platform
10
11
  import logging
11
- import multiprocessing
12
12
 
13
13
 
14
14
  def resolve_device() -> str:
@@ -17,52 +17,62 @@ def resolve_device() -> str:
17
17
  if torch.backends.mps.is_available():
18
18
  mac_ver = platform.mac_ver()[0]
19
19
  try:
20
- major, minor = (int(x) for x in mac_ver.split(".")[:2])
20
+ major, _ = (int(x) for x in mac_ver.split(".")[:2])
21
21
  except Exception:
22
- major, minor = 0, 0
22
+ major, _ = 0, 0
23
23
  if major >= 14:
24
24
  return "mps"
25
25
  return "cpu"
26
26
 
27
+
27
28
  def get_device_info() -> dict:
28
29
  info = {
29
- 'cuda_available': torch.cuda.is_available(),
30
- 'cuda_device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
31
- 'mps_available': torch.backends.mps.is_available(),
32
- 'current_device': resolve_device(),
30
+ "cuda_available": torch.cuda.is_available(),
31
+ "cuda_device_count": (
32
+ torch.cuda.device_count() if torch.cuda.is_available() else 0
33
+ ),
34
+ "mps_available": torch.backends.mps.is_available(),
35
+ "current_device": resolve_device(),
33
36
  }
34
-
37
+
35
38
  if torch.cuda.is_available():
36
- info['cuda_device_name'] = torch.cuda.get_device_name(0)
37
- info['cuda_capability'] = torch.cuda.get_device_capability(0)
38
-
39
+ info["cuda_device_name"] = torch.cuda.get_device_name(0)
40
+ info["cuda_capability"] = torch.cuda.get_device_capability(0)
41
+
39
42
  return info
40
43
 
44
+
41
45
  def configure_device(
42
- distributed: bool,
43
- local_rank: int,
44
- base_device: torch.device | str = "cpu"
46
+ distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
45
47
  ) -> torch.device:
46
48
  try:
47
49
  device = torch.device(base_device)
48
50
  except Exception:
49
- logging.warning("[configure_device Warning] Invalid base_device, falling back to CPU.")
51
+ logging.warning(
52
+ "[configure_device Warning] Invalid base_device, falling back to CPU."
53
+ )
50
54
  return torch.device("cpu")
51
55
 
52
56
  if distributed:
53
57
  if device.type == "cuda":
54
58
  if not torch.cuda.is_available():
55
- logging.warning("[Distributed Warning] CUDA requested but unavailable. Falling back to CPU.")
59
+ logging.warning(
60
+ "[Distributed Warning] CUDA requested but unavailable. Falling back to CPU."
61
+ )
56
62
  return torch.device("cpu")
57
63
  if not (0 <= local_rank < torch.cuda.device_count()):
58
- logging.warning(f"[Distributed Warning] local_rank {local_rank} is invalid for available CUDA devices. Falling back to CPU.")
64
+ logging.warning(
65
+ f"[Distributed Warning] local_rank {local_rank} is invalid for available CUDA devices. Falling back to CPU."
66
+ )
59
67
  return torch.device("cpu")
60
68
  try:
61
69
  torch.cuda.set_device(local_rank)
62
70
  return torch.device(f"cuda:{local_rank}")
63
71
  except Exception as exc:
64
- logging.warning(f"[Distributed Warning] Failed to set CUDA device for local_rank {local_rank}: {exc}. Falling back to CPU.")
72
+ logging.warning(
73
+ f"[Distributed Warning] Failed to set CUDA device for local_rank {local_rank}: {exc}. Falling back to CPU."
74
+ )
65
75
  return torch.device("cpu")
66
76
  else:
67
77
  return torch.device("cpu")
68
- return device
78
+ return device
@@ -15,10 +15,13 @@ from torch.utils.data import DataLoader, IterableDataset
15
15
  from torch.utils.data.distributed import DistributedSampler
16
16
  from nextrec.basic.loggers import colorize
17
17
 
18
- def init_process_group(distributed: bool, rank: int, world_size: int, device_id: int | None = None) -> None:
18
+
19
+ def init_process_group(
20
+ distributed: bool, rank: int, world_size: int, device_id: int | None = None
21
+ ) -> None:
19
22
  """
20
23
  initialize distributed process group for multi-GPU training.
21
-
24
+
22
25
  Args:
23
26
  distributed: whether to enable distributed training
24
27
  rank: global rank of the current process
@@ -29,7 +32,10 @@ def init_process_group(distributed: bool, rank: int, world_size: int, device_id:
29
32
  backend = "nccl" if device_id is not None else "gloo"
30
33
  if backend == "nccl":
31
34
  torch.cuda.set_device(device_id)
32
- dist.init_process_group(backend=backend, init_method="env://", rank=rank, world_size=world_size)
35
+ dist.init_process_group(
36
+ backend=backend, init_method="env://", rank=rank, world_size=world_size
37
+ )
38
+
33
39
 
34
40
  def gather_numpy(self, array: np.ndarray | None) -> np.ndarray | None:
35
41
  """
@@ -53,6 +59,7 @@ def gather_numpy(self, array: np.ndarray | None) -> np.ndarray | None:
53
59
  return None
54
60
  return np.concatenate(pieces, axis=0)
55
61
 
62
+
56
63
  def add_distributed_sampler(
57
64
  loader: DataLoader,
58
65
  distributed: bool,
@@ -64,7 +71,7 @@ def add_distributed_sampler(
64
71
  is_main_process: bool = False,
65
72
  ) -> tuple[DataLoader, DistributedSampler | None]:
66
73
  """
67
- add distributedsampler to a dataloader, this for distributed training
74
+ add distributedsampler to a dataloader, this for distributed training
68
75
  when each device has its own dataloader
69
76
  """
70
77
  # early return if not distributed
@@ -78,11 +85,24 @@ def add_distributed_sampler(
78
85
  return loader, None
79
86
  if isinstance(dataset, IterableDataset):
80
87
  if is_main_process:
81
- logging.info(colorize("[Distributed Info] Iterable/streaming DataLoader provided; DistributedSampler is skipped. Ensure dataset handles sharding per rank.", color="yellow"))
88
+ logging.info(
89
+ colorize(
90
+ "[Distributed Info] Iterable/streaming DataLoader provided; DistributedSampler is skipped. Ensure dataset handles sharding per rank.",
91
+ color="yellow",
92
+ )
93
+ )
82
94
  return loader, None
83
- sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last)
95
+ sampler = DistributedSampler(
96
+ dataset,
97
+ num_replicas=world_size,
98
+ rank=rank,
99
+ shuffle=shuffle,
100
+ drop_last=drop_last,
101
+ )
84
102
  loader_kwargs = {
85
- "batch_size": loader.batch_size if loader.batch_size is not None else default_batch_size,
103
+ "batch_size": (
104
+ loader.batch_size if loader.batch_size is not None else default_batch_size
105
+ ),
86
106
  "shuffle": False,
87
107
  "sampler": sampler,
88
108
  "num_workers": loader.num_workers,
@@ -104,11 +124,18 @@ def add_distributed_sampler(
104
124
  if generator is not None:
105
125
  loader_kwargs["generator"] = generator
106
126
  if loader.num_workers > 0:
107
- loader_kwargs["persistent_workers"] = getattr(loader, "persistent_workers", False)
127
+ loader_kwargs["persistent_workers"] = getattr(
128
+ loader, "persistent_workers", False
129
+ )
108
130
  prefetch_factor = getattr(loader, "prefetch_factor", None)
109
131
  if prefetch_factor is not None:
110
132
  loader_kwargs["prefetch_factor"] = prefetch_factor
111
133
  distributed_loader = DataLoader(dataset, **loader_kwargs)
112
134
  if is_main_process:
113
- logging.info(colorize("[Distributed Info] Attached DistributedSampler to provided DataLoader", color="cyan"))
135
+ logging.info(
136
+ colorize(
137
+ "[Distributed Info] Attached DistributedSampler to provided DataLoader",
138
+ color="cyan",
139
+ )
140
+ )
114
141
  return distributed_loader, sampler
@@ -2,6 +2,7 @@
2
2
  Embedding utilities for NextRec
3
3
 
4
4
  Date: create on 13/11/2025
5
+ Checkpoint: edit on 06/12/2025
5
6
  Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
8
 
nextrec/utils/feature.py CHANGED
@@ -5,6 +5,7 @@ Date: create on 03/12/2025
5
5
  Author: Yang Zhou, zyaztec@gmail.com
6
6
  """
7
7
 
8
+
8
9
  def normalize_to_list(value: str | list[str] | None) -> list[str]:
9
10
  if value is None:
10
11
  return []