nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +244 -113
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1373 -443
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +42 -24
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +303 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +106 -40
  23. nextrec/models/match/dssm.py +82 -69
  24. nextrec/models/match/dssm_v2.py +72 -58
  25. nextrec/models/match/mind.py +175 -108
  26. nextrec/models/match/sdm.py +104 -88
  27. nextrec/models/match/youtube_dnn.py +73 -60
  28. nextrec/models/multi_task/esmm.py +53 -39
  29. nextrec/models/multi_task/mmoe.py +70 -47
  30. nextrec/models/multi_task/ple.py +107 -50
  31. nextrec/models/multi_task/poso.py +121 -41
  32. nextrec/models/multi_task/share_bottom.py +54 -38
  33. nextrec/models/ranking/afm.py +172 -45
  34. nextrec/models/ranking/autoint.py +84 -61
  35. nextrec/models/ranking/dcn.py +59 -42
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +36 -26
  38. nextrec/models/ranking/dien.py +158 -102
  39. nextrec/models/ranking/din.py +88 -60
  40. nextrec/models/ranking/fibinet.py +55 -35
  41. nextrec/models/ranking/fm.py +32 -26
  42. nextrec/models/ranking/masknet.py +95 -34
  43. nextrec/models/ranking/pnn.py +34 -31
  44. nextrec/models/ranking/widedeep.py +37 -29
  45. nextrec/models/ranking/xdeepfm.py +63 -41
  46. nextrec/utils/__init__.py +61 -32
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +52 -12
  49. nextrec/utils/distributed.py +141 -0
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +531 -0
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.3.6.dist-info/RECORD +0 -64
  61. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/utils/__init__.py CHANGED
@@ -10,59 +10,88 @@ This package provides various utility functions organized by category:
10
10
  - file_utils: File I/O operations
11
11
  - model_utils: Model-related utilities
12
12
  - feature_utils: Feature processing utilities
13
+ - config_utils: Configuration loading and processing utilities
13
14
 
14
15
  Date: create on 13/11/2025
15
- Last update: 03/12/2025 (refactored)
16
+ Last update: 06/12/2025
16
17
  Author: Yang Zhou, zyaztec@gmail.com
