nextrec 0.4.24__py3-none-any.whl → 0.4.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/asserts.py +72 -0
  3. nextrec/basic/loggers.py +18 -1
  4. nextrec/basic/model.py +191 -71
  5. nextrec/basic/summary.py +58 -0
  6. nextrec/cli.py +13 -0
  7. nextrec/data/data_processing.py +3 -9
  8. nextrec/data/dataloader.py +25 -2
  9. nextrec/data/preprocessor.py +283 -36
  10. nextrec/models/multi_task/[pre]aitm.py +173 -0
  11. nextrec/models/multi_task/[pre]snr_trans.py +232 -0
  12. nextrec/models/multi_task/[pre]star.py +192 -0
  13. nextrec/models/multi_task/apg.py +330 -0
  14. nextrec/models/multi_task/cross_stitch.py +229 -0
  15. nextrec/models/multi_task/escm.py +290 -0
  16. nextrec/models/multi_task/esmm.py +8 -21
  17. nextrec/models/multi_task/hmoe.py +203 -0
  18. nextrec/models/multi_task/mmoe.py +20 -28
  19. nextrec/models/multi_task/pepnet.py +68 -66
  20. nextrec/models/multi_task/ple.py +30 -44
  21. nextrec/models/multi_task/poso.py +13 -22
  22. nextrec/models/multi_task/share_bottom.py +14 -25
  23. nextrec/models/ranking/afm.py +2 -2
  24. nextrec/models/ranking/autoint.py +2 -4
  25. nextrec/models/ranking/dcn.py +2 -3
  26. nextrec/models/ranking/dcn_v2.py +2 -3
  27. nextrec/models/ranking/deepfm.py +2 -3
  28. nextrec/models/ranking/dien.py +7 -9
  29. nextrec/models/ranking/din.py +8 -10
  30. nextrec/models/ranking/eulernet.py +1 -2
  31. nextrec/models/ranking/ffm.py +1 -2
  32. nextrec/models/ranking/fibinet.py +2 -3
  33. nextrec/models/ranking/fm.py +1 -1
  34. nextrec/models/ranking/lr.py +1 -1
  35. nextrec/models/ranking/masknet.py +1 -2
  36. nextrec/models/ranking/pnn.py +1 -2
  37. nextrec/models/ranking/widedeep.py +2 -3
  38. nextrec/models/ranking/xdeepfm.py +2 -4
  39. nextrec/models/representation/rqvae.py +4 -4
  40. nextrec/models/retrieval/dssm.py +18 -26
  41. nextrec/models/retrieval/dssm_v2.py +15 -22
  42. nextrec/models/retrieval/mind.py +9 -15
  43. nextrec/models/retrieval/sdm.py +36 -33
  44. nextrec/models/retrieval/youtube_dnn.py +16 -24
  45. nextrec/models/sequential/hstu.py +2 -2
  46. nextrec/utils/__init__.py +5 -1
  47. nextrec/utils/config.py +2 -0
  48. nextrec/utils/model.py +16 -77
  49. nextrec/utils/torch_utils.py +11 -0
  50. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
  51. nextrec-0.4.27.dist-info/RECORD +90 -0
  52. nextrec/models/multi_task/aitm.py +0 -0
  53. nextrec/models/multi_task/snr_trans.py +0 -0
  54. nextrec-0.4.24.dist-info/RECORD +0 -86
  55. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
  56. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
  57. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/summary.py CHANGED
@@ -48,6 +48,27 @@ class SummarySet:
48
48
  checkpoint_path: str
49
49
  train_data_summary: dict[str, Any] | None
50
50
  valid_data_summary: dict[str, Any] | None
51
+ note: str | None
52
+
53
+ def collect_dataloader_summary(self, data_loader: DataLoader | None):
54
+ if data_loader is None:
55
+ return None
56
+
57
+ summary = {
58
+ "batch_size": data_loader.batch_size,
59
+ "num_workers": data_loader.num_workers,
60
+ "pin_memory": data_loader.pin_memory,
61
+ "persistent_workers": data_loader.persistent_workers,
62
+ }
63
+ prefetch_factor = getattr(data_loader, "prefetch_factor", None)
64
+ if prefetch_factor is not None:
65
+ summary["prefetch_factor"] = prefetch_factor
66
+
67
+ sampler = getattr(data_loader, "sampler", None)
68
+ if sampler is not None:
69
+ summary["sampler"] = sampler.__class__.__name__
70
+
71
+ return summary or None
51
72
 
