nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +250 -112
- nextrec/basic/loggers.py +63 -44
- nextrec/basic/metrics.py +270 -120
- nextrec/basic/model.py +1084 -402
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +492 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +273 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +69 -46
- nextrec/models/multi_task/mmoe.py +91 -53
- nextrec/models/multi_task/ple.py +117 -58
- nextrec/models/multi_task/poso.py +163 -55
- nextrec/models/multi_task/share_bottom.py +63 -36
- nextrec/models/ranking/afm.py +80 -45
- nextrec/models/ranking/autoint.py +74 -57
- nextrec/models/ranking/dcn.py +110 -48
- nextrec/models/ranking/dcn_v2.py +265 -45
- nextrec/models/ranking/deepfm.py +39 -24
- nextrec/models/ranking/dien.py +335 -146
- nextrec/models/ranking/din.py +158 -92
- nextrec/models/ranking/fibinet.py +134 -52
- nextrec/models/ranking/fm.py +68 -26
- nextrec/models/ranking/masknet.py +95 -33
- nextrec/models/ranking/pnn.py +128 -58
- nextrec/models/ranking/widedeep.py +40 -28
- nextrec/models/ranking/xdeepfm.py +67 -40
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +496 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +33 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- nextrec-0.4.3.dist-info/RECORD +69 -0
- nextrec-0.4.3.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/utils/config.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration utilities for NextRec
|
|
3
|
+
|
|
4
|
+
This module provides utilities for loading and processing configuration files,
|
|
5
|
+
including feature configuration, model configuration, and training configuration.
|
|
6
|
+
|
|
7
|
+
Date: create on 06/12/2025
|
|
8
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import importlib
|
|
14
|
+
import importlib.util
|
|
15
|
+
import inspect
|
|
16
|
+
from copy import deepcopy
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
|
19
|
+
|
|
20
|
+
from nextrec.utils.feature import normalize_to_list
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
24
|
+
from nextrec.data.preprocessor import DataProcessor
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def resolve_path(path_str: str | Path, base_dir: Path) -> Path:
|
|
28
|
+
path = Path(path_str).expanduser()
|
|
29
|
+
if path.is_absolute():
|
|
30
|
+
return path
|
|
31
|
+
# Prefer resolving relative to current working directory when the path (or its parent)
|
|
32
|
+
# already exists there; otherwise fall back to the config file's directory.
|
|
33
|
+
cwd_path = (Path.cwd() / path).resolve()
|
|
34
|
+
if cwd_path.exists() or cwd_path.parent.exists():
|
|
35
|
+
return cwd_path
|
|
36
|
+
base_dir_path = (base_dir / path).resolve()
|
|
37
|
+
if base_dir_path.exists() or base_dir_path.parent.exists():
|
|
38
|
+
return base_dir_path
|
|
39
|
+
return cwd_path
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def select_features(
|
|
43
|
+
feature_cfg: Dict[str, Any], df_columns: List[str]
|
|
44
|
+
) -> Tuple[List[str], List[str], List[str]]:
|
|
45
|
+
columns = set(df_columns)
|
|
46
|
+
|
|
47
|
+
def pick(group: str) -> List[str]:
|
|
48
|
+
cfg = feature_cfg.get(group, {}) or {}
|
|
49
|
+
names = [name for name in cfg.keys() if name in columns]
|
|
50
|
+
missing = [name for name in cfg.keys() if name not in columns]
|
|
51
|
+
if missing:
|
|
52
|
+
print(f"[feature_config] skipped missing {group} columns: {missing}")
|
|
53
|
+
return names
|
|
54
|
+
|
|
55
|
+
dense_names = pick("dense")
|
|
56
|
+
sparse_names = pick("sparse")
|
|
57
|
+
sequence_names = pick("sequence")
|
|
58
|
+
return dense_names, sparse_names, sequence_names
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def register_processor_features(
|
|
62
|
+
processor: DataProcessor,
|
|
63
|
+
feature_cfg: Dict[str, Any],
|
|
64
|
+
dense_names: List[str],
|
|
65
|
+
sparse_names: List[str],
|
|
66
|
+
sequence_names: List[str],
|
|
67
|
+
) -> None:
|
|
68
|
+
"""
|
|
69
|
+
Register features to DataProcessor based on feature configuration.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
processor: DataProcessor instance
|
|
73
|
+
feature_cfg: Feature configuration dictionary
|
|
74
|
+
dense_names: List of dense feature names
|
|
75
|
+
sparse_names: List of sparse feature names
|
|
76
|
+
sequence_names: List of sequence feature names
|
|
77
|
+
"""
|
|
78
|
+
dense_cfg = feature_cfg.get("dense", {}) or {}
|
|
79
|
+
sparse_cfg = feature_cfg.get("sparse", {}) or {}
|
|
80
|
+
sequence_cfg = feature_cfg.get("sequence", {}) or {}
|
|
81
|
+
|
|
82
|
+
for name in dense_names:
|
|
83
|
+
proc_cfg = dense_cfg.get(name, {}).get("processor_config", {}) or {}
|
|
84
|
+
processor.add_numeric_feature(
|
|
85
|
+
name,
|
|
86
|
+
scaler=proc_cfg.get("scaler", "standard"),
|
|
87
|
+
fill_na=proc_cfg.get("fill_na"),
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
for name in sparse_names:
|
|
91
|
+
proc_cfg = sparse_cfg.get(name, {}).get("processor_config", {}) or {}
|
|
92
|
+
processor.add_sparse_feature(
|
|
93
|
+
name,
|
|
94
|
+
encode_method=proc_cfg.get("encode_method", "hash"),
|
|
95
|
+
hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
|
|
96
|
+
fill_na=proc_cfg.get("fill_na", "<UNK>"),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
for name in sequence_names:
|
|
100
|
+
proc_cfg = sequence_cfg.get(name, {}).get("processor_config", {}) or {}
|
|
101
|
+
processor.add_sequence_feature(
|
|
102
|
+
name,
|
|
103
|
+
encode_method=proc_cfg.get("encode_method", "hash"),
|
|
104
|
+
hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
|
|
105
|
+
max_len=proc_cfg.get("max_len", 50),
|
|
106
|
+
pad_value=proc_cfg.get("pad_value", 0),
|
|
107
|
+
truncate=proc_cfg.get("truncate", "post"),
|
|
108
|
+
separator=proc_cfg.get("separator", ","),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def build_feature_objects(
|
|
113
|
+
processor: "DataProcessor",
|
|
114
|
+
feature_cfg: Dict[str, Any],
|
|
115
|
+
dense_names: List[str],
|
|
116
|
+
sparse_names: List[str],
|
|
117
|
+
sequence_names: List[str],
|
|
118
|
+
) -> Tuple[List["DenseFeature"], List["SparseFeature"], List["SequenceFeature"]]:
|
|
119
|
+
"""
|
|
120
|
+
Build feature objects from processor and feature configuration.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
processor: Fitted DataProcessor instance
|
|
124
|
+
feature_cfg: Feature configuration dictionary
|
|
125
|
+
dense_names: List of dense feature names
|
|
126
|
+
sparse_names: List of sparse feature names
|
|
127
|
+
sequence_names: List of sequence feature names
|
|
128
|
+
"""
|
|
129
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
130
|
+
|
|
131
|
+
dense_cfg = feature_cfg.get("dense", {}) or {}
|
|
132
|
+
sparse_cfg = feature_cfg.get("sparse", {}) or {}
|
|
133
|
+
sequence_cfg = feature_cfg.get("sequence", {}) or {}
|
|
134
|
+
vocab_sizes = processor.get_vocab_sizes()
|
|
135
|
+
|
|
136
|
+
dense_features: List[DenseFeature] = []
|
|
137
|
+
for name in dense_names:
|
|
138
|
+
embed_cfg = dense_cfg.get(name, {}).get("embedding_config", {}) or {}
|
|
139
|
+
dense_features.append(
|
|
140
|
+
DenseFeature(
|
|
141
|
+
name=name,
|
|
142
|
+
embedding_dim=embed_cfg.get("embedding_dim"),
|
|
143
|
+
input_dim=embed_cfg.get("input_dim", 1),
|
|
144
|
+
use_embedding=embed_cfg.get("use_embedding", False),
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
sparse_features: List[SparseFeature] = []
|
|
149
|
+
for name in sparse_names:
|
|
150
|
+
entry = sparse_cfg.get(name, {}) or {}
|
|
151
|
+
proc_cfg = entry.get("processor_config", {}) or {}
|
|
152
|
+
embed_cfg = entry.get("embedding_config", {}) or {}
|
|
153
|
+
vocab_size = (
|
|
154
|
+
embed_cfg.get("vocab_size")
|
|
155
|
+
or proc_cfg.get("hash_size")
|
|
156
|
+
or vocab_sizes.get(name, 0)
|
|
157
|
+
or 1
|
|
158
|
+
)
|
|
159
|
+
sparse_features.append(
|
|
160
|
+
SparseFeature(
|
|
161
|
+
name=name,
|
|
162
|
+
vocab_size=int(vocab_size),
|
|
163
|
+
embedding_dim=embed_cfg.get("embedding_dim"),
|
|
164
|
+
padding_idx=embed_cfg.get("padding_idx"),
|
|
165
|
+
l1_reg=embed_cfg.get("l1_reg", 0.0),
|
|
166
|
+
l2_reg=embed_cfg.get("l2_reg", 1e-5),
|
|
167
|
+
trainable=embed_cfg.get("trainable", True),
|
|
168
|
+
)
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
sequence_features: List[SequenceFeature] = []
|
|
172
|
+
for name in sequence_names:
|
|
173
|
+
entry = sequence_cfg.get(name, {}) or {}
|
|
174
|
+
proc_cfg = entry.get("processor_config", {}) or {}
|
|
175
|
+
embed_cfg = entry.get("embedding_config", {}) or {}
|
|
176
|
+
vocab_size = (
|
|
177
|
+
embed_cfg.get("vocab_size")
|
|
178
|
+
or proc_cfg.get("hash_size")
|
|
179
|
+
or vocab_sizes.get(name, 0)
|
|
180
|
+
or 1
|
|
181
|
+
)
|
|
182
|
+
sequence_features.append(
|
|
183
|
+
SequenceFeature(
|
|
184
|
+
name=name,
|
|
185
|
+
vocab_size=int(vocab_size),
|
|
186
|
+
max_len=embed_cfg.get("max_len") or proc_cfg.get("max_len", 50),
|
|
187
|
+
embedding_dim=embed_cfg.get("embedding_dim"),
|
|
188
|
+
padding_idx=embed_cfg.get("padding_idx"),
|
|
189
|
+
combiner=embed_cfg.get("combiner", "mean"),
|
|
190
|
+
l1_reg=embed_cfg.get("l1_reg", 0.0),
|
|
191
|
+
l2_reg=embed_cfg.get("l2_reg", 1e-5),
|
|
192
|
+
trainable=embed_cfg.get("trainable", True),
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return dense_features, sparse_features, sequence_features
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def extract_feature_groups(
|
|
200
|
+
feature_cfg: Dict[str, Any], df_columns: List[str]
|
|
201
|
+
) -> Tuple[Dict[str, List[str]], List[str]]:
|
|
202
|
+
"""
|
|
203
|
+
Extract and validate feature groups from feature configuration.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
feature_cfg: Feature configuration dictionary
|
|
207
|
+
df_columns: Available dataframe columns
|
|
208
|
+
"""
|
|
209
|
+
feature_groups = feature_cfg.get("feature_groups") or {}
|
|
210
|
+
if not feature_groups:
|
|
211
|
+
return {}, []
|
|
212
|
+
|
|
213
|
+
defined = (
|
|
214
|
+
set((feature_cfg.get("dense") or {}).keys())
|
|
215
|
+
| set((feature_cfg.get("sparse") or {}).keys())
|
|
216
|
+
| set((feature_cfg.get("sequence") or {}).keys())
|
|
217
|
+
)
|
|
218
|
+
available_cols = set(df_columns)
|
|
219
|
+
resolved: Dict[str, List[str]] = {}
|
|
220
|
+
collected: List[str] = []
|
|
221
|
+
|
|
222
|
+
for group_name, names in feature_groups.items():
|
|
223
|
+
name_list = normalize_to_list(names)
|
|
224
|
+
filtered = []
|
|
225
|
+
missing_defined = [n for n in name_list if n not in defined]
|
|
226
|
+
missing_cols = [n for n in name_list if n not in available_cols]
|
|
227
|
+
|
|
228
|
+
if missing_defined:
|
|
229
|
+
print(
|
|
230
|
+
f"[feature_config] feature_groups.{group_name} contains features not defined in dense/sparse/sequence: {missing_defined}"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
for n in name_list:
|
|
234
|
+
if n in available_cols:
|
|
235
|
+
if n not in filtered:
|
|
236
|
+
filtered.append(n)
|
|
237
|
+
else:
|
|
238
|
+
if n not in missing_cols:
|
|
239
|
+
missing_cols.append(n)
|
|
240
|
+
|
|
241
|
+
if missing_cols:
|
|
242
|
+
print(
|
|
243
|
+
f"[feature_config] feature_groups.{group_name} missing data columns: {missing_cols}"
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
resolved[group_name] = filtered
|
|
247
|
+
collected.extend(filtered)
|
|
248
|
+
|
|
249
|
+
return resolved, collected
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def load_model_class(model_cfg: Dict[str, Any], base_dir: Path) -> type:
|
|
253
|
+
"""
|
|
254
|
+
Load model class from configuration.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
model_cfg: Model configuration dictionary
|
|
258
|
+
base_dir: Base directory for resolving relative paths
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
def camelize(name: str) -> str:
|
|
262
|
+
"""Convert snake_case or kebab-case to CamelCase."""
|
|
263
|
+
return "".join(
|
|
264
|
+
part.capitalize()
|
|
265
|
+
for part in name.replace("_", " ").replace("-", " ").split()
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
module_path = model_cfg.get("module_path")
|
|
269
|
+
name = model_cfg.get("model") or model_cfg.get("name")
|
|
270
|
+
module_name = model_cfg.get("module") or model_cfg.get("module_name")
|
|
271
|
+
class_name = model_cfg.get("class_name")
|
|
272
|
+
|
|
273
|
+
# Case 1: Custom file path
|
|
274
|
+
if module_path:
|
|
275
|
+
resolved = resolve_path(module_path, base_dir)
|
|
276
|
+
if not resolved.exists():
|
|
277
|
+
raise FileNotFoundError(f"Custom model file not found: {resolved}")
|
|
278
|
+
|
|
279
|
+
spec = importlib.util.spec_from_file_location(resolved.stem, resolved)
|
|
280
|
+
if spec is None or spec.loader is None:
|
|
281
|
+
raise ImportError(f"Unable to load custom model file: {resolved}")
|
|
282
|
+
|
|
283
|
+
module = importlib.util.module_from_spec(spec)
|
|
284
|
+
spec.loader.exec_module(module)
|
|
285
|
+
|
|
286
|
+
if class_name and hasattr(module, class_name):
|
|
287
|
+
return getattr(module, class_name)
|
|
288
|
+
|
|
289
|
+
# Auto-pick first BaseModel subclass
|
|
290
|
+
from nextrec.basic.model import BaseModel
|
|
291
|
+
|
|
292
|
+
for attr in module.__dict__.values():
|
|
293
|
+
if (
|
|
294
|
+
isinstance(attr, type)
|
|
295
|
+
and issubclass(attr, BaseModel)
|
|
296
|
+
and attr is not BaseModel
|
|
297
|
+
):
|
|
298
|
+
return attr
|
|
299
|
+
|
|
300
|
+
raise AttributeError(
|
|
301
|
+
f"No BaseModel subclass found in {resolved}, please provide class_name"
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Case 2: Builtin model by short name
|
|
305
|
+
if name and not module_name:
|
|
306
|
+
from nextrec.basic.model import BaseModel
|
|
307
|
+
|
|
308
|
+
candidates = [
|
|
309
|
+
f"nextrec.models.{name.lower()}",
|
|
310
|
+
f"nextrec.models.ranking.{name.lower()}",
|
|
311
|
+
f"nextrec.models.match.{name.lower()}",
|
|
312
|
+
f"nextrec.models.multi_task.{name.lower()}",
|
|
313
|
+
f"nextrec.models.generative.{name.lower()}",
|
|
314
|
+
]
|
|
315
|
+
errors = []
|
|
316
|
+
|
|
317
|
+
for mod in candidates:
|
|
318
|
+
try:
|
|
319
|
+
module = importlib.import_module(mod)
|
|
320
|
+
cls_name = class_name or camelize(name)
|
|
321
|
+
|
|
322
|
+
if hasattr(module, cls_name):
|
|
323
|
+
return getattr(module, cls_name)
|
|
324
|
+
|
|
325
|
+
# Fallback: first BaseModel subclass
|
|
326
|
+
for attr in module.__dict__.values():
|
|
327
|
+
if (
|
|
328
|
+
isinstance(attr, type)
|
|
329
|
+
and issubclass(attr, BaseModel)
|
|
330
|
+
and attr is not BaseModel
|
|
331
|
+
):
|
|
332
|
+
return attr
|
|
333
|
+
|
|
334
|
+
errors.append(f"{mod} missing class {cls_name}")
|
|
335
|
+
except Exception as exc:
|
|
336
|
+
errors.append(f"{mod}: {exc}")
|
|
337
|
+
|
|
338
|
+
raise ImportError(f"Unable to find model for model='{name}'. Tried: {errors}")
|
|
339
|
+
|
|
340
|
+
# Case 3: Explicit module + class
|
|
341
|
+
if module_name and class_name:
|
|
342
|
+
module = importlib.import_module(module_name)
|
|
343
|
+
if not hasattr(module, class_name):
|
|
344
|
+
raise AttributeError(f"Class {class_name} not found in {module_name}")
|
|
345
|
+
return getattr(module, class_name)
|
|
346
|
+
|
|
347
|
+
raise ValueError(
|
|
348
|
+
"model configuration must provide 'model' (builtin name), 'module_path' (custom path), or 'module'+'class_name'"
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def build_model_instance(
|
|
353
|
+
model_cfg: Dict[str, Any],
|
|
354
|
+
model_cfg_path: Path,
|
|
355
|
+
dense_features: List[DenseFeature],
|
|
356
|
+
sparse_features: List[SparseFeature],
|
|
357
|
+
sequence_features: List[SequenceFeature],
|
|
358
|
+
target: List[str],
|
|
359
|
+
device: str,
|
|
360
|
+
) -> Any:
|
|
361
|
+
"""
|
|
362
|
+
Build model instance from configuration and feature objects.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
model_cfg: Model configuration dictionary
|
|
366
|
+
model_cfg_path: Path to model config file (for resolving relative paths)
|
|
367
|
+
dense_features: List of dense feature objects
|
|
368
|
+
sparse_features: List of sparse feature objects
|
|
369
|
+
sequence_features: List of sequence feature objects
|
|
370
|
+
target: List of target column names
|
|
371
|
+
device: Device string (e.g., 'cpu', 'cuda:0')
|
|
372
|
+
"""
|
|
373
|
+
dense_map = {f.name: f for f in dense_features}
|
|
374
|
+
sparse_map = {f.name: f for f in sparse_features}
|
|
375
|
+
sequence_map = {f.name: f for f in sequence_features}
|
|
376
|
+
feature_pool: Dict[str, Any] = {**dense_map, **sparse_map, **sequence_map}
|
|
377
|
+
|
|
378
|
+
model_cls = load_model_class(model_cfg, model_cfg_path.parent)
|
|
379
|
+
params_cfg = deepcopy(model_cfg.get("params") or {})
|
|
380
|
+
feature_groups = params_cfg.pop("feature_groups", {}) or {}
|
|
381
|
+
feature_bindings_cfg = (
|
|
382
|
+
model_cfg.get("feature_bindings")
|
|
383
|
+
or params_cfg.pop("feature_bindings", {})
|
|
384
|
+
or {}
|
|
385
|
+
)
|
|
386
|
+
sig_params = inspect.signature(model_cls.__init__).parameters
|
|
387
|
+
|
|
388
|
+
def _select(names: List[str] | None, pool: Dict[str, Any], desc: str) -> List[Any]:
|
|
389
|
+
"""Select features from pool by names."""
|
|
390
|
+
if names is None:
|
|
391
|
+
return list(pool.values())
|
|
392
|
+
missing = [n for n in names if n not in feature_pool]
|
|
393
|
+
if missing:
|
|
394
|
+
raise ValueError(
|
|
395
|
+
f"feature_groups.{desc} contains unknown features: {missing}"
|
|
396
|
+
)
|
|
397
|
+
return [feature_pool[n] for n in names]
|
|
398
|
+
|
|
399
|
+
def accepts(name: str) -> bool:
|
|
400
|
+
"""Check if parameter name is accepted by model __init__."""
|
|
401
|
+
return name in sig_params
|
|
402
|
+
|
|
403
|
+
accepts_var_kwargs = any(
|
|
404
|
+
param.kind == inspect.Parameter.VAR_KEYWORD for param in sig_params.values()
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
init_kwargs: Dict[str, Any] = dict(params_cfg)
|
|
408
|
+
|
|
409
|
+
# Explicit bindings (model_config.feature_bindings) take priority
|
|
410
|
+
for param_name, binding in feature_bindings_cfg.items():
|
|
411
|
+
if param_name in init_kwargs:
|
|
412
|
+
continue
|
|
413
|
+
|
|
414
|
+
if isinstance(binding, (list, tuple, set)):
|
|
415
|
+
if accepts(param_name) or accepts_var_kwargs:
|
|
416
|
+
init_kwargs[param_name] = _select(
|
|
417
|
+
list(binding), feature_pool, f"feature_bindings.{param_name}"
|
|
418
|
+
)
|
|
419
|
+
continue
|
|
420
|
+
|
|
421
|
+
if isinstance(binding, dict):
|
|
422
|
+
direct_features = binding.get("features") or binding.get("feature_names")
|
|
423
|
+
if direct_features and (accepts(param_name) or accepts_var_kwargs):
|
|
424
|
+
init_kwargs[param_name] = _select(
|
|
425
|
+
normalize_to_list(direct_features),
|
|
426
|
+
feature_pool,
|
|
427
|
+
f"feature_bindings.{param_name}",
|
|
428
|
+
)
|
|
429
|
+
continue
|
|
430
|
+
group_key = binding.get("group") or binding.get("group_key")
|
|
431
|
+
else:
|
|
432
|
+
group_key = binding
|
|
433
|
+
|
|
434
|
+
if group_key not in feature_groups:
|
|
435
|
+
print(
|
|
436
|
+
f"[feature_config] feature_bindings refers to unknown group '{group_key}', skipped"
|
|
437
|
+
)
|
|
438
|
+
continue
|
|
439
|
+
|
|
440
|
+
if accepts(param_name) or accepts_var_kwargs:
|
|
441
|
+
init_kwargs[param_name] = _select(
|
|
442
|
+
feature_groups[group_key], feature_pool, str(group_key)
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# Dynamic feature groups: any key in feature_groups that matches __init__ will be filled
|
|
446
|
+
for group_key, names in feature_groups.items():
|
|
447
|
+
if accepts(str(group_key)):
|
|
448
|
+
init_kwargs.setdefault(
|
|
449
|
+
str(group_key), _select(names, feature_pool, str(group_key))
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Generalized mapping: match params to feature_groups by normalized names
|
|
453
|
+
def _normalize_group_key(key: str) -> str:
|
|
454
|
+
"""Normalize group key by removing common suffixes."""
|
|
455
|
+
key = key.lower()
|
|
456
|
+
for suffix in ("_features", "_feature", "_feats", "_feat", "_list", "_group"):
|
|
457
|
+
if key.endswith(suffix):
|
|
458
|
+
key = key[: -len(suffix)]
|
|
459
|
+
return key
|
|
460
|
+
|
|
461
|
+
normalized_groups = {}
|
|
462
|
+
for gk in feature_groups:
|
|
463
|
+
norm = _normalize_group_key(gk)
|
|
464
|
+
normalized_groups.setdefault(norm, gk)
|
|
465
|
+
|
|
466
|
+
for param_name in sig_params:
|
|
467
|
+
if param_name in ("self",) or param_name in init_kwargs:
|
|
468
|
+
continue
|
|
469
|
+
norm_param = _normalize_group_key(param_name)
|
|
470
|
+
if norm_param in normalized_groups and (
|
|
471
|
+
accepts(param_name) or accepts_var_kwargs
|
|
472
|
+
):
|
|
473
|
+
group_key = normalized_groups[norm_param]
|
|
474
|
+
init_kwargs[param_name] = _select(
|
|
475
|
+
feature_groups[group_key], feature_pool, str(group_key)
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
# Feature wiring: prefer explicit groups when provided
|
|
479
|
+
if accepts("dense_features"):
|
|
480
|
+
init_kwargs.setdefault("dense_features", dense_features)
|
|
481
|
+
if accepts("sparse_features"):
|
|
482
|
+
init_kwargs.setdefault("sparse_features", sparse_features)
|
|
483
|
+
if accepts("sequence_features"):
|
|
484
|
+
init_kwargs.setdefault("sequence_features", sequence_features)
|
|
485
|
+
|
|
486
|
+
if accepts("target"):
|
|
487
|
+
init_kwargs.setdefault("target", target)
|
|
488
|
+
if accepts("device"):
|
|
489
|
+
init_kwargs.setdefault("device", device)
|
|
490
|
+
|
|
491
|
+
# Pass session_id if model accepts it
|
|
492
|
+
if "session_id" not in init_kwargs and model_cfg.get("session_id") is not None:
|
|
493
|
+
if accepts("session_id") or accepts_var_kwargs:
|
|
494
|
+
init_kwargs["session_id"] = model_cfg.get("session_id")
|
|
495
|
+
|
|
496
|
+
return model_cls(**init_kwargs)
|
nextrec/utils/device.py
CHANGED
|
@@ -2,13 +2,13 @@
|
|
|
2
2
|
Device management utilities for NextRec
|
|
3
3
|
|
|
4
4
|
Date: create on 03/12/2025
|
|
5
|
+
Checkpoint: edit on 06/12/2025
|
|
5
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
|
-
|
|
8
|
+
|
|
8
9
|
import torch
|
|
9
10
|
import platform
|
|
10
11
|
import logging
|
|
11
|
-
import multiprocessing
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def resolve_device() -> str:
|
|
@@ -17,52 +17,62 @@ def resolve_device() -> str:
|
|
|
17
17
|
if torch.backends.mps.is_available():
|
|
18
18
|
mac_ver = platform.mac_ver()[0]
|
|
19
19
|
try:
|
|
20
|
-
major,
|
|
20
|
+
major, _ = (int(x) for x in mac_ver.split(".")[:2])
|
|
21
21
|
except Exception:
|
|
22
|
-
major,
|
|
22
|
+
major, _ = 0, 0
|
|
23
23
|
if major >= 14:
|
|
24
24
|
return "mps"
|
|
25
25
|
return "cpu"
|
|
26
26
|
|
|
27
|
+
|
|
27
28
|
def get_device_info() -> dict:
|
|
28
29
|
info = {
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
30
|
+
"cuda_available": torch.cuda.is_available(),
|
|
31
|
+
"cuda_device_count": (
|
|
32
|
+
torch.cuda.device_count() if torch.cuda.is_available() else 0
|
|
33
|
+
),
|
|
34
|
+
"mps_available": torch.backends.mps.is_available(),
|
|
35
|
+
"current_device": resolve_device(),
|
|
33
36
|
}
|
|
34
|
-
|
|
37
|
+
|
|
35
38
|
if torch.cuda.is_available():
|
|
36
|
-
info[
|
|
37
|
-
info[
|
|
38
|
-
|
|
39
|
+
info["cuda_device_name"] = torch.cuda.get_device_name(0)
|
|
40
|
+
info["cuda_capability"] = torch.cuda.get_device_capability(0)
|
|
41
|
+
|
|
39
42
|
return info
|
|
40
43
|
|
|
44
|
+
|
|
41
45
|
def configure_device(
|
|
42
|
-
distributed: bool,
|
|
43
|
-
local_rank: int,
|
|
44
|
-
base_device: torch.device | str = "cpu"
|
|
46
|
+
distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
|
|
45
47
|
) -> torch.device:
|
|
46
48
|
try:
|
|
47
49
|
device = torch.device(base_device)
|
|
48
50
|
except Exception:
|
|
49
|
-
logging.warning(
|
|
51
|
+
logging.warning(
|
|
52
|
+
"[configure_device Warning] Invalid base_device, falling back to CPU."
|
|
53
|
+
)
|
|
50
54
|
return torch.device("cpu")
|
|
51
55
|
|
|
52
56
|
if distributed:
|
|
53
57
|
if device.type == "cuda":
|
|
54
58
|
if not torch.cuda.is_available():
|
|
55
|
-
logging.warning(
|
|
59
|
+
logging.warning(
|
|
60
|
+
"[Distributed Warning] CUDA requested but unavailable. Falling back to CPU."
|
|
61
|
+
)
|
|
56
62
|
return torch.device("cpu")
|
|
57
63
|
if not (0 <= local_rank < torch.cuda.device_count()):
|
|
58
|
-
logging.warning(
|
|
64
|
+
logging.warning(
|
|
65
|
+
f"[Distributed Warning] local_rank {local_rank} is invalid for available CUDA devices. Falling back to CPU."
|
|
66
|
+
)
|
|
59
67
|
return torch.device("cpu")
|
|
60
68
|
try:
|
|
61
69
|
torch.cuda.set_device(local_rank)
|
|
62
70
|
return torch.device(f"cuda:{local_rank}")
|
|
63
71
|
except Exception as exc:
|
|
64
|
-
logging.warning(
|
|
72
|
+
logging.warning(
|
|
73
|
+
f"[Distributed Warning] Failed to set CUDA device for local_rank {local_rank}: {exc}. Falling back to CPU."
|
|
74
|
+
)
|
|
65
75
|
return torch.device("cpu")
|
|
66
76
|
else:
|
|
67
77
|
return torch.device("cpu")
|
|
68
|
-
return device
|
|
78
|
+
return device
|
nextrec/utils/distributed.py
CHANGED
|
@@ -15,10 +15,13 @@ from torch.utils.data import DataLoader, IterableDataset
|
|
|
15
15
|
from torch.utils.data.distributed import DistributedSampler
|
|
16
16
|
from nextrec.basic.loggers import colorize
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
|
|
19
|
+
def init_process_group(
|
|
20
|
+
distributed: bool, rank: int, world_size: int, device_id: int | None = None
|
|
21
|
+
) -> None:
|
|
19
22
|
"""
|
|
20
23
|
initialize distributed process group for multi-GPU training.
|
|
21
|
-
|
|
24
|
+
|
|
22
25
|
Args:
|
|
23
26
|
distributed: whether to enable distributed training
|
|
24
27
|
rank: global rank of the current process
|
|
@@ -29,7 +32,10 @@ def init_process_group(distributed: bool, rank: int, world_size: int, device_id:
|
|
|
29
32
|
backend = "nccl" if device_id is not None else "gloo"
|
|
30
33
|
if backend == "nccl":
|
|
31
34
|
torch.cuda.set_device(device_id)
|
|
32
|
-
dist.init_process_group(
|
|
35
|
+
dist.init_process_group(
|
|
36
|
+
backend=backend, init_method="env://", rank=rank, world_size=world_size
|
|
37
|
+
)
|
|
38
|
+
|
|
33
39
|
|
|
34
40
|
def gather_numpy(self, array: np.ndarray | None) -> np.ndarray | None:
|
|
35
41
|
"""
|
|
@@ -53,6 +59,7 @@ def gather_numpy(self, array: np.ndarray | None) -> np.ndarray | None:
|
|
|
53
59
|
return None
|
|
54
60
|
return np.concatenate(pieces, axis=0)
|
|
55
61
|
|
|
62
|
+
|
|
56
63
|
def add_distributed_sampler(
|
|
57
64
|
loader: DataLoader,
|
|
58
65
|
distributed: bool,
|
|
@@ -64,7 +71,7 @@ def add_distributed_sampler(
|
|
|
64
71
|
is_main_process: bool = False,
|
|
65
72
|
) -> tuple[DataLoader, DistributedSampler | None]:
|
|
66
73
|
"""
|
|
67
|
-
add distributedsampler to a dataloader, this for distributed training
|
|
74
|
+
add distributedsampler to a dataloader, this for distributed training
|
|
68
75
|
when each device has its own dataloader
|
|
69
76
|
"""
|
|
70
77
|
# early return if not distributed
|
|
@@ -78,11 +85,24 @@ def add_distributed_sampler(
|
|
|
78
85
|
return loader, None
|
|
79
86
|
if isinstance(dataset, IterableDataset):
|
|
80
87
|
if is_main_process:
|
|
81
|
-
logging.info(
|
|
88
|
+
logging.info(
|
|
89
|
+
colorize(
|
|
90
|
+
"[Distributed Info] Iterable/streaming DataLoader provided; DistributedSampler is skipped. Ensure dataset handles sharding per rank.",
|
|
91
|
+
color="yellow",
|
|
92
|
+
)
|
|
93
|
+
)
|
|
82
94
|
return loader, None
|
|
83
|
-
sampler = DistributedSampler(
|
|
95
|
+
sampler = DistributedSampler(
|
|
96
|
+
dataset,
|
|
97
|
+
num_replicas=world_size,
|
|
98
|
+
rank=rank,
|
|
99
|
+
shuffle=shuffle,
|
|
100
|
+
drop_last=drop_last,
|
|
101
|
+
)
|
|
84
102
|
loader_kwargs = {
|
|
85
|
-
"batch_size":
|
|
103
|
+
"batch_size": (
|
|
104
|
+
loader.batch_size if loader.batch_size is not None else default_batch_size
|
|
105
|
+
),
|
|
86
106
|
"shuffle": False,
|
|
87
107
|
"sampler": sampler,
|
|
88
108
|
"num_workers": loader.num_workers,
|
|
@@ -104,11 +124,18 @@ def add_distributed_sampler(
|
|
|
104
124
|
if generator is not None:
|
|
105
125
|
loader_kwargs["generator"] = generator
|
|
106
126
|
if loader.num_workers > 0:
|
|
107
|
-
loader_kwargs["persistent_workers"] = getattr(
|
|
127
|
+
loader_kwargs["persistent_workers"] = getattr(
|
|
128
|
+
loader, "persistent_workers", False
|
|
129
|
+
)
|
|
108
130
|
prefetch_factor = getattr(loader, "prefetch_factor", None)
|
|
109
131
|
if prefetch_factor is not None:
|
|
110
132
|
loader_kwargs["prefetch_factor"] = prefetch_factor
|
|
111
133
|
distributed_loader = DataLoader(dataset, **loader_kwargs)
|
|
112
134
|
if is_main_process:
|
|
113
|
-
logging.info(
|
|
135
|
+
logging.info(
|
|
136
|
+
colorize(
|
|
137
|
+
"[Distributed Info] Attached DistributedSampler to provided DataLoader",
|
|
138
|
+
color="cyan",
|
|
139
|
+
)
|
|
140
|
+
)
|
|
114
141
|
return distributed_loader, sampler
|