17
18
  """
18
19
 
20
+ from . import optimizer, initializer, embedding
19
21
  from .optimizer import get_optimizer, get_scheduler
20
22
  from .initializer import get_initializer
21
23
  from .embedding import get_auto_embedding_dim
22
24
  from .device import resolve_device, get_device_info
23
25
  from .tensor import to_tensor, stack_tensors, concat_tensors, pad_sequence_tensors
24
- from .file import resolve_file_paths, read_table, load_dataframes, iter_file_chunks, default_output_dir
26
+ from .file import (
27
+ resolve_file_paths,
28
+ read_table,
29
+ load_dataframes,
30
+ iter_file_chunks,
31
+ default_output_dir,
32
+ read_yaml,
33
+ )
25
34
  from .model import merge_features, get_mlp_output_dim
26
35
  from .feature import normalize_to_list
27
- from . import optimizer, initializer, embedding
36
+ from .synthetic_data import (
37
+ generate_match_data,
38
+ generate_ranking_data,
39
+ generate_multitask_data,
40
+ generate_distributed_ranking_data,
41
+ )
42
+ from .config import (
43
+ resolve_path,
44
+ select_features,
45
+ register_processor_features,
46
+ build_feature_objects,
47
+ extract_feature_groups,
48
+ load_model_class,
49
+ build_model_instance,
50
+ )
28
51
 
29
52
  __all__ = [
30
53
  # Optimizer & Scheduler
31
- 'get_optimizer',
32
- 'get_scheduler',
33
-
54
+ "get_optimizer",
55
+ "get_scheduler",
34
56
  # Initializer
35
- 'get_initializer',
36
-
57
+ "get_initializer",
37
58
  # Embedding
38
- 'get_auto_embedding_dim',
39
-
59
+ "get_auto_embedding_dim",
40
60
  # Device utilities
41
- 'resolve_device',
42
- 'get_device_info',
43
-
61
+ "resolve_device",
62
+ "get_device_info",
44
63
  # Tensor utilities
45
- 'to_tensor',
46
- 'stack_tensors',
47
- 'concat_tensors',
48
- 'pad_sequence_tensors',
49
-
64
+ "to_tensor",
65
+ "stack_tensors",
66
+ "concat_tensors",
67
+ "pad_sequence_tensors",
50
68
  # File utilities
51
- 'resolve_file_paths',
52
- 'read_table',
53
- 'load_dataframes',
54
- 'iter_file_chunks',
55
- 'default_output_dir',
56
-
69
+ "resolve_file_paths",
70
+ "read_table",
71
+ "read_yaml",
72
+ "load_dataframes",
73
+ "iter_file_chunks",
74
+ "default_output_dir",
57
75
  # Model utilities
58
- 'merge_features',
59
- 'get_mlp_output_dim',
60
-
76
+ "merge_features",
77
+ "get_mlp_output_dim",
61
78
  # Feature utilities
62
- 'normalize_to_list',
63
-
79
+ "normalize_to_list",
80
+ # Config utilities
81
+ "resolve_path",
82
+ "select_features",
83
+ "register_processor_features",
84
+ "build_feature_objects",
85
+ "extract_feature_groups",
86
+ "load_model_class",
87
+ "build_model_instance",
88
+ # Synthetic data utilities
89
+ "generate_ranking_data",
90
+ "generate_match_data",
91
+ "generate_multitask_data",
92
+ "generate_distributed_ranking_data",
64
93
  # Module exports
65
- 'optimizer',
66
- 'initializer',
67
- 'embedding',
94
+ "optimizer",
95
+ "initializer",
96
+ "embedding",
68
97
  ]
@@ -0,0 +1,490 @@
1
+ """
2
+ Configuration utilities for NextRec
3
+
4
+ This module provides utilities for loading and processing configuration files,
5
+ including feature configuration, model configuration, and training configuration.
6
+
7
+ Date: create on 06/12/2025
8
+ Author: Yang Zhou, zyaztec@gmail.com
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import importlib
14
+ import importlib.util
15
+ import inspect
16
+ from copy import deepcopy
17
+ from pathlib import Path
18
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
19
+
20
+ from nextrec.utils.feature import normalize_to_list
21
+
22
+ if TYPE_CHECKING:
23
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
24
+ from nextrec.data.preprocessor import DataProcessor
25
+
26
+
27
+ def resolve_path(path_str: str | Path, base_dir: Path) -> Path:
28
+ path = Path(path_str).expanduser()
29
+ if path.is_absolute():
30
+ return path
31
+ if path.exists():
32
+ return path.resolve()
33
+ return (base_dir / path).resolve()
34
+
35
+
36
+ def select_features(
37
+ feature_cfg: Dict[str, Any], df_columns: List[str]
38
+ ) -> Tuple[List[str], List[str], List[str]]:
39
+ columns = set(df_columns)
40
+
41
+ def pick(group: str) -> List[str]:
42
+ cfg = feature_cfg.get(group, {}) or {}
43
+ names = [name for name in cfg.keys() if name in columns]
44
+ missing = [name for name in cfg.keys() if name not in columns]
45
+ if missing:
46
+ print(f"[feature_config] skipped missing {group} columns: {missing}")
47
+ return names
48
+
49
+ dense_names = pick("dense")
50
+ sparse_names = pick("sparse")
51
+ sequence_names = pick("sequence")
52
+ return dense_names, sparse_names, sequence_names
53
+
54
+
55
+ def register_processor_features(
56
+ processor: DataProcessor,
57
+ feature_cfg: Dict[str, Any],
58
+ dense_names: List[str],
59
+ sparse_names: List[str],
60
+ sequence_names: List[str],
61
+ ) -> None:
62
+ """
63
+ Register features to DataProcessor based on feature configuration.
64
+
65
+ Args:
66
+ processor: DataProcessor instance
67
+ feature_cfg: Feature configuration dictionary
68
+ dense_names: List of dense feature names
69
+ sparse_names: List of sparse feature names
70
+ sequence_names: List of sequence feature names
71
+ """
72
+ dense_cfg = feature_cfg.get("dense", {}) or {}
73
+ sparse_cfg = feature_cfg.get("sparse", {}) or {}
74
+ sequence_cfg = feature_cfg.get("sequence", {}) or {}
75
+
76
+ for name in dense_names:
77
+ proc_cfg = dense_cfg.get(name, {}).get("processor_config", {}) or {}
78
+ processor.add_numeric_feature(
79
+ name,
80
+ scaler=proc_cfg.get("scaler", "standard"),
81
+ fill_na=proc_cfg.get("fill_na"),
82
+ )
83
+
84
+ for name in sparse_names:
85
+ proc_cfg = sparse_cfg.get(name, {}).get("processor_config", {}) or {}
86
+ processor.add_sparse_feature(
87
+ name,
88
+ encode_method=proc_cfg.get("encode_method", "hash"),
89
+ hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
90
+ fill_na=proc_cfg.get("fill_na", "<UNK>"),
91
+ )
92
+
93
+ for name in sequence_names:
94
+ proc_cfg = sequence_cfg.get(name, {}).get("processor_config", {}) or {}
95
+ processor.add_sequence_feature(
96
+ name,
97
+ encode_method=proc_cfg.get("encode_method", "hash"),
98
+ hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
99
+ max_len=proc_cfg.get("max_len", 50),
100
+ pad_value=proc_cfg.get("pad_value", 0),
101
+ truncate=proc_cfg.get("truncate", "post"),
102
+ separator=proc_cfg.get("separator", ","),
103
+ )
104
+
105
+
106
+ def build_feature_objects(
107
+ processor: "DataProcessor",
108
+ feature_cfg: Dict[str, Any],
109
+ dense_names: List[str],
110
+ sparse_names: List[str],
111
+ sequence_names: List[str],
112
+ ) -> Tuple[List["DenseFeature"], List["SparseFeature"], List["SequenceFeature"]]:
113
+ """
114
+ Build feature objects from processor and feature configuration.
115
+
116
+ Args:
117
+ processor: Fitted DataProcessor instance
118
+ feature_cfg: Feature configuration dictionary
119
+ dense_names: List of dense feature names
120
+ sparse_names: List of sparse feature names
121
+ sequence_names: List of sequence feature names
122
+ """
123
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
124
+
125
+ dense_cfg = feature_cfg.get("dense", {}) or {}
126
+ sparse_cfg = feature_cfg.get("sparse", {}) or {}
127
+ sequence_cfg = feature_cfg.get("sequence", {}) or {}
128
+ vocab_sizes = processor.get_vocab_sizes()
129
+
130
+ dense_features: List[DenseFeature] = []
131
+ for name in dense_names:
132
+ embed_cfg = dense_cfg.get(name, {}).get("embedding_config", {}) or {}
133
+ dense_features.append(
134
+ DenseFeature(
135
+ name=name,
136
+ embedding_dim=embed_cfg.get("embedding_dim"),
137
+ input_dim=embed_cfg.get("input_dim", 1),
138
+ use_embedding=embed_cfg.get("use_embedding", False),
139
+ )
140
+ )
141
+
142
+ sparse_features: List[SparseFeature] = []
143
+ for name in sparse_names:
144
+ entry = sparse_cfg.get(name, {}) or {}
145
+ proc_cfg = entry.get("processor_config", {}) or {}
146
+ embed_cfg = entry.get("embedding_config", {}) or {}
147
+ vocab_size = (
148
+ embed_cfg.get("vocab_size")
149
+ or proc_cfg.get("hash_size")
150
+ or vocab_sizes.get(name, 0)
151
+ or 1
152
+ )
153
+ sparse_features.append(
154
+ SparseFeature(
155
+ name=name,
156
+ vocab_size=int(vocab_size),
157
+ embedding_dim=embed_cfg.get("embedding_dim"),
158
+ padding_idx=embed_cfg.get("padding_idx"),
159
+ l1_reg=embed_cfg.get("l1_reg", 0.0),
160
+ l2_reg=embed_cfg.get("l2_reg", 1e-5),
161
+ trainable=embed_cfg.get("trainable", True),
162
+ )
163
+ )
164
+
165
+ sequence_features: List[SequenceFeature] = []
166
+ for name in sequence_names:
167
+ entry = sequence_cfg.get(name, {}) or {}
168
+ proc_cfg = entry.get("processor_config", {}) or {}
169
+ embed_cfg = entry.get("embedding_config", {}) or {}
170
+ vocab_size = (
171
+ embed_cfg.get("vocab_size")
172
+ or proc_cfg.get("hash_size")
173
+ or vocab_sizes.get(name, 0)
174
+ or 1
175
+ )
176
+ sequence_features.append(
177
+ SequenceFeature(
178
+ name=name,
179
+ vocab_size=int(vocab_size),
180
+ max_len=embed_cfg.get("max_len") or proc_cfg.get("max_len", 50),
181
+ embedding_dim=embed_cfg.get("embedding_dim"),
182
+ padding_idx=embed_cfg.get("padding_idx"),
183
+ combiner=embed_cfg.get("combiner", "mean"),
184
+ l1_reg=embed_cfg.get("l1_reg", 0.0),
185
+ l2_reg=embed_cfg.get("l2_reg", 1e-5),
186
+ trainable=embed_cfg.get("trainable", True),
187
+ )
188
+ )
189
+
190
+ return dense_features, sparse_features, sequence_features
191
+
192
+
193
+ def extract_feature_groups(
194
+ feature_cfg: Dict[str, Any], df_columns: List[str]
195
+ ) -> Tuple[Dict[str, List[str]], List[str]]:
196
+ """
197
+ Extract and validate feature groups from feature configuration.
198
+
199
+ Args:
200
+ feature_cfg: Feature configuration dictionary
201
+ df_columns: Available dataframe columns
202
+ """
203
+ feature_groups = feature_cfg.get("feature_groups") or {}
204
+ if not feature_groups:
205
+ return {}, []
206
+
207
+ defined = (
208
+ set((feature_cfg.get("dense") or {}).keys())
209
+ | set((feature_cfg.get("sparse") or {}).keys())
210
+ | set((feature_cfg.get("sequence") or {}).keys())
211
+ )
212
+ available_cols = set(df_columns)
213
+ resolved: Dict[str, List[str]] = {}
214
+ collected: List[str] = []
215
+
216
+ for group_name, names in feature_groups.items():
217
+ name_list = normalize_to_list(names)
218
+ filtered = []
219
+ missing_defined = [n for n in name_list if n not in defined]
220
+ missing_cols = [n for n in name_list if n not in available_cols]
221
+
222
+ if missing_defined:
223
+ print(
224
+ f"[feature_config] feature_groups.{group_name} contains features not defined in dense/sparse/sequence: {missing_defined}"
225
+ )
226
+
227
+ for n in name_list:
228
+ if n in available_cols:
229
+ if n not in filtered:
230
+ filtered.append(n)
231
+ else:
232
+ if n not in missing_cols:
233
+ missing_cols.append(n)
234
+
235
+ if missing_cols:
236
+ print(
237
+ f"[feature_config] feature_groups.{group_name} missing data columns: {missing_cols}"
238
+ )
239
+
240
+ resolved[group_name] = filtered
241
+ collected.extend(filtered)
242
+
243
+ return resolved, collected
244
+
245
+
246
+ def load_model_class(model_cfg: Dict[str, Any], base_dir: Path) -> type:
247
+ """
248
+ Load model class from configuration.
249
+
250
+ Args:
251
+ model_cfg: Model configuration dictionary
252
+ base_dir: Base directory for resolving relative paths
253
+ """
254
+
255
+ def camelize(name: str) -> str:
256
+ """Convert snake_case or kebab-case to CamelCase."""
257
+ return "".join(
258
+ part.capitalize()
259
+ for part in name.replace("_", " ").replace("-", " ").split()
260
+ )
261
+
262
+ module_path = model_cfg.get("module_path")
263
+ name = model_cfg.get("model") or model_cfg.get("name")
264
+ module_name = model_cfg.get("module") or model_cfg.get("module_name")
265
+ class_name = model_cfg.get("class_name")
266
+
267
+ # Case 1: Custom file path
268
+ if module_path:
269
+ resolved = resolve_path(module_path, base_dir)
270
+ if not resolved.exists():
271
+ raise FileNotFoundError(f"Custom model file not found: {resolved}")
272
+
273
+ spec = importlib.util.spec_from_file_location(resolved.stem, resolved)
274
+ if spec is None or spec.loader is None:
275
+ raise ImportError(f"Unable to load custom model file: {resolved}")
276
+
277
+ module = importlib.util.module_from_spec(spec)
278
+ spec.loader.exec_module(module)
279
+
280
+ if class_name and hasattr(module, class_name):
281
+ return getattr(module, class_name)
282
+
283
+ # Auto-pick first BaseModel subclass
284
+ from nextrec.basic.model import BaseModel
285
+
286
+ for attr in module.__dict__.values():
287
+ if (
288
+ isinstance(attr, type)
289
+ and issubclass(attr, BaseModel)
290
+ and attr is not BaseModel
291
+ ):
292
+ return attr
293
+
294
+ raise AttributeError(
295
+ f"No BaseModel subclass found in {resolved}, please provide class_name"
296
+ )
297
+
298
+ # Case 2: Builtin model by short name
299
+ if name and not module_name:
300
+ from nextrec.basic.model import BaseModel
301
+
302
+ candidates = [
303
+ f"nextrec.models.{name.lower()}",
304
+ f"nextrec.models.ranking.{name.lower()}",
305
+ f"nextrec.models.match.{name.lower()}",
306
+ f"nextrec.models.multi_task.{name.lower()}",
307
+ f"nextrec.models.generative.{name.lower()}",
308
+ ]
309
+ errors = []
310
+
311
+ for mod in candidates:
312
+ try:
313
+ module = importlib.import_module(mod)
314
+ cls_name = class_name or camelize(name)
315
+
316
+ if hasattr(module, cls_name):
317
+ return getattr(module, cls_name)
318
+
319
+ # Fallback: first BaseModel subclass
320
+ for attr in module.__dict__.values():
321
+ if (
322
+ isinstance(attr, type)
323
+ and issubclass(attr, BaseModel)
324
+ and attr is not BaseModel
325
+ ):
326
+ return attr
327
+
328
+ errors.append(f"{mod} missing class {cls_name}")
329
+ except Exception as exc:
330
+ errors.append(f"{mod}: {exc}")
331
+
332
+ raise ImportError(f"Unable to find model for model='{name}'. Tried: {errors}")
333
+
334
+ # Case 3: Explicit module + class
335
+ if module_name and class_name:
336
+ module = importlib.import_module(module_name)
337
+ if not hasattr(module, class_name):
338
+ raise AttributeError(f"Class {class_name} not found in {module_name}")
339
+ return getattr(module, class_name)
340
+
341
+ raise ValueError(
342
+ "model configuration must provide 'model' (builtin name), 'module_path' (custom path), or 'module'+'class_name'"
343
+ )
344
+
345
+
346
+ def build_model_instance(
347
+ model_cfg: Dict[str, Any],
348
+ model_cfg_path: Path,
349
+ dense_features: List[DenseFeature],
350
+ sparse_features: List[SparseFeature],
351
+ sequence_features: List[SequenceFeature],
352
+ target: List[str],
353
+ device: str,
354
+ ) -> Any:
355
+ """
356
+ Build model instance from configuration and feature objects.
357
+
358
+ Args:
359
+ model_cfg: Model configuration dictionary
360
+ model_cfg_path: Path to model config file (for resolving relative paths)
361
+ dense_features: List of dense feature objects
362
+ sparse_features: List of sparse feature objects
363
+ sequence_features: List of sequence feature objects
364
+ target: List of target column names
365
+ device: Device string (e.g., 'cpu', 'cuda:0')
366
+ """
367
+ dense_map = {f.name: f for f in dense_features}
368
+ sparse_map = {f.name: f for f in sparse_features}
369
+ sequence_map = {f.name: f for f in sequence_features}
370
+ feature_pool: Dict[str, Any] = {**dense_map, **sparse_map, **sequence_map}
371
+
372
+ model_cls = load_model_class(model_cfg, model_cfg_path.parent)
373
+ params_cfg = deepcopy(model_cfg.get("params") or {})
374
+ feature_groups = params_cfg.pop("feature_groups", {}) or {}
375
+ feature_bindings_cfg = (
376
+ model_cfg.get("feature_bindings")
377
+ or params_cfg.pop("feature_bindings", {})
378
+ or {}
379
+ )
380
+ sig_params = inspect.signature(model_cls.__init__).parameters
381
+
382
+ def _select(names: List[str] | None, pool: Dict[str, Any], desc: str) -> List[Any]:
383
+ """Select features from pool by names."""
384
+ if names is None:
385
+ return list(pool.values())
386
+ missing = [n for n in names if n not in feature_pool]
387
+ if missing:
388
+ raise ValueError(
389
+ f"feature_groups.{desc} contains unknown features: {missing}"
390
+ )
391
+ return [feature_pool[n] for n in names]
392
+
393
+ def accepts(name: str) -> bool:
394
+ """Check if parameter name is accepted by model __init__."""
395
+ return name in sig_params
396
+
397
+ accepts_var_kwargs = any(
398
+ param.kind == inspect.Parameter.VAR_KEYWORD for param in sig_params.values()
399
+ )
400
+
401
+ init_kwargs: Dict[str, Any] = dict(params_cfg)
402
+
403
+ # Explicit bindings (model_config.feature_bindings) take priority
404
+ for param_name, binding in feature_bindings_cfg.items():
405
+ if param_name in init_kwargs:
406
+ continue
407
+
408
+ if isinstance(binding, (list, tuple, set)):
409
+ if accepts(param_name) or accepts_var_kwargs:
410
+ init_kwargs[param_name] = _select(
411
+ list(binding), feature_pool, f"feature_bindings.{param_name}"
412
+ )
413
+ continue
414
+
415
+ if isinstance(binding, dict):
416
+ direct_features = binding.get("features") or binding.get("feature_names")
417
+ if direct_features and (accepts(param_name) or accepts_var_kwargs):
418
+ init_kwargs[param_name] = _select(
419
+ normalize_to_list(direct_features),
420
+ feature_pool,
421
+ f"feature_bindings.{param_name}",
422
+ )
423
+ continue
424
+ group_key = binding.get("group") or binding.get("group_key")
425
+ else:
426
+ group_key = binding
427
+
428
+ if group_key not in feature_groups:
429
+ print(
430
+ f"[feature_config] feature_bindings refers to unknown group '{group_key}', skipped"
431
+ )
432
+ continue
433
+
434
+ if accepts(param_name) or accepts_var_kwargs:
435
+ init_kwargs[param_name] = _select(
436
+ feature_groups[group_key], feature_pool, str(group_key)
437
+ )
438
+
439
+ # Dynamic feature groups: any key in feature_groups that matches __init__ will be filled
440
+ for group_key, names in feature_groups.items():
441
+ if accepts(str(group_key)):
442
+ init_kwargs.setdefault(
443
+ str(group_key), _select(names, feature_pool, str(group_key))
444
+ )
445
+
446
+ # Generalized mapping: match params to feature_groups by normalized names
447
+ def _normalize_group_key(key: str) -> str:
448
+ """Normalize group key by removing common suffixes."""
449
+ key = key.lower()
450
+ for suffix in ("_features", "_feature", "_feats", "_feat", "_list", "_group"):
451
+ if key.endswith(suffix):
452
+ key = key[: -len(suffix)]
453
+ return key
454
+
455
+ normalized_groups = {}
456
+ for gk in feature_groups:
457
+ norm = _normalize_group_key(gk)
458
+ normalized_groups.setdefault(norm, gk)
459
+
460
+ for param_name in sig_params:
461
+ if param_name in ("self",) or param_name in init_kwargs:
462
+ continue
463
+ norm_param = _normalize_group_key(param_name)
464
+ if norm_param in normalized_groups and (
465
+ accepts(param_name) or accepts_var_kwargs
466
+ ):
467
+ group_key = normalized_groups[norm_param]
468
+ init_kwargs[param_name] = _select(
469
+ feature_groups[group_key], feature_pool, str(group_key)
470
+ )
471
+
472
+ # Feature wiring: prefer explicit groups when provided
473
+ if accepts("dense_features"):
474
+ init_kwargs.setdefault("dense_features", dense_features)
475
+ if accepts("sparse_features"):
476
+ init_kwargs.setdefault("sparse_features", sparse_features)
477
+ if accepts("sequence_features"):
478
+ init_kwargs.setdefault("sequence_features", sequence_features)
479
+
480
+ if accepts("target"):
481
+ init_kwargs.setdefault("target", target)
482
+ if accepts("device"):
483
+ init_kwargs.setdefault("device", device)
484
+
485
+ # Pass session_id if model accepts it
486
+ if "session_id" not in init_kwargs and model_cfg.get("session_id") is not None:
487
+ if accepts("session_id") or accepts_var_kwargs:
488
+ init_kwargs["session_id"] = model_cfg.get("session_id")
489
+
490
+ return model_cls(**init_kwargs)
nextrec/utils/device.py CHANGED
@@ -2,12 +2,13 @@
2
2
  Device management utilities for NextRec
