nextrec 0.4.24__py3-none-any.whl → 0.4.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/asserts.py +72 -0
- nextrec/basic/loggers.py +18 -1
- nextrec/basic/model.py +191 -71
- nextrec/basic/summary.py +58 -0
- nextrec/cli.py +13 -0
- nextrec/data/data_processing.py +3 -9
- nextrec/data/dataloader.py +25 -2
- nextrec/data/preprocessor.py +283 -36
- nextrec/models/multi_task/[pre]aitm.py +173 -0
- nextrec/models/multi_task/[pre]snr_trans.py +232 -0
- nextrec/models/multi_task/[pre]star.py +192 -0
- nextrec/models/multi_task/apg.py +330 -0
- nextrec/models/multi_task/cross_stitch.py +229 -0
- nextrec/models/multi_task/escm.py +290 -0
- nextrec/models/multi_task/esmm.py +8 -21
- nextrec/models/multi_task/hmoe.py +203 -0
- nextrec/models/multi_task/mmoe.py +20 -28
- nextrec/models/multi_task/pepnet.py +68 -66
- nextrec/models/multi_task/ple.py +30 -44
- nextrec/models/multi_task/poso.py +13 -22
- nextrec/models/multi_task/share_bottom.py +14 -25
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -4
- nextrec/models/ranking/dcn.py +2 -3
- nextrec/models/ranking/dcn_v2.py +2 -3
- nextrec/models/ranking/deepfm.py +2 -3
- nextrec/models/ranking/dien.py +7 -9
- nextrec/models/ranking/din.py +8 -10
- nextrec/models/ranking/eulernet.py +1 -2
- nextrec/models/ranking/ffm.py +1 -2
- nextrec/models/ranking/fibinet.py +2 -3
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/lr.py +1 -1
- nextrec/models/ranking/masknet.py +1 -2
- nextrec/models/ranking/pnn.py +1 -2
- nextrec/models/ranking/widedeep.py +2 -3
- nextrec/models/ranking/xdeepfm.py +2 -4
- nextrec/models/representation/rqvae.py +4 -4
- nextrec/models/retrieval/dssm.py +18 -26
- nextrec/models/retrieval/dssm_v2.py +15 -22
- nextrec/models/retrieval/mind.py +9 -15
- nextrec/models/retrieval/sdm.py +36 -33
- nextrec/models/retrieval/youtube_dnn.py +16 -24
- nextrec/models/sequential/hstu.py +2 -2
- nextrec/utils/__init__.py +5 -1
- nextrec/utils/config.py +2 -0
- nextrec/utils/model.py +16 -77
- nextrec/utils/torch_utils.py +11 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
- nextrec-0.4.27.dist-info/RECORD +90 -0
- nextrec/models/multi_task/aitm.py +0 -0
- nextrec/models/multi_task/snr_trans.py +0 -0
- nextrec-0.4.24.dist-info/RECORD +0 -86
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/summary.py
CHANGED
|
@@ -48,6 +48,27 @@ class SummarySet:
|
|
|
48
48
|
checkpoint_path: str
|
|
49
49
|
train_data_summary: dict[str, Any] | None
|
|
50
50
|
valid_data_summary: dict[str, Any] | None
|
|
51
|
+
note: str | None
|
|
52
|
+
|
|
53
|
+
def collect_dataloader_summary(self, data_loader: DataLoader | None):
|
|
54
|
+
if data_loader is None:
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
summary = {
|
|
58
|
+
"batch_size": data_loader.batch_size,
|
|
59
|
+
"num_workers": data_loader.num_workers,
|
|
60
|
+
"pin_memory": data_loader.pin_memory,
|
|
61
|
+
"persistent_workers": data_loader.persistent_workers,
|
|
62
|
+
}
|
|
63
|
+
prefetch_factor = getattr(data_loader, "prefetch_factor", None)
|
|
64
|
+
if prefetch_factor is not None:
|
|
65
|
+
summary["prefetch_factor"] = prefetch_factor
|
|
66
|
+
|
|
67
|
+
sampler = getattr(data_loader, "sampler", None)
|
|
68
|
+
if sampler is not None:
|
|
69
|
+
summary["sampler"] = sampler.__class__.__name__
|
|
70
|
+
|
|
71
|
+
return summary or None
|
|
51
72
|
|
|
52
73
|
def build_data_summary(
|
|
53
74
|
self, data: Any, data_loader: DataLoader | None, sample_key: str
|
|
@@ -66,6 +87,10 @@ class SummarySet:
|
|
|
66
87
|
if train_size is not None:
|
|
67
88
|
summary[sample_key] = int(train_size)
|
|
68
89
|
|
|
90
|
+
dataloader_summary = self.collect_dataloader_summary(data_loader)
|
|
91
|
+
if dataloader_summary:
|
|
92
|
+
summary["dataloader"] = dataloader_summary
|
|
93
|
+
|
|
69
94
|
if labels:
|
|
70
95
|
task_types = list(self.task) if isinstance(self.task, list) else [self.task]
|
|
71
96
|
if len(task_types) != len(self.target_columns):
|
|
@@ -321,6 +346,7 @@ class SummarySet:
|
|
|
321
346
|
logger.info(f" Session ID: {self.session_id}")
|
|
322
347
|
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
323
348
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
349
|
+
logger.info(f" Note: {self.note}")
|
|
324
350
|
|
|
325
351
|
if "Data Summary" in selected_sections and (
|
|
326
352
|
self.train_data_summary or self.valid_data_summary
|
|
@@ -341,6 +367,22 @@ class SummarySet:
|
|
|
341
367
|
for label, value in lines:
|
|
342
368
|
logger.info(f" {format_kv(label, value)}")
|
|
343
369
|
|
|
370
|
+
dataloader_info = self.train_data_summary.get("dataloader")
|
|
371
|
+
if isinstance(dataloader_info, dict):
|
|
372
|
+
logger.info("Train DataLoader:")
|
|
373
|
+
for key in (
|
|
374
|
+
"batch_size",
|
|
375
|
+
"num_workers",
|
|
376
|
+
"pin_memory",
|
|
377
|
+
"persistent_workers",
|
|
378
|
+
"sampler",
|
|
379
|
+
):
|
|
380
|
+
if key in dataloader_info:
|
|
381
|
+
label = key.replace("_", " ").title()
|
|
382
|
+
logger.info(
|
|
383
|
+
format_kv(label, dataloader_info[key], indent=2)
|
|
384
|
+
)
|
|
385
|
+
|
|
344
386
|
if self.valid_data_summary:
|
|
345
387
|
if self.train_data_summary:
|
|
346
388
|
logger.info("")
|
|
@@ -355,3 +397,19 @@ class SummarySet:
|
|
|
355
397
|
logger.info(f"{target_name}:")
|
|
356
398
|
for label, value in lines:
|
|
357
399
|
logger.info(f" {format_kv(label, value)}")
|
|
400
|
+
|
|
401
|
+
dataloader_info = self.valid_data_summary.get("dataloader")
|
|
402
|
+
if isinstance(dataloader_info, dict):
|
|
403
|
+
logger.info("Valid DataLoader:")
|
|
404
|
+
for key in (
|
|
405
|
+
"batch_size",
|
|
406
|
+
"num_workers",
|
|
407
|
+
"pin_memory",
|
|
408
|
+
"persistent_workers",
|
|
409
|
+
"sampler",
|
|
410
|
+
):
|
|
411
|
+
if key in dataloader_info:
|
|
412
|
+
label = key.replace("_", " ").title()
|
|
413
|
+
logger.info(
|
|
414
|
+
format_kv(label, dataloader_info[key], indent=2)
|
|
415
|
+
)
|
nextrec/cli.py
CHANGED
|
@@ -320,6 +320,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
320
320
|
streaming=True,
|
|
321
321
|
chunk_size=dataloader_chunk_size,
|
|
322
322
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
323
|
+
prefetch_factor=dataloader_cfg.get("prefetch_factor"),
|
|
323
324
|
)
|
|
324
325
|
valid_loader = None
|
|
325
326
|
if val_data_path:
|
|
@@ -331,6 +332,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
331
332
|
streaming=True,
|
|
332
333
|
chunk_size=dataloader_chunk_size,
|
|
333
334
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
335
|
+
prefetch_factor=dataloader_cfg.get("prefetch_factor"),
|
|
334
336
|
)
|
|
335
337
|
elif streaming_valid_files:
|
|
336
338
|
valid_loader = dataloader.create_dataloader(
|
|
@@ -340,6 +342,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
340
342
|
streaming=True,
|
|
341
343
|
chunk_size=dataloader_chunk_size,
|
|
342
344
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
345
|
+
prefetch_factor=dataloader_cfg.get("prefetch_factor"),
|
|
343
346
|
)
|
|
344
347
|
else:
|
|
345
348
|
train_loader = dataloader.create_dataloader(
|
|
@@ -347,12 +350,14 @@ def train_model(train_config_path: str) -> None:
|
|
|
347
350
|
batch_size=dataloader_cfg.get("train_batch_size", 512),
|
|
348
351
|
shuffle=dataloader_cfg.get("train_shuffle", True),
|
|
349
352
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
353
|
+
prefetch_factor=dataloader_cfg.get("prefetch_factor"),
|
|
350
354
|
)
|
|
351
355
|
valid_loader = dataloader.create_dataloader(
|
|
352
356
|
data=valid_data,
|
|
353
357
|
batch_size=dataloader_cfg.get("valid_batch_size", 512),
|
|
354
358
|
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
355
359
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
360
|
+
prefetch_factor=dataloader_cfg.get("prefetch_factor"),
|
|
356
361
|
)
|
|
357
362
|
|
|
358
363
|
model_cfg.setdefault("session_id", session_id)
|
|
@@ -383,6 +388,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
383
388
|
loss=train_cfg.get("loss", "focal"),
|
|
384
389
|
loss_params=train_cfg.get("loss_params", {}),
|
|
385
390
|
loss_weights=train_cfg.get("loss_weights"),
|
|
391
|
+
ignore_label=train_cfg.get("ignore_label", -1),
|
|
386
392
|
)
|
|
387
393
|
|
|
388
394
|
model.fit(
|
|
@@ -397,6 +403,12 @@ def train_model(train_config_path: str) -> None:
|
|
|
397
403
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
398
404
|
user_id_column=id_column,
|
|
399
405
|
use_tensorboard=False,
|
|
406
|
+
use_wandb=train_cfg.get("use_wandb", False),
|
|
407
|
+
use_swanlab=train_cfg.get("use_swanlab", False),
|
|
408
|
+
wandb_api=train_cfg.get("wandb_api"),
|
|
409
|
+
swanlab_api=train_cfg.get("swanlab_api"),
|
|
410
|
+
log_interval=train_cfg.get("log_interval", 1),
|
|
411
|
+
note=train_cfg.get("note"),
|
|
400
412
|
)
|
|
401
413
|
|
|
402
414
|
|
|
@@ -583,6 +595,7 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
583
595
|
shuffle=False,
|
|
584
596
|
streaming=predict_cfg.get("streaming", True),
|
|
585
597
|
chunk_size=predict_cfg.get("chunk_size", 20000),
|
|
598
|
+
prefetch_factor=predict_cfg.get("prefetch_factor"),
|
|
586
599
|
)
|
|
587
600
|
|
|
588
601
|
save_format = predict_cfg.get(
|
nextrec/data/data_processing.py
CHANGED
|
@@ -13,6 +13,8 @@ import numpy as np
|
|
|
13
13
|
import pandas as pd
|
|
14
14
|
import torch
|
|
15
15
|
|
|
16
|
+
from nextrec.utils.torch_utils import to_numpy
|
|
17
|
+
|
|
16
18
|
|
|
17
19
|
def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
18
20
|
|
|
@@ -23,15 +25,7 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
|
23
25
|
return None
|
|
24
26
|
return data[name].values
|
|
25
27
|
else:
|
|
26
|
-
|
|
27
|
-
return getattr(data, name)
|
|
28
|
-
raise KeyError(f"Unsupported data type for extracting column {name}")
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def to_numpy(values: Any) -> np.ndarray:
|
|
32
|
-
if isinstance(values, torch.Tensor):
|
|
33
|
-
return values.detach().cpu().numpy()
|
|
34
|
-
return np.asarray(values)
|
|
28
|
+
raise KeyError(f"Only dict or DataFrame supported, got {type(data)}")
|
|
35
29
|
|
|
36
30
|
|
|
37
31
|
def get_data_length(data: Any) -> int | None:
|
nextrec/data/dataloader.py
CHANGED
|
@@ -194,6 +194,7 @@ class RecDataLoader(FeatureSet):
|
|
|
194
194
|
streaming: bool = False,
|
|
195
195
|
chunk_size: int = 10000,
|
|
196
196
|
num_workers: int = 0,
|
|
197
|
+
prefetch_factor: int | None = None,
|
|
197
198
|
sampler=None,
|
|
198
199
|
) -> DataLoader:
|
|
199
200
|
"""
|
|
@@ -206,6 +207,7 @@ class RecDataLoader(FeatureSet):
|
|
|
206
207
|
streaming: If True, use streaming mode for large files; if False, load full data into memory.
|
|
207
208
|
chunk_size: Chunk size for streaming mode (number of rows per chunk).
|
|
208
209
|
num_workers: Number of worker processes for data loading.
|
|
210
|
+
prefetch_factor: Number of batches loaded in advance by each worker.
|
|
209
211
|
sampler: Optional sampler for DataLoader, only used for distributed training.
|
|
210
212
|
Returns:
|
|
211
213
|
DataLoader instance.
|
|
@@ -234,6 +236,7 @@ class RecDataLoader(FeatureSet):
|
|
|
234
236
|
streaming=streaming,
|
|
235
237
|
chunk_size=chunk_size,
|
|
236
238
|
num_workers=num_workers,
|
|
239
|
+
prefetch_factor=prefetch_factor,
|
|
237
240
|
)
|
|
238
241
|
|
|
239
242
|
if isinstance(data, (dict, pd.DataFrame)):
|
|
@@ -242,6 +245,7 @@ class RecDataLoader(FeatureSet):
|
|
|
242
245
|
batch_size=batch_size,
|
|
243
246
|
shuffle=shuffle,
|
|
244
247
|
num_workers=num_workers,
|
|
248
|
+
prefetch_factor=prefetch_factor,
|
|
245
249
|
sampler=sampler,
|
|
246
250
|
)
|
|
247
251
|
|
|
@@ -253,6 +257,7 @@ class RecDataLoader(FeatureSet):
|
|
|
253
257
|
batch_size: int,
|
|
254
258
|
shuffle: bool,
|
|
255
259
|
num_workers: int = 0,
|
|
260
|
+
prefetch_factor: int | None = None,
|
|
256
261
|
sampler=None,
|
|
257
262
|
) -> DataLoader:
|
|
258
263
|
raw_data = data
|
|
@@ -275,6 +280,9 @@ class RecDataLoader(FeatureSet):
|
|
|
275
280
|
"[RecDataLoader Error] No valid tensors could be built from the provided data."
|
|
276
281
|
)
|
|
277
282
|
dataset = TensorDictDataset(tensors)
|
|
283
|
+
loader_kwargs = {}
|
|
284
|
+
if num_workers > 0 and prefetch_factor is not None:
|
|
285
|
+
loader_kwargs["prefetch_factor"] = prefetch_factor
|
|
278
286
|
return DataLoader(
|
|
279
287
|
dataset,
|
|
280
288
|
batch_size=batch_size,
|
|
@@ -284,6 +292,7 @@ class RecDataLoader(FeatureSet):
|
|
|
284
292
|
num_workers=num_workers,
|
|
285
293
|
pin_memory=torch.cuda.is_available(),
|
|
286
294
|
persistent_workers=num_workers > 0,
|
|
295
|
+
**loader_kwargs,
|
|
287
296
|
)
|
|
288
297
|
|
|
289
298
|
def create_from_path(
|
|
@@ -294,6 +303,7 @@ class RecDataLoader(FeatureSet):
|
|
|
294
303
|
streaming: bool,
|
|
295
304
|
chunk_size: int = 10000,
|
|
296
305
|
num_workers: int = 0,
|
|
306
|
+
prefetch_factor: int | None = None,
|
|
297
307
|
) -> DataLoader:
|
|
298
308
|
if isinstance(path, (str, os.PathLike)):
|
|
299
309
|
file_paths, file_type = resolve_file_paths(str(Path(path)))
|
|
@@ -327,6 +337,7 @@ class RecDataLoader(FeatureSet):
|
|
|
327
337
|
chunk_size,
|
|
328
338
|
shuffle,
|
|
329
339
|
num_workers=num_workers,
|
|
340
|
+
prefetch_factor=prefetch_factor,
|
|
330
341
|
)
|
|
331
342
|
|
|
332
343
|
dfs = []
|
|
@@ -350,7 +361,11 @@ class RecDataLoader(FeatureSet):
|
|
|
350
361
|
f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use streaming=True or reduce chunk_size."
|
|
351
362
|
) from exc
|
|
352
363
|
return self.create_from_memory(
|
|
353
|
-
combined_df,
|
|
364
|
+
combined_df,
|
|
365
|
+
batch_size,
|
|
366
|
+
shuffle,
|
|
367
|
+
num_workers=num_workers,
|
|
368
|
+
prefetch_factor=prefetch_factor,
|
|
354
369
|
)
|
|
355
370
|
|
|
356
371
|
def load_files_streaming(
|
|
@@ -361,6 +376,7 @@ class RecDataLoader(FeatureSet):
|
|
|
361
376
|
chunk_size: int,
|
|
362
377
|
shuffle: bool,
|
|
363
378
|
num_workers: int = 0,
|
|
379
|
+
prefetch_factor: int | None = None,
|
|
364
380
|
) -> DataLoader:
|
|
365
381
|
if not check_streaming_support(file_type):
|
|
366
382
|
raise ValueError(
|
|
@@ -393,8 +409,15 @@ class RecDataLoader(FeatureSet):
|
|
|
393
409
|
file_type=file_type,
|
|
394
410
|
processor=self.processor,
|
|
395
411
|
)
|
|
412
|
+
loader_kwargs = {}
|
|
413
|
+
if num_workers > 0 and prefetch_factor is not None:
|
|
414
|
+
loader_kwargs["prefetch_factor"] = prefetch_factor
|
|
396
415
|
return DataLoader(
|
|
397
|
-
dataset,
|
|
416
|
+
dataset,
|
|
417
|
+
batch_size=1,
|
|
418
|
+
collate_fn=collate_fn,
|
|
419
|
+
num_workers=num_workers,
|
|
420
|
+
**loader_kwargs,
|
|
398
421
|
)
|
|
399
422
|
|
|
400
423
|
|