nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +250 -112
  7. nextrec/basic/loggers.py +63 -44
  8. nextrec/basic/metrics.py +270 -120
  9. nextrec/basic/model.py +1084 -402
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +492 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +273 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +69 -46
  29. nextrec/models/multi_task/mmoe.py +91 -53
  30. nextrec/models/multi_task/ple.py +117 -58
  31. nextrec/models/multi_task/poso.py +163 -55
  32. nextrec/models/multi_task/share_bottom.py +63 -36
  33. nextrec/models/ranking/afm.py +80 -45
  34. nextrec/models/ranking/autoint.py +74 -57
  35. nextrec/models/ranking/dcn.py +110 -48
  36. nextrec/models/ranking/dcn_v2.py +265 -45
  37. nextrec/models/ranking/deepfm.py +39 -24
  38. nextrec/models/ranking/dien.py +335 -146
  39. nextrec/models/ranking/din.py +158 -92
  40. nextrec/models/ranking/fibinet.py +134 -52
  41. nextrec/models/ranking/fm.py +68 -26
  42. nextrec/models/ranking/masknet.py +95 -33
  43. nextrec/models/ranking/pnn.py +128 -58
  44. nextrec/models/ranking/widedeep.py +40 -28
  45. nextrec/models/ranking/xdeepfm.py +67 -40
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +496 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +33 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/model.py +22 -0
  55. nextrec/utils/optimizer.py +25 -9
  56. nextrec/utils/synthetic_data.py +283 -165
  57. nextrec/utils/tensor.py +24 -13
  58. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
  59. nextrec-0.4.3.dist-info/RECORD +69 -0
  60. nextrec-0.4.3.dist-info/entry_points.txt +2 -0
  61. nextrec-0.4.1.dist-info/RECORD +0 -66
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
  63. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,496 @@