3
3
 
4
4
  Date: create on 03/12/2025
5
+ Checkpoint: edit on 06/12/2025
5
6
  Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
- import os
8
+
8
9
  import torch
9
10
  import platform
10
- import multiprocessing
11
+ import logging
11
12
 
12
13
 
13
14
  def resolve_device() -> str:
@@ -16,23 +17,62 @@ def resolve_device() -> str:
16
17
  if torch.backends.mps.is_available():
17
18
  mac_ver = platform.mac_ver()[0]
18
19
  try:
19
- major, minor = (int(x) for x in mac_ver.split(".")[:2])
20
+ major, _ = (int(x) for x in mac_ver.split(".")[:2])
20
21
  except Exception:
21
- major, minor = 0, 0
22
+ major, _ = 0, 0
22
23
  if major >= 14:
23
24
  return "mps"
24
25
  return "cpu"
25
26
 
27
+
26
28
  def get_device_info() -> dict:
27
29
  info = {
28
- 'cuda_available': torch.cuda.is_available(),
29
- 'cuda_device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
30
- 'mps_available': torch.backends.mps.is_available(),
31
- 'current_device': resolve_device(),
30
+ "cuda_available": torch.cuda.is_available(),
31
+ "cuda_device_count": (
32
+ torch.cuda.device_count() if torch.cuda.is_available() else 0
33
+ ),
34
+ "mps_available": torch.backends.mps.is_available(),
35
+ "current_device": resolve_device(),
32
36
  }