52
73
  def build_data_summary(
53
74
  self, data: Any, data_loader: DataLoader | None, sample_key: str
@@ -66,6 +87,10 @@ class SummarySet:
66
87
  if train_size is not None:
67
88
  summary[sample_key] = int(train_size)
68
89
 
90
+ dataloader_summary = self.collect_dataloader_summary(data_loader)
91
+ if dataloader_summary:
92
+ summary["dataloader"] = dataloader_summary
93
+
69
94
  if labels:
70
95
  task_types = list(self.task) if isinstance(self.task, list) else [self.task]
71
96
  if len(task_types) != len(self.target_columns):
@@ -321,6 +346,7 @@ class SummarySet:
321
346
  logger.info(f" Session ID: {self.session_id}")
322
347
  logger.info(f" Features Config Path: {self.features_config_path}")
323
348
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
349
+ logger.info(f" Note: {self.note}")
324
350
 
325
351
  if "Data Summary" in selected_sections and (
326
352
  self.train_data_summary or self.valid_data_summary
@@ -341,6 +367,22 @@ class SummarySet:
341
367
  for label, value in lines:
342
368
  logger.info(f" {format_kv(label, value)}")
343
369
 
370
+ dataloader_info = self.train_data_summary.get("dataloader")
371
+ if isinstance(dataloader_info, dict):
372
+ logger.info("Train DataLoader:")
373
+ for key in (
374
+ "batch_size",
375
+ "num_workers",
376
+ "pin_memory",
377
+ "persistent_workers",
378
+ "sampler",
379
+ ):
380
+ if key in dataloader_info:
381
+ label = key.replace("_", " ").title()
382
+ logger.info(
383
+ format_kv(label, dataloader_info[key], indent=2)
384
+ )
385
+
344
386
  if self.valid_data_summary:
345
387
  if self.train_data_summary:
346
388
  logger.info("")
@@ -355,3 +397,19 @@ class SummarySet:
355
397
  logger.info(f"{target_name}:")
356
398
  for label, value in lines:
357
399
  logger.info(f" {format_kv(label, value)}")
400
+
401
+ dataloader_info = self.valid_data_summary.get("dataloader")
402
+ if isinstance(dataloader_info, dict):
403
+ logger.info("Valid DataLoader:")
404
+ for key in (
405
+ "batch_size",
406
+ "num_workers",
407
+ "pin_memory",
408
+ "persistent_workers",
409
+ "sampler",
410
+ ):
411
+ if key in dataloader_info:
412
+ label = key.replace("_", " ").title()
413
+ logger.info(
414
+ format_kv(label, dataloader_info[key], indent=2)
415
+ )
nextrec/cli.py CHANGED
@@ -320,6 +320,7 @@ def train_model(train_config_path: str) -> None:
320
320
  streaming=True,
321
321
  chunk_size=dataloader_chunk_size,
322
322
  num_workers=dataloader_cfg.get("num_workers", 0),
323
+ prefetch_factor=dataloader_cfg.get("prefetch_factor"),
323
324
  )
324
325
  valid_loader = None
325
326
  if val_data_path:
@@ -331,6 +332,7 @@ def train_model(train_config_path: str) -> None:
331
332
  streaming=True,
332
333
  chunk_size=dataloader_chunk_size,
333
334
  num_workers=dataloader_cfg.get("num_workers", 0),
335
+ prefetch_factor=dataloader_cfg.get("prefetch_factor"),
334
336
  )
335
337
  elif streaming_valid_files:
336
338
  valid_loader = dataloader.create_dataloader(
@@ -340,6 +342,7 @@ def train_model(train_config_path: str) -> None:
340
342
  streaming=True,
341
343
  chunk_size=dataloader_chunk_size,
342
344
  num_workers=dataloader_cfg.get("num_workers", 0),
345
+ prefetch_factor=dataloader_cfg.get("prefetch_factor"),
343
346
  )
344
347
  else:
345
348
  train_loader = dataloader.create_dataloader(
@@ -347,12 +350,14 @@ def train_model(train_config_path: str) -> None:
347
350
  batch_size=dataloader_cfg.get("train_batch_size", 512),
348
351
  shuffle=dataloader_cfg.get("train_shuffle", True),
349
352
  num_workers=dataloader_cfg.get("num_workers", 0),
353
+ prefetch_factor=dataloader_cfg.get("prefetch_factor"),
350
354
  )
351
355
  valid_loader = dataloader.create_dataloader(
352
356
  data=valid_data,
353
357
  batch_size=dataloader_cfg.get("valid_batch_size", 512),
354
358
  shuffle=dataloader_cfg.get("valid_shuffle", False),
355
359
  num_workers=dataloader_cfg.get("num_workers", 0),
360
+ prefetch_factor=dataloader_cfg.get("prefetch_factor"),
356
361
  )
357
362
 
358
363
  model_cfg.setdefault("session_id", session_id)
@@ -383,6 +388,7 @@ def train_model(train_config_path: str) -> None:
383
388
  loss=train_cfg.get("loss", "focal"),
384
389
  loss_params=train_cfg.get("loss_params", {}),
385
390
  loss_weights=train_cfg.get("loss_weights"),
391
+ ignore_label=train_cfg.get("ignore_label", -1),
386
392
  )
387
393
 
388
394
  model.fit(
@@ -397,6 +403,12 @@ def train_model(train_config_path: str) -> None:
397
403
  num_workers=dataloader_cfg.get("num_workers", 0),
398
404
  user_id_column=id_column,
399
405
  use_tensorboard=False,
406
+ use_wandb=train_cfg.get("use_wandb", False),
407
+ use_swanlab=train_cfg.get("use_swanlab", False),
408
+ wandb_api=train_cfg.get("wandb_api"),
409
+ swanlab_api=train_cfg.get("swanlab_api"),
410
+ log_interval=train_cfg.get("log_interval", 1),
411
+ note=train_cfg.get("note"),
400
412
  )
401
413
 
402
414
 
@@ -583,6 +595,7 @@ def predict_model(predict_config_path: str) -> None:
583
595
  shuffle=False,
584
596
  streaming=predict_cfg.get("streaming", True),
585
597
  chunk_size=predict_cfg.get("chunk_size", 20000),
598
+ prefetch_factor=predict_cfg.get("prefetch_factor"),
586
599
  )
587
600
 
588
601
  save_format = predict_cfg.get(
@@ -13,6 +13,8 @@ import numpy as np
13
13
  import pandas as pd
14
14
  import torch
15
15
 
16
+ from nextrec.utils.torch_utils import to_numpy
17
+
16
18
 
17
19
  def get_column_data(data: dict | pd.DataFrame, name: str):
18
20
 
@@ -23,15 +25,7 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
23
25
  return None
24
26
  return data[name].values
25
27
  else:
26
- if hasattr(data, name):
27
- return getattr(data, name)
28
- raise KeyError(f"Unsupported data type for extracting column {name}")
29
-
30
-
31
- def to_numpy(values: Any) -> np.ndarray:
32
- if isinstance(values, torch.Tensor):
33
- return values.detach().cpu().numpy()
34
- return np.asarray(values)
28
+ raise KeyError(f"Only dict or DataFrame supported, got {type(data)}")
35
29
 
36
30
 
37
31
  def get_data_length(data: Any) -> int | None:
@@ -194,6 +194,7 @@ class RecDataLoader(FeatureSet):
194
194
  streaming: bool = False,
195
195
  chunk_size: int = 10000,
196
196
  num_workers: int = 0,
197
+ prefetch_factor: int | None = None,
197
198
  sampler=None,
198
199
  ) -> DataLoader:
199
200
  """
@@ -206,6 +207,7 @@ class RecDataLoader(FeatureSet):
206
207
  streaming: If True, use streaming mode for large files; if False, load full data into memory.
207
208
  chunk_size: Chunk size for streaming mode (number of rows per chunk).
208
209
  num_workers: Number of worker processes for data loading.
210
+ prefetch_factor: Number of batches loaded in advance by each worker.
209
211
  sampler: Optional sampler for DataLoader, only used for distributed training.
210
212
  Returns:
211
213
  DataLoader instance.
@@ -234,6 +236,7 @@ class RecDataLoader(FeatureSet):
234
236
  streaming=streaming,
235
237
  chunk_size=chunk_size,
236
238
  num_workers=num_workers,
239
+ prefetch_factor=prefetch_factor,
237
240
  )
238
241
 
239
242
  if isinstance(data, (dict, pd.DataFrame)):
@@ -242,6 +245,7 @@ class RecDataLoader(FeatureSet):
242
245
  batch_size=batch_size,
243
246
  shuffle=shuffle,
244
247
  num_workers=num_workers,
248
+ prefetch_factor=prefetch_factor,
245
249
  sampler=sampler,
246
250
  )
247
251
 
@@ -253,6 +257,7 @@ class RecDataLoader(FeatureSet):
253
257
  batch_size: int,
254
258
  shuffle: bool,
255
259
  num_workers: int = 0,
260
+ prefetch_factor: int | None = None,
256
261
  sampler=None,
257
262
  ) -> DataLoader:
258
263
  raw_data = data
@@ -275,6 +280,9 @@ class RecDataLoader(FeatureSet):
275
280
  "[RecDataLoader Error] No valid tensors could be built from the provided data."
276
281
  )
277
282
  dataset = TensorDictDataset(tensors)
283
+ loader_kwargs = {}
284
+ if num_workers > 0 and prefetch_factor is not None:
285
+ loader_kwargs["prefetch_factor"] = prefetch_factor
278
286
  return DataLoader(
279
287
  dataset,
280
288
  batch_size=batch_size,
@@ -284,6 +292,7 @@ class RecDataLoader(FeatureSet):
284
292
  num_workers=num_workers,
285
293
  pin_memory=torch.cuda.is_available(),
286
294
  persistent_workers=num_workers > 0,
295
+ **loader_kwargs,
287
296
  )
288
297
 
289
298
  def create_from_path(
@@ -294,6 +303,7 @@ class RecDataLoader(FeatureSet):
294
303
  streaming: bool,
295
304
  chunk_size: int = 10000,
296
305
  num_workers: int = 0,
306
+ prefetch_factor: int | None = None,
297
307
  ) -> DataLoader:
298
308
  if isinstance(path, (str, os.PathLike)):
299
309
  file_paths, file_type = resolve_file_paths(str(Path(path)))
@@ -327,6 +337,7 @@ class RecDataLoader(FeatureSet):
327
337
  chunk_size,
328
338
  shuffle,
329
339
  num_workers=num_workers,
340
+ prefetch_factor=prefetch_factor,
330
341
  )
331
342
 
332
343
  dfs = []
@@ -350,7 +361,11 @@ class RecDataLoader(FeatureSet):
350
361
  f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use streaming=True or reduce chunk_size."
351
362
  ) from exc
352
363
  return self.create_from_memory(
353
- combined_df, batch_size, shuffle, num_workers=num_workers
364
+ combined_df,
365
+ batch_size,
366
+ shuffle,
367
+ num_workers=num_workers,
368
+ prefetch_factor=prefetch_factor,
354
369
  )
355
370
 
356
371
  def load_files_streaming(
@@ -361,6 +376,7 @@ class RecDataLoader(FeatureSet):
361
376
  chunk_size: int,
362
377
  shuffle: bool,
363
378
  num_workers: int = 0,
379
+ prefetch_factor: int | None = None,
364
380
  ) -> DataLoader:
365
381
  if not check_streaming_support(file_type):
366
382
  raise ValueError(
@@ -393,8 +409,15 @@ class RecDataLoader(FeatureSet):
393
409
  file_type=file_type,
394
410
  processor=self.processor,
395
411
  )
412
+ loader_kwargs = {}
413
+ if num_workers > 0 and prefetch_factor is not None:
414
+ loader_kwargs["prefetch_factor"] = prefetch_factor
396
415
  return DataLoader(
397
- dataset, batch_size=1, collate_fn=collate_fn, num_workers=num_workers
416
+ dataset,
417
+ batch_size=1,
418
+ collate_fn=collate_fn,
419
+ num_workers=num_workers,
420
+ **loader_kwargs,
398
421
  )
399
422
 
400
423