1
+ """
2
+ Configuration utilities for NextRec
3
+
4
+ This module provides utilities for loading and processing configuration files,
5
+ including feature configuration, model configuration, and training configuration.
6
+
7
+ Date: create on 06/12/2025
8
+ Author: Yang Zhou, zyaztec@gmail.com
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import importlib
14
+ import importlib.util
15
+ import inspect
16
+ from copy import deepcopy
17
+ from pathlib import Path
18
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
19
+
20
+ from nextrec.utils.feature import normalize_to_list
21
+
22
+ if TYPE_CHECKING:
23
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
24
+ from nextrec.data.preprocessor import DataProcessor
25
+
26
+
27
+ def resolve_path(path_str: str | Path, base_dir: Path) -> Path:
28
+ path = Path(path_str).expanduser()
29
+ if path.is_absolute():
30
+ return path
31
+ # Prefer resolving relative to current working directory when the path (or its parent)
32
+ # already exists there; otherwise fall back to the config file's directory.
33
+ cwd_path = (Path.cwd() / path).resolve()
34
+ if cwd_path.exists() or cwd_path.parent.exists():
35
+ return cwd_path
36
+ base_dir_path = (base_dir / path).resolve()
37
+ if base_dir_path.exists() or base_dir_path.parent.exists():
38
+ return base_dir_path
39
+ return cwd_path
40
+
41
+
42
+ def select_features(
43
+ feature_cfg: Dict[str, Any], df_columns: List[str]
44
+ ) -> Tuple[List[str], List[str], List[str]]:
45
+ columns = set(df_columns)
46
+
47
+ def pick(group: str) -> List[str]:
48
+ cfg = feature_cfg.get(group, {}) or {}
49
+ names = [name for name in cfg.keys() if name in columns]
50
+ missing = [name for name in cfg.keys() if name not in columns]
51
+ if missing:
52
+ print(f"[feature_config] skipped missing {group} columns: {missing}")
53
+ return names
54
+
55
+ dense_names = pick("dense")
56
+ sparse_names = pick("sparse")
57
+ sequence_names = pick("sequence")
58
+ return dense_names, sparse_names, sequence_names
59
+
60
+
61
+ def register_processor_features(
62
+ processor: DataProcessor,
63
+ feature_cfg: Dict[str, Any],
64
+ dense_names: List[str],
65
+ sparse_names: List[str],
66
+ sequence_names: List[str],
67
+ ) -> None:
68
+ """
69
+ Register features to DataProcessor based on feature configuration.
70
+
71
+ Args:
72
+ processor: DataProcessor instance
73
+ feature_cfg: Feature configuration dictionary
74
+ dense_names: List of dense feature names
75
+ sparse_names: List of sparse feature names
76
+ sequence_names: List of sequence feature names
77
+ """
78
+ dense_cfg = feature_cfg.get("dense", {}) or {}
79
+ sparse_cfg = feature_cfg.get("sparse", {}) or {}
80
+ sequence_cfg = feature_cfg.get("sequence", {}) or {}
81
+
82
+ for name in dense_names:
83
+ proc_cfg = dense_cfg.get(name, {}).get("processor_config", {}) or {}
84
+ processor.add_numeric_feature(
85
+ name,
86
+ scaler=proc_cfg.get("scaler", "standard"),
87
+ fill_na=proc_cfg.get("fill_na"),
88
+ )
89
+
90
+ for name in sparse_names:
91
+ proc_cfg = sparse_cfg.get(name, {}).get("processor_config", {}) or {}
92
+ processor.add_sparse_feature(
93
+ name,
94
+ encode_method=proc_cfg.get("encode_method", "hash"),
95
+ hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
96
+ fill_na=proc_cfg.get("fill_na", "<UNK>"),
97
+ )
98
+
99
+ for name in sequence_names:
100
+ proc_cfg = sequence_cfg.get(name, {}).get("processor_config", {}) or {}
101
+ processor.add_sequence_feature(
102
+ name,
103
+ encode_method=proc_cfg.get("encode_method", "hash"),
104
+ hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
105
+ max_len=proc_cfg.get("max_len", 50),
106
+ pad_value=proc_cfg.get("pad_value", 0),
107
+ truncate=proc_cfg.get("truncate", "post"),
108
+ separator=proc_cfg.get("separator", ","),
109
+ )
110
+
111
+
112
+ def build_feature_objects(
113
+ processor: "DataProcessor",
114
+ feature_cfg: Dict[str, Any],
115
+ dense_names: List[str],
116
+ sparse_names: List[str],
117
+ sequence_names: List[str],
118
+ ) -> Tuple[List["DenseFeature"], List["SparseFeature"], List["SequenceFeature"]]:
119
+ """
120
+ Build feature objects from processor and feature configuration.
121
+
122
+ Args:
123
+ processor: Fitted DataProcessor instance
124
+ feature_cfg: Feature configuration dictionary
125
+ dense_names: List of dense feature names
126
+ sparse_names: List of sparse feature names
127
+ sequence_names: List of sequence feature names
128
+ """
129
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
130
+
131
+ dense_cfg = feature_cfg.get("dense", {}) or {}
132
+ sparse_cfg = feature_cfg.get("sparse", {}) or {}
133
+ sequence_cfg = feature_cfg.get("sequence", {}) or {}
134
+ vocab_sizes = processor.get_vocab_sizes()
135
+
136
+ dense_features: List[DenseFeature] = []
137
+ for name in dense_names:
138
+ embed_cfg = dense_cfg.get(name, {}).get("embedding_config", {}) or {}
139
+ dense_features.append(
140
+ DenseFeature(
141
+ name=name,
142
+ embedding_dim=embed_cfg.get("embedding_dim"),
143
+ input_dim=embed_cfg.get("input_dim", 1),
144
+ use_embedding=embed_cfg.get("use_embedding", False),
145
+ )
146
+ )
147
+
148
+ sparse_features: List[SparseFeature] = []
149
+ for name in sparse_names:
150
+ entry = sparse_cfg.get(name, {}) or {}
151
+ proc_cfg = entry.get("processor_config", {}) or {}
152
+ embed_cfg = entry.get("embedding_config", {}) or {}
153
+ vocab_size = (
154
+ embed_cfg.get("vocab_size")
155
+ or proc_cfg.get("hash_size")
156
+ or vocab_sizes.get(name, 0)
157
+ or 1
158
+ )
159
+ sparse_features.append(
160
+ SparseFeature(
161
+ name=name,
162
+ vocab_size=int(vocab_size),
163
+ embedding_dim=embed_cfg.get("embedding_dim"),
164
+ padding_idx=embed_cfg.get("padding_idx"),
165
+ l1_reg=embed_cfg.get("l1_reg", 0.0),
166
+ l2_reg=embed_cfg.get("l2_reg", 1e-5),
167
+ trainable=embed_cfg.get("trainable", True),
168
+ )
169
+ )
170
+
171
+ sequence_features: List[SequenceFeature] = []
172
+ for name in sequence_names:
173
+ entry = sequence_cfg.get(name, {}) or {}
174
+ proc_cfg = entry.get("processor_config", {}) or {}
175
+ embed_cfg = entry.get("embedding_config", {}) or {}
176
+ vocab_size = (
177
+ embed_cfg.get("vocab_size")
178
+ or proc_cfg.get("hash_size")
179
+ or vocab_sizes.get(name, 0)
180
+ or 1
181
+ )
182
+ sequence_features.append(
183
+ SequenceFeature(
184
+ name=name,
185
+ vocab_size=int(vocab_size),
186
+ max_len=embed_cfg.get("max_len") or proc_cfg.get("max_len", 50),
187
+ embedding_dim=embed_cfg.get("embedding_dim"),
188
+ padding_idx=embed_cfg.get("padding_idx"),
189
+ combiner=embed_cfg.get("combiner", "mean"),
190
+ l1_reg=embed_cfg.get("l1_reg", 0.0),
191
+ l2_reg=embed_cfg.get("l2_reg", 1e-5),
192
+ trainable=embed_cfg.get("trainable", True),
193
+ )
194
+ )
195
+
196
+ return dense_features, sparse_features, sequence_features
197
+
198
+
199
+ def extract_feature_groups(
200
+ feature_cfg: Dict[str, Any], df_columns: List[str]
201
+ ) -> Tuple[Dict[str, List[str]], List[str]]:
202
+ """
203
+ Extract and validate feature groups from feature configuration.
204
+
205
+ Args:
206
+ feature_cfg: Feature configuration dictionary
207
+ df_columns: Available dataframe columns
208
+ """
209
+ feature_groups = feature_cfg.get("feature_groups") or {}
210
+ if not feature_groups:
211
+ return {}, []
212
+
213
+ defined = (
214
+ set((feature_cfg.get("dense") or {}).keys())
215
+ | set((feature_cfg.get("sparse") or {}).keys())
216
+ | set((feature_cfg.get("sequence") or {}).keys())
217
+ )
218
+ available_cols = set(df_columns)
219
+ resolved: Dict[str, List[str]] = {}
220
+ collected: List[str] = []
221
+
222
+ for group_name, names in feature_groups.items():
223
+ name_list = normalize_to_list(names)
224
+ filtered = []
225
+ missing_defined = [n for n in name_list if n not in defined]
226
+ missing_cols = [n for n in name_list if n not in available_cols]
227
+
228
+ if missing_defined:
229
+ print(
230
+ f"[feature_config] feature_groups.{group_name} contains features not defined in dense/sparse/sequence: {missing_defined}"
231
+ )
232
+
233
+ for n in name_list:
234
+ if n in available_cols:
235
+ if n not in filtered:
236
+ filtered.append(n)
237
+ else:
238
+ if n not in missing_cols:
239
+ missing_cols.append(n)
240
+
241
+ if missing_cols:
242
+ print(
243
+ f"[feature_config] feature_groups.{group_name} missing data columns: {missing_cols}"
244
+ )
245
+
246
+ resolved[group_name] = filtered
247
+ collected.extend(filtered)
248
+
249
+ return resolved, collected
250
+
251
+
252
+ def load_model_class(model_cfg: Dict[str, Any], base_dir: Path) -> type:
253
+ """
254
+ Load model class from configuration.
255
+
256
+ Args:
257
+ model_cfg: Model configuration dictionary
258
+ base_dir: Base directory for resolving relative paths
259
+ """
260
+
261
+ def camelize(name: str) -> str:
262
+ """Convert snake_case or kebab-case to CamelCase."""
263
+ return "".join(
264
+ part.capitalize()
265
+ for part in name.replace("_", " ").replace("-", " ").split()
266
+ )
267
+
268
+ module_path = model_cfg.get("module_path")
269
+ name = model_cfg.get("model") or model_cfg.get("name")
270
+ module_name = model_cfg.get("module") or model_cfg.get("module_name")
271
+ class_name = model_cfg.get("class_name")
272
+
273
+ # Case 1: Custom file path
274
+ if module_path:
275
+ resolved = resolve_path(module_path, base_dir)
276
+ if not resolved.exists():
277
+ raise FileNotFoundError(f"Custom model file not found: {resolved}")
278
+
279
+ spec = importlib.util.spec_from_file_location(resolved.stem, resolved)
280
+ if spec is None or spec.loader is None:
281
+ raise ImportError(f"Unable to load custom model file: {resolved}")
282
+
283
+ module = importlib.util.module_from_spec(spec)
284
+ spec.loader.exec_module(module)
285
+
286
+ if class_name and hasattr(module, class_name):
287
+ return getattr(module, class_name)
288
+
289
+ # Auto-pick first BaseModel subclass
290
+ from nextrec.basic.model import BaseModel
291
+
292
+ for attr in module.__dict__.values():
293
+ if (
294
+ isinstance(attr, type)
295
+ and issubclass(attr, BaseModel)
296
+ and attr is not BaseModel
297
+ ):
298
+ return attr
299
+
300
+ raise AttributeError(
301
+ f"No BaseModel subclass found in {resolved}, please provide class_name"
302
+ )
303
+
304
+ # Case 2: Builtin model by short name
305
+ if name and not module_name:
306
+ from nextrec.basic.model import BaseModel
307
+
308
+ candidates = [
309
+ f"nextrec.models.{name.lower()}",
310
+ f"nextrec.models.ranking.{name.lower()}",
311
+ f"nextrec.models.match.{name.lower()}",
312
+ f"nextrec.models.multi_task.{name.lower()}",
313
+ f"nextrec.models.generative.{name.lower()}",
314
+ ]
315
+ errors = []
316
+
317
+ for mod in candidates:
318
+ try:
319
+ module = importlib.import_module(mod)
320
+ cls_name = class_name or camelize(name)
321
+
322
+ if hasattr(module, cls_name):
323
+ return getattr(module, cls_name)
324
+
325
+ # Fallback: first BaseModel subclass
326
+ for attr in module.__dict__.values():
327
+ if (
328
+ isinstance(attr, type)
329
+ and issubclass(attr, BaseModel)
330
+ and attr is not BaseModel
331
+ ):
332
+ return attr
333
+
334
+ errors.append(f"{mod} missing class {cls_name}")
335
+ except Exception as exc:
336
+ errors.append(f"{mod}: {exc}")
337
+
338
+ raise ImportError(f"Unable to find model for model='{name}'. Tried: {errors}")
339
+
340
+ # Case 3: Explicit module + class
341
+ if module_name and class_name:
342
+ module = importlib.import_module(module_name)
343
+ if not hasattr(module, class_name):
344
+ raise AttributeError(f"Class {class_name} not found in {module_name}")
345
+ return getattr(module, class_name)
346
+
347
+ raise ValueError(
348
+ "model configuration must provide 'model' (builtin name), 'module_path' (custom path), or 'module'+'class_name'"
349
+ )
350
+
351
+
352
+ def build_model_instance(
353
+ model_cfg: Dict[str, Any],
354
+ model_cfg_path: Path,
355
+ dense_features: List[DenseFeature],
356
+ sparse_features: List[SparseFeature],
357
+ sequence_features: List[SequenceFeature],
358
+ target: List[str],
359
+ device: str,
360
+ ) -> Any:
361
+ """
362
+ Build model instance from configuration and feature objects.
363
+
364
+ Args:
365
+ model_cfg: Model configuration dictionary
366
+ model_cfg_path: Path to model config file (for resolving relative paths)
367
+ dense_features: List of dense feature objects
368
+ sparse_features: List of sparse feature objects
369
+ sequence_features: List of sequence feature objects
370
+ target: List of target column names
371
+ device: Device string (e.g., 'cpu', 'cuda:0')
372
+ """
373
+ dense_map = {f.name: f for f in dense_features}
374
+ sparse_map = {f.name: f for f in sparse_features}
375
+ sequence_map = {f.name: f for f in sequence_features}
376
+ feature_pool: Dict[str, Any] = {**dense_map, **sparse_map, **sequence_map}
377
+
378
+ model_cls = load_model_class(model_cfg, model_cfg_path.parent)
379
+ params_cfg = deepcopy(model_cfg.get("params") or {})
380
+ feature_groups = params_cfg.pop("feature_groups", {}) or {}
381
+ feature_bindings_cfg = (
382
+ model_cfg.get("feature_bindings")
383
+ or params_cfg.pop("feature_bindings", {})
384
+ or {}
385
+ )
386
+ sig_params = inspect.signature(model_cls.__init__).parameters
387
+
388
+ def _select(names: List[str] | None, pool: Dict[str, Any], desc: str) -> List[Any]:
389
+ """Select features from pool by names."""
390
+ if names is None:
391
+ return list(pool.values())
392
+ missing = [n for n in names if n not in feature_pool]
393
+ if missing:
394
+ raise ValueError(
395
+ f"feature_groups.{desc} contains unknown features: {missing}"
396
+ )
397
+ return [feature_pool[n] for n in names]
398
+
399
+ def accepts(name: str) -> bool:
400
+ """Check if parameter name is accepted by model __init__."""
401
+ return name in sig_params
402
+
403
+ accepts_var_kwargs = any(
404
+ param.kind == inspect.Parameter.VAR_KEYWORD for param in sig_params.values()
405
+ )
406
+
407
+ init_kwargs: Dict[str, Any] = dict(params_cfg)
408
+
409
+ # Explicit bindings (model_config.feature_bindings) take priority
410
+ for param_name, binding in feature_bindings_cfg.items():
411
+ if param_name in init_kwargs:
412
+ continue
413
+
414
+ if isinstance(binding, (list, tuple, set)):
415
+ if accepts(param_name) or accepts_var_kwargs:
416
+ init_kwargs[param_name] = _select(
417
+ list(binding), feature_pool, f"feature_bindings.{param_name}"
418
+ )
419
+ continue
420
+
421
+ if isinstance(binding, dict):
422
+ direct_features = binding.get("features") or binding.get("feature_names")
423
+ if direct_features and (accepts(param_name) or accepts_var_kwargs):
424
+ init_kwargs[param_name] = _select(
425
+ normalize_to_list(direct_features),
426
+ feature_pool,
427
+ f"feature_bindings.{param_name}",
428
+ )
429
+ continue
430
+ group_key = binding.get("group") or binding.get("group_key")
431
+ else:
432
+ group_key = binding
433
+
434
+ if group_key not in feature_groups:
435
+ print(
436
+ f"[feature_config] feature_bindings refers to unknown group '{group_key}', skipped"
437
+ )
438
+ continue
439
+
440
+ if accepts(param_name) or accepts_var_kwargs:
441
+ init_kwargs[param_name] = _select(
442
+ feature_groups[group_key], feature_pool, str(group_key)
443
+ )
444
+
445
+ # Dynamic feature groups: any key in feature_groups that matches __init__ will be filled
446
+ for group_key, names in feature_groups.items():
447
+ if accepts(str(group_key)):
448
+ init_kwargs.setdefault(
449
+ str(group_key), _select(names, feature_pool, str(group_key))
450
+ )
451
+
452
+ # Generalized mapping: match params to feature_groups by normalized names
453
+ def _normalize_group_key(key: str) -> str:
454
+ """Normalize group key by removing common suffixes."""
455
+ key = key.lower()
456
+ for suffix in ("_features", "_feature", "_feats", "_feat", "_list", "_group"):
457
+ if key.endswith(suffix):
458
+ key = key[: -len(suffix)]
459
+ return key
460
+
461
+ normalized_groups = {}
462
+ for gk in feature_groups:
463
+ norm = _normalize_group_key(gk)
464
+ normalized_groups.setdefault(norm, gk)
465
+
466
+ for param_name in sig_params:
467
+ if param_name in ("self",) or param_name in init_kwargs:
468
+ continue
469
+ norm_param = _normalize_group_key(param_name)
470
+ if norm_param in normalized_groups and (
471
+ accepts(param_name) or accepts_var_kwargs
472
+ ):
473
+ group_key = normalized_groups[norm_param]
474
+ init_kwargs[param_name] = _select(
475
+ feature_groups[group_key], feature_pool, str(group_key)
476
+ )
477
+
478
+ # Feature wiring: prefer explicit groups when provided
479
+ if accepts("dense_features"):
480
+ init_kwargs.setdefault("dense_features", dense_features)
481
+ if accepts("sparse_features"):
482
+ init_kwargs.setdefault("sparse_features", sparse_features)
483
+ if accepts("sequence_features"):
484
+ init_kwargs.setdefault("sequence_features", sequence_features)
485
+
486
+ if accepts("target"):
487
+ init_kwargs.setdefault("target", target)
488
+ if accepts("device"):
489
+ init_kwargs.setdefault("device", device)
490
+
491
+ # Pass session_id if model accepts it
492
+ if "session_id" not in init_kwargs and model_cfg.get("session_id") is not None:
493
+ if accepts("session_id") or accepts_var_kwargs:
494
+ init_kwargs["session_id"] = model_cfg.get("session_id")
495
+
496
+ return model_cls(**init_kwargs)
nextrec/utils/device.py CHANGED
@@ -2,13 +2,13 @@
2
2
  Device management utilities for NextRec
3
3
 
4
4
  Date: create on 03/12/2025
5
+ Checkpoint: edit on 06/12/2025
5
6
  Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
- import os
8
+
8
9
  import torch
9
10
  import platform
10
11
  import logging
11
- import multiprocessing
12
12
 
13
13
 
14
14
  def resolve_device() -> str:
@@ -17,52 +17,62 @@ def resolve_device() -> str:
17
17
  if torch.backends.mps.is_available():
18
18
  mac_ver = platform.mac_ver()[0]
19
19
  try:
20
- major, minor = (int(x) for x in mac_ver.split(".")[:2])
20
+ major, _ = (int(x) for x in mac_ver.split(".")[:2])
21
21
  except Exception:
22
- major, minor = 0, 0
22
+ major, _ = 0, 0
23
23
  if major >= 14:
24
24
  return "mps"
25
25
  return "cpu"
26
26
 
27
+
27
28
  def get_device_info() -> dict:
28
29
  info = {
29
- 'cuda_available': torch.cuda.is_available(),
30
- 'cuda_device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
31
- 'mps_available': torch.backends.mps.is_available(),
32
- 'current_device': resolve_device(),
30
+ "cuda_available": torch.cuda.is_available(),
31
+ "cuda_device_count": (
32
+ torch.cuda.device_count() if torch.cuda.is_available() else 0
33
+ ),
34
+ "mps_available": torch.backends.mps.is_available(),
35
+ "current_device": resolve_device(),
33
36
  }
34
-
37
+
35
38
  if torch.cuda.is_available():
36
- info['cuda_device_name'] = torch.cuda.get_device_name(0)
37
- info['cuda_capability'] = torch.cuda.get_device_capability(0)
38
-
39
+ info["cuda_device_name"] = torch.cuda.get_device_name(0)
40
+ info["cuda_capability"] = torch.cuda.get_device_capability(0)
41
+
39
42
  return info
40
43
 
44
+
41
45
  def configure_device(
42
- distributed: bool,
43
- local_rank: int,
44
- base_device: torch.device | str = "cpu"
46
+ distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
45
47
  ) -> torch.device:
46
48
  try:
47
49
  device = torch.device(base_device)
48
50
  except Exception:
49
- logging.warning("[configure_device Warning] Invalid base_device, falling back to CPU.")
51
+ logging.warning(
52
+ "[configure_device Warning] Invalid base_device, falling back to CPU."
53
+ )
50
54
  return torch.device("cpu")
51
55
 
52
56
  if distributed:
53
57
  if device.type == "cuda":
54
58
  if not torch.cuda.is_available():
55
- logging.warning("[Distributed Warning] CUDA requested but unavailable. Falling back to CPU.")
59
+ logging.warning(
60
+ "[Distributed Warning] CUDA requested but unavailable. Falling back to CPU."
61
+ )
56
62
  return torch.device("cpu")
57
63
  if not (0 <= local_rank < torch.cuda.device_count()):
58
- logging.warning(f"[Distributed Warning] local_rank {local_rank} is invalid for available CUDA devices. Falling back to CPU.")
64
+ logging.warning(
65
+ f"[Distributed Warning] local_rank {local_rank} is invalid for available CUDA devices. Falling back to CPU."
66
+ )
59
67
  return torch.device("cpu")
60
68
  try:
61
69
  torch.cuda.set_device(local_rank)
62
70
  return torch.device(f"cuda:{local_rank}")
63
71
  except Exception as exc:
64
- logging.warning(f"[Distributed Warning] Failed to set CUDA device for local_rank {local_rank}: {exc}. Falling back to CPU.")
72
+ logging.warning(
73
+ f"[Distributed Warning] Failed to set CUDA device for local_rank {local_rank}: {exc}. Falling back to CPU."
74
+ )
65
75
  return torch.device("cpu")
66
76
  else:
67
77
  return torch.device("cpu")
68
- return device
78
+ return device
@@ -15,10 +15,13 @@ from torch.utils.data import DataLoader, IterableDataset
15
15
  from torch.utils.data.distributed import DistributedSampler
16
16
  from nextrec.basic.loggers import colorize
17
17
 
18
- def init_process_group(distributed: bool, rank: int, world_size: int, device_id: int | None = None) -> None:
18
+
19
+ def init_process_group(
20
+ distributed: bool, rank: int, world_size: int, device_id: int | None = None
21
+ ) -> None:
19
22
  """