33
-
37
+
34
38
  if torch.cuda.is_available():
35
- info['cuda_device_name'] = torch.cuda.get_device_name(0)
36
- info['cuda_capability'] = torch.cuda.get_device_capability(0)
37
-
39
+ info["cuda_device_name"] = torch.cuda.get_device_name(0)
40
+ info["cuda_capability"] = torch.cuda.get_device_capability(0)
41
+
38
42
  return info
43
+
44
+
45
+ def configure_device(
46
+ distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
47
+ ) -> torch.device:
48
+ try:
49
+ device = torch.device(base_device)
50
+ except Exception:
51
+ logging.warning(
52
+ "[configure_device Warning] Invalid base_device, falling back to CPU."
53
+ )
54
+ return torch.device("cpu")
55
+
56
+ if distributed:
57
+ if device.type == "cuda":
58
+ if not torch.cuda.is_available():
59
+ logging.warning(
60
+ "[Distributed Warning] CUDA requested but unavailable. Falling back to CPU."
61
+ )
62
+ return torch.device("cpu")
63
+ if not (0 <= local_rank < torch.cuda.device_count()):
64
+ logging.warning(
65
+ f"[Distributed Warning] local_rank {local_rank} is invalid for available CUDA devices. Falling back to CPU."
66
+ )
67
+ return torch.device("cpu")
68
+ try:
69
+ torch.cuda.set_device(local_rank)
70
+ return torch.device(f"cuda:{local_rank}")
71
+ except Exception as exc:
72
+ logging.warning(
73
+ f"[Distributed Warning] Failed to set CUDA device for local_rank {local_rank}: {exc}. Falling back to CPU."
74
+ )
75
+ return torch.device("cpu")
76
+ else:
77
+ return torch.device("cpu")
78
+ return device