20
23
  initialize distributed process group for multi-GPU training.
21
-
24
+
22
25
  Args:
23
26
  distributed: whether to enable distributed training
24
27
  rank: global rank of the current process
@@ -29,7 +32,10 @@ def init_process_group(distributed: bool, rank: int, world_size: int, device_id:
29
32
  backend = "nccl" if device_id is not None else "gloo"
30
33
  if backend == "nccl":
31
34
  torch.cuda.set_device(device_id)
32
- dist.init_process_group(backend=backend, init_method="env://", rank=rank, world_size=world_size)
35
+ dist.init_process_group(
36
+ backend=backend, init_method="env://", rank=rank, world_size=world_size
37
+ )
38
+
33
39
 
34
40
  def gather_numpy(self, array: np.ndarray | None) -> np.ndarray | None:
35
41
  """
@@ -53,6 +59,7 @@ def gather_numpy(self, array: np.ndarray | None) -> np.ndarray | None:
53
59
  return None
54
60
  return np.concatenate(pieces, axis=0)
55
61
 
62
+
56
63
  def add_distributed_sampler(
57
64
  loader: DataLoader,
58
65
  distributed: bool,
@@ -64,7 +71,7 @@ def add_distributed_sampler(
64
71
  is_main_process: bool = False,
65
72
  ) -> tuple[DataLoader, DistributedSampler | None]:
66
73
  """
67
- add distributedsampler to a dataloader, this for distributed training
74
+ add distributedsampler to a dataloader, this for distributed training
68
75
  when each device has its own dataloader
69
76
  """
70
77
  # early return if not distributed
@@ -78,11 +85,24 @@ def add_distributed_sampler(
78
85
  return loader, None
79
86
  if isinstance(dataset, IterableDataset):
80
87
  if is_main_process:
81
- logging.info(colorize("[Distributed Info] Iterable/streaming DataLoader provided; DistributedSampler is skipped. Ensure dataset handles sharding per rank.", color="yellow"))
88
+ logging.info(
89
+ colorize(
90
+ "[Distributed Info] Iterable/streaming DataLoader provided; DistributedSampler is skipped. Ensure dataset handles sharding per rank.",
91
+ color="yellow",
92
+ )
93
+ )
82
94
  return loader, None
83
- sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last)
95
+ sampler = DistributedSampler(
96
+ dataset,
97
+ num_replicas=world_size,
98
+ rank=rank,
99
+ shuffle=shuffle,
100
+ drop_last=drop_last,
101
+ )
84
102
  loader_kwargs = {
85
- "batch_size": loader.batch_size if loader.batch_size is not None else default_batch_size,
103
+ "batch_size": (
104
+ loader.batch_size if loader.batch_size is not None else default_batch_size
105
+ ),
86
106
  "shuffle": False,
87
107
  "sampler": sampler,
88
108
  "num_workers": loader.num_workers,
@@ -104,11 +124,18 @@ def add_distributed_sampler(
104
124
  if generator is not None:
105
125
  loader_kwargs["generator"] = generator
106
126
  if loader.num_workers > 0:
107
- loader_kwargs["persistent_workers"] = getattr(loader, "persistent_workers", False)
127
+ loader_kwargs["persistent_workers"] = getattr(
128
+ loader, "persistent_workers", False
129
+ )
108
130
  prefetch_factor = getattr(loader, "prefetch_factor", None)
109
131
  if prefetch_factor is not None:
110
132
  loader_kwargs["prefetch_factor"] = prefetch_factor
111
133
  distributed_loader = DataLoader(dataset, **loader_kwargs)
112
134
  if is_main_process:
113
- logging.info(colorize("[Distributed Info] Attached DistributedSampler to provided DataLoader", color="cyan"))
135
+ logging.info(
136
+ colorize(
137
+ "[Distributed Info] Attached DistributedSampler to provided DataLoader",
138
+ color="cyan",
139
+ )
140
+ )
114
141
  return distributed_loader, sampler
@@ -2,6 +2,7 @@
2
2
  Embedding utilities for NextRec
3
3
 
4
4
  Date: create on 13/11/2025
5
+ Checkpoint: edit on 06/12/2025
5
6
  Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
8