nextrec 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/callback.py +30 -15
  3. nextrec/basic/features.py +1 -0
  4. nextrec/basic/layers.py +6 -8
  5. nextrec/basic/loggers.py +14 -7
  6. nextrec/basic/metrics.py +6 -76
  7. nextrec/basic/model.py +312 -318
  8. nextrec/cli.py +5 -10
  9. nextrec/data/__init__.py +13 -16
  10. nextrec/data/batch_utils.py +3 -2
  11. nextrec/data/data_processing.py +10 -2
  12. nextrec/data/data_utils.py +9 -14
  13. nextrec/data/dataloader.py +12 -13
  14. nextrec/data/preprocessor.py +328 -255
  15. nextrec/loss/__init__.py +1 -5
  16. nextrec/loss/loss_utils.py +2 -8
  17. nextrec/models/generative/__init__.py +1 -8
  18. nextrec/models/generative/hstu.py +6 -4
  19. nextrec/models/multi_task/esmm.py +2 -2
  20. nextrec/models/multi_task/mmoe.py +2 -2
  21. nextrec/models/multi_task/ple.py +2 -2
  22. nextrec/models/multi_task/poso.py +2 -3
  23. nextrec/models/multi_task/share_bottom.py +2 -2
  24. nextrec/models/ranking/afm.py +2 -2
  25. nextrec/models/ranking/autoint.py +2 -2
  26. nextrec/models/ranking/dcn.py +2 -2
  27. nextrec/models/ranking/dcn_v2.py +2 -2
  28. nextrec/models/ranking/deepfm.py +2 -2
  29. nextrec/models/ranking/dien.py +3 -3
  30. nextrec/models/ranking/din.py +3 -3
  31. nextrec/models/ranking/ffm.py +0 -0
  32. nextrec/models/ranking/fibinet.py +5 -5
  33. nextrec/models/ranking/fm.py +3 -7
  34. nextrec/models/ranking/lr.py +0 -0
  35. nextrec/models/ranking/masknet.py +2 -2
  36. nextrec/models/ranking/pnn.py +2 -2
  37. nextrec/models/ranking/widedeep.py +2 -2
  38. nextrec/models/ranking/xdeepfm.py +2 -2
  39. nextrec/models/representation/__init__.py +9 -0
  40. nextrec/models/{generative → representation}/rqvae.py +9 -9
  41. nextrec/models/retrieval/__init__.py +0 -0
  42. nextrec/models/{match → retrieval}/dssm.py +8 -3
  43. nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
  44. nextrec/models/{match → retrieval}/mind.py +4 -3
  45. nextrec/models/{match → retrieval}/sdm.py +4 -3
  46. nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
  47. nextrec/utils/__init__.py +60 -46
  48. nextrec/utils/config.py +8 -7
  49. nextrec/utils/console.py +371 -0
  50. nextrec/utils/{synthetic_data.py → data.py} +102 -15
  51. nextrec/utils/feature.py +15 -0
  52. nextrec/utils/torch_utils.py +411 -0
  53. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/METADATA +6 -6
  54. nextrec-0.4.9.dist-info/RECORD +70 -0
  55. nextrec/utils/cli_utils.py +0 -58
  56. nextrec/utils/device.py +0 -78
  57. nextrec/utils/distributed.py +0 -141
  58. nextrec/utils/file.py +0 -92
  59. nextrec/utils/initializer.py +0 -79
  60. nextrec/utils/optimizer.py +0 -75
  61. nextrec/utils/tensor.py +0 -72
  62. nextrec-0.4.8.dist-info/RECORD +0 -71
  63. /nextrec/models/{match/__init__.py → ranking/eulernet.py} +0 -0
  64. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/WHEEL +0 -0
  65. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/entry_points.txt +0 -0
  66. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,70 @@
1
+ nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
2
+ nextrec/__version__.py,sha256=LdxLMJM_JXsCQBeSvnxCNyGWmINE0yWfna3DQaT41Vs,22
3
+ nextrec/cli.py,sha256=vumtNQww-FXgGa6I90IhEjngL1Y-e0GSuSK442s4M40,19209
4
+ nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ nextrec/basic/activation.py,sha256=uzTWfCOtBSkbu_Gk9XBNTj8__s241CaYLJk6l8nGX9I,2885
6
+ nextrec/basic/callback.py,sha256=a6gg7r3x1v0xaOSya9PteLql7I14nepY7gX8tDYtins,14679
7
+ nextrec/basic/features.py,sha256=Wnbzr7UMotgv1Vzeg0o9Po-KKIvYUSYIghoVDfMPx_g,4340
8
+ nextrec/basic/layers.py,sha256=GJH2Tx3IkZrYGb7-ET976iHCC28Ubck_NO9-iyY4mDI,28911
9
+ nextrec/basic/loggers.py,sha256=JnQiFvmsVgZ63gqBLR2ZFWrVPzkxRbzWhTdeoiJKcos,6526
10
+ nextrec/basic/metrics.py,sha256=8RswR_3MGvIBkT_n6fnmON2eYH-hfD7kIKVnyJJjL3o,23131
11
+ nextrec/basic/model.py,sha256=bQPRRIOGB_0PNJv3Zr-8mruUMBC7N2dmPWy7r3L45M4,98649
12
+ nextrec/basic/session.py,sha256=UOG_-EgCOxvqZwCkiEd8sgNV2G1sm_HbzKYVQw8yYDI,4483
13
+ nextrec/data/__init__.py,sha256=YZQjpty1pDCM7q_YNmiA2sa5kbujUw26ObLHWjMPjKY,1194
14
+ nextrec/data/batch_utils.py,sha256=0bYGVX7RlhnHv_ZBaUngjDIpBNw-igCk98DgOsF7T6o,2879
15
+ nextrec/data/data_processing.py,sha256=lKXDBszrO5fJMAQetgSPr2mSQuzOluuz1eHV4jp0TDU,5538
16
+ nextrec/data/data_utils.py,sha256=0Ls1cnG9lBz0ovtyedw5vwp7WegGK_iF-F8e_3DEddo,880
17
+ nextrec/data/dataloader.py,sha256=D2QZDxc9Ic7zkSoaJQBxrjmbHWyJ8d3k0QF3IqLZxfk,18793
18
+ nextrec/data/preprocessor.py,sha256=wNjivq2N-iDzBropkp3YfSkN0jSA4l4h81C-ECa6k4c,44643
19
+ nextrec/loss/__init__.py,sha256=-sibZK8QXLblVNWqdqjrPPzMCDyIXSq7yd2eZ57p9Nw,810
20
+ nextrec/loss/listwise.py,sha256=UT9vJCOTOQLogVwaeTV7Z5uxIYnngGdxk-p9e97MGkU,5744
21
+ nextrec/loss/loss_utils.py,sha256=Eg_EKm47onSCLhgs2q7IkB7TV9TwV1Dz4QgVR2yh-gc,4610
22
+ nextrec/loss/pairwise.py,sha256=X9yg-8pcPt2IWU0AiUhWAt3_4W_3wIF0uSdDYTdoPFY,3398
23
+ nextrec/loss/pointwise.py,sha256=o9J3OznY0hlbDsUXqn3k-BBzYiuUH5dopz8QBFqS_kQ,7343
24
+ nextrec/models/generative/__init__.py,sha256=QtfMCu0-R7jg_HgjrUTP8ShWzXjQnPCpQjqpGwbMzp4,176
25
+ nextrec/models/generative/hstu.py,sha256=P2Kl7HEL3afwiCApGKQ6UbUNO9eNXXrB10H7iiF8cI0,19735
26
+ nextrec/models/generative/tiger.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ nextrec/models/multi_task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
+ nextrec/models/multi_task/esmm.py,sha256=3QQePhSkOcM4t472S4E5xU9_KiLiSwHb9CfdkEgmqqk,6491
29
+ nextrec/models/multi_task/mmoe.py,sha256=uFTbc0MiFBDTCIt8mTW6xs0oyOn1EesIHHZo81HR35k,8583
30
+ nextrec/models/multi_task/ple.py,sha256=z32etizNlTLwwR7CYKxy8u9owAbtiRh492Fje_y64hQ,13016
31
+ nextrec/models/multi_task/poso.py,sha256=foH7XDUz0XN0s0zoyHLuTmrcs3QOT8-x4YGxLX1Lxxg,19016
32
+ nextrec/models/multi_task/share_bottom.py,sha256=rmEnsX3LA3pNsLKfG1ir5WDLdkSY-imO_ASiclirJiA,6519
33
+ nextrec/models/ranking/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
+ nextrec/models/ranking/afm.py,sha256=96jGUPL4yTWobMIVBjHpOxl9AtAzCAGR8yw7Sy2JmdQ,10125
35
+ nextrec/models/ranking/autoint.py,sha256=S6Cxnp1q2OErSYqmIix5P-b4qLWR-0dY6TMStuU6WLg,8109
36
+ nextrec/models/ranking/dcn.py,sha256=whkjiKEuadl6oSP-NJdSOCOqvWZGX4EsId9oqlfVpa8,7299
37
+ nextrec/models/ranking/dcn_v2.py,sha256=QnqQbJsrtQp4mtvnBXFUVefKyr4dw-gHNWrCbO26oHw,11163
38
+ nextrec/models/ranking/deepfm.py,sha256=EMAGhPCjJHmxpkoTWaioVgNt2yVB0PzGJpDc8cQlczs,5224
39
+ nextrec/models/ranking/dien.py,sha256=c7Zs85vxhOgKHg5s0QcSLCn1xXCCSD177TMERgM_v8g,18958
40
+ nextrec/models/ranking/din.py,sha256=gdUhuKiKXBNOALbK8fGhlbSeuDT8agcEdNSrC_wveHc,9422
41
+ nextrec/models/ranking/eulernet.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
+ nextrec/models/ranking/ffm.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
43
+ nextrec/models/ranking/fibinet.py,sha256=_eroddVHooJcaGT8MqS4mUrtv5j4pnTmfI3FoAKOZhs,7919
44
+ nextrec/models/ranking/fm.py,sha256=SsrSKK3y4xg5Lv-t3JLnZan55Hzze2AxAiVPuscy0bk,4536
45
+ nextrec/models/ranking/lr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
+ nextrec/models/ranking/masknet.py,sha256=tY1y2lO0iq82oylPN0SBnL5Bikc8weinFXpURyVT1hE,12373
47
+ nextrec/models/ranking/pnn.py,sha256=FcNIFAw5J0ORGSR6L8ZK7NeXlJPpojwe_SpsxMQqCFw,8174
48
+ nextrec/models/ranking/widedeep.py,sha256=-ghKfe_0puvlI9fBQr8lK3gXkfVvslGwP40AJTGqc7w,5077
49
+ nextrec/models/ranking/xdeepfm.py,sha256=FMtl_zYO1Ty_2d9VWRsz6Jo-Xjw8vikpIQPZCDVavVY,8156
50
+ nextrec/models/representation/__init__.py,sha256=O3QHMMXBszwM-mTl7bA3wawNZvDGet-QIv6Ys5GHGJ8,190
51
+ nextrec/models/representation/rqvae.py,sha256=JyZxVY9CibcdBGk97TxjG5O3WQC10_60tHNcP_qtegs,29290
52
+ nextrec/models/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
53
+ nextrec/models/retrieval/dssm.py,sha256=yMTu9msn6WgPqpM_zaJ0Z8R3uAYPzuJY_NS0xjJGSx0,8043
54
+ nextrec/models/retrieval/dssm_v2.py,sha256=IJzCT78Ra1DeSpR85KF1Awqnywu_6kDko3XDPbcup5s,7075
55
+ nextrec/models/retrieval/mind.py,sha256=ktNXwZxhnH1DV08c0gdsbodo4QJQL4UQGeyOqpmGwVM,15090
56
+ nextrec/models/retrieval/sdm.py,sha256=LhkCZSfGhxOxziEkUtjr_hnqcyciJ2qpMoBSFBVW9lQ,10558
57
+ nextrec/models/retrieval/youtube_dnn.py,sha256=xtGPV6_5LeSZBKkrTaU1CmtxlhgYLvZmjpwYaXYIaEA,7403
58
+ nextrec/utils/__init__.py,sha256=5ss2XQq8QZ2Ko5eiQ7oIig5cIZNrYGIptaarYEeO7Fk,2550
59
+ nextrec/utils/config.py,sha256=0HOeMyTlx8g6BZVpXzo2lEOkb-mzNwhbigQuUomsYnY,19934
60
+ nextrec/utils/console.py,sha256=D2Vax9_b7bgvAAOyk-Q2oUhSk1B-OngY5buS9Gb9-I0,11398
61
+ nextrec/utils/data.py,sha256=alruiWZFbmwy3kO12q42VXmtHmXFFjVULpHa43fx_mI,21098
62
+ nextrec/utils/embedding.py,sha256=akAEc062MG2cD7VIOllHaqtwzAirQR2gq5iW7oKpGAU,1449
63
+ nextrec/utils/feature.py,sha256=rsUAv3ELyDpehVw8nPEEsLCCIjuKGTJJZuFaWB_wrPk,633
64
+ nextrec/utils/model.py,sha256=dYl1XfIZt6aVjNyV2AAhcArwFRMcEAKrjG_pr8AVHs0,1163
65
+ nextrec/utils/torch_utils.py,sha256=AKfYbSOJjEw874xsDB5IO3Ote4X7vnqzt_E0jJny0o8,13468
66
+ nextrec-0.4.9.dist-info/METADATA,sha256=FnndUmBNuNT6Odd_8XTlnW71NmqZ_Wr5lFNmxVwha1k,19463
67
+ nextrec-0.4.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
68
+ nextrec-0.4.9.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
69
+ nextrec-0.4.9.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
70
+ nextrec-0.4.9.dist-info/RECORD,,
@@ -1,58 +0,0 @@
1
- """
2
- CLI utilities for NextRec.
3
-
4
- This module provides small helpers used by command-line entrypoints, such as
5
- printing startup banners and resolving the installed package version.
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- import logging
11
- import os
12
- import platform
13
- import sys
14
- from datetime import datetime
15
-
16
-
17
- def get_nextrec_version() -> str:
18
- """
19
- Best-effort version resolver for NextRec.
20
-
21
- Prefer in-repo `nextrec.__version__`, fall back to installed package metadata.
22
- """
23
- try:
24
- from nextrec import __version__ # type: ignore
25
-
26
- if __version__:
27
- return str(__version__)
28
- except Exception:
29
- pass
30
-
31
- try:
32
- from importlib.metadata import version
33
-
34
- return version("nextrec")
35
- except Exception:
36
- return "unknown"
37
-
38
-
39
- def log_startup_info(
40
- logger: logging.Logger, *, mode: str, config_path: str | None
41
- ) -> None:
42
- """Log a short, user-friendly startup banner."""
43
- version = get_nextrec_version()
44
- now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
45
-
46
- lines = [
47
- "NextRec CLI",
48
- f"- Version: {version}",
49
- f"- Time: {now}",
50
- f"- Mode: {mode}",
51
- f"- Config: {config_path or '(not set)'}",
52
- f"- Python: {platform.python_version()} ({sys.executable})",
53
- f"- Platform: {platform.system()} {platform.release()} ({platform.machine()})",
54
- f"- Workdir: {os.getcwd()}",
55
- f"- Command: {' '.join(sys.argv)}",
56
- ]
57
- for line in lines:
58
- logger.info(line)
nextrec/utils/device.py DELETED
@@ -1,78 +0,0 @@
1
- """
2
- Device management utilities for NextRec
3
-
4
- Date: create on 03/12/2025
5
- Checkpoint: edit on 06/12/2025
6
- Author: Yang Zhou, zyaztec@gmail.com
7
- """
8
-
9
- import torch
10
- import platform
11
- import logging
12
-
13
-
14
- def resolve_device() -> str:
15
- if torch.cuda.is_available():
16
- return "cuda"
17
- if torch.backends.mps.is_available():
18
- mac_ver = platform.mac_ver()[0]
19
- try:
20
- major, _ = (int(x) for x in mac_ver.split(".")[:2])
21
- except Exception:
22
- major, _ = 0, 0
23
- if major >= 14:
24
- return "mps"
25
- return "cpu"
26
-
27
-
28
- def get_device_info() -> dict:
29
- info = {
30
- "cuda_available": torch.cuda.is_available(),
31
- "cuda_device_count": (
32
- torch.cuda.device_count() if torch.cuda.is_available() else 0
33
- ),
34
- "mps_available": torch.backends.mps.is_available(),
35
- "current_device": resolve_device(),
36
- }
37
-
38
- if torch.cuda.is_available():
39
- info["cuda_device_name"] = torch.cuda.get_device_name(0)
40
- info["cuda_capability"] = torch.cuda.get_device_capability(0)
41
-
42
- return info
43
-
44
-
45
- def configure_device(
46
- distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
47
- ) -> torch.device:
48
- try:
49
- device = torch.device(base_device)
50
- except Exception:
51
- logging.warning(
52
- "[configure_device Warning] Invalid base_device, falling back to CPU."
53
- )
54
- return torch.device("cpu")
55
-
56
- if distributed:
57
- if device.type == "cuda":
58
- if not torch.cuda.is_available():
59
- logging.warning(
60
- "[Distributed Warning] CUDA requested but unavailable. Falling back to CPU."
61
- )
62
- return torch.device("cpu")
63
- if not (0 <= local_rank < torch.cuda.device_count()):
64
- logging.warning(
65
- f"[Distributed Warning] local_rank {local_rank} is invalid for available CUDA devices. Falling back to CPU."
66
- )
67
- return torch.device("cpu")
68
- try:
69
- torch.cuda.set_device(local_rank)
70
- return torch.device(f"cuda:{local_rank}")
71
- except Exception as exc:
72
- logging.warning(
73
- f"[Distributed Warning] Failed to set CUDA device for local_rank {local_rank}: {exc}. Falling back to CPU."
74
- )
75
- return torch.device("cpu")
76
- else:
77
- return torch.device("cpu")
78
- return device
@@ -1,141 +0,0 @@
1
- """
2
- Distributed utilities for NextRec.
3
-
4
- Date: create on 04/12/2025
5
- Checkpoint: edit on 05/12/2025
6
- Author: Yang Zhou,zyaztec@gmail.com
7
- """
8
-
9
- import logging
10
- import numpy as np
11
- import torch
12
- import torch.distributed as dist
13
-
14
- from torch.utils.data import DataLoader, IterableDataset
15
- from torch.utils.data.distributed import DistributedSampler
16
- from nextrec.basic.loggers import colorize
17
-
18
-
19
- def init_process_group(
20
- distributed: bool, rank: int, world_size: int, device_id: int | None = None
21
- ) -> None:
22
- """
23
- initialize distributed process group for multi-GPU training.
24
-
25
- Args:
26
- distributed: whether to enable distributed training
27
- rank: global rank of the current process
28
- world_size: total number of processes
29
- """
30
- if (not distributed) or (not dist.is_available()) or dist.is_initialized():
31
- return
32
- backend = "nccl" if device_id is not None else "gloo"
33
- if backend == "nccl":
34
- torch.cuda.set_device(device_id)
35
- dist.init_process_group(
36
- backend=backend, init_method="env://", rank=rank, world_size=world_size
37
- )
38
-
39
-
40
- def gather_numpy(self, array: np.ndarray | None) -> np.ndarray | None:
41
- """
42
- Gather numpy arrays (or None) across ranks. Uses all_gather_object to avoid
43
- shape mismatches and ensures every rank participates even when local data is empty.
44
- """
45
- if not (self.distributed and dist.is_available() and dist.is_initialized()):
46
- return array
47
-
48
- world_size = dist.get_world_size()
49
- gathered: list[np.ndarray | None] = [None for _ in range(world_size)]
50
- dist.all_gather_object(gathered, array)
51
- pieces: list[np.ndarray] = []
52
- for item in gathered:
53
- if item is None:
54
- continue
55
- item_np = np.asarray(item)
56
- if item_np.size > 0:
57
- pieces.append(item_np)
58
- if not pieces:
59
- return None
60
- return np.concatenate(pieces, axis=0)
61
-
62
-
63
- def add_distributed_sampler(
64
- loader: DataLoader,
65
- distributed: bool,
66
- world_size: int,
67
- rank: int,
68
- shuffle: bool,
69
- drop_last: bool,
70
- default_batch_size: int,
71
- is_main_process: bool = False,
72
- ) -> tuple[DataLoader, DistributedSampler | None]:
73
- """
74
- add distributedsampler to a dataloader, this for distributed training
75
- when each device has its own dataloader
76
- """
77
- # early return if not distributed
78
- if not (distributed and dist.is_available() and dist.is_initialized()):
79
- return loader, None
80
- # return if already has DistributedSampler
81
- if isinstance(loader.sampler, DistributedSampler):
82
- return loader, loader.sampler
83
- dataset = getattr(loader, "dataset", None)
84
- if dataset is None:
85
- return loader, None
86
- if isinstance(dataset, IterableDataset):
87
- if is_main_process:
88
- logging.info(
89
- colorize(
90
- "[Distributed Info] Iterable/streaming DataLoader provided; DistributedSampler is skipped. Ensure dataset handles sharding per rank.",
91
- color="yellow",
92
- )
93
- )
94
- return loader, None
95
- sampler = DistributedSampler(
96
- dataset,
97
- num_replicas=world_size,
98
- rank=rank,
99
- shuffle=shuffle,
100
- drop_last=drop_last,
101
- )
102
- loader_kwargs = {
103
- "batch_size": (
104
- loader.batch_size if loader.batch_size is not None else default_batch_size
105
- ),
106
- "shuffle": False,
107
- "sampler": sampler,
108
- "num_workers": loader.num_workers,
109
- "collate_fn": loader.collate_fn,
110
- "drop_last": drop_last,
111
- }
112
- if getattr(loader, "pin_memory", False):
113
- loader_kwargs["pin_memory"] = True
114
- pin_memory_device = getattr(loader, "pin_memory_device", None)
115
- if pin_memory_device:
116
- loader_kwargs["pin_memory_device"] = pin_memory_device
117
- timeout = getattr(loader, "timeout", None)
118
- if timeout:
119
- loader_kwargs["timeout"] = timeout
120
- worker_init_fn = getattr(loader, "worker_init_fn", None)
121
- if worker_init_fn is not None:
122
- loader_kwargs["worker_init_fn"] = worker_init_fn
123
- generator = getattr(loader, "generator", None)
124
- if generator is not None:
125
- loader_kwargs["generator"] = generator
126
- if loader.num_workers > 0:
127
- loader_kwargs["persistent_workers"] = getattr(
128
- loader, "persistent_workers", False
129
- )
130
- prefetch_factor = getattr(loader, "prefetch_factor", None)
131
- if prefetch_factor is not None:
132
- loader_kwargs["prefetch_factor"] = prefetch_factor
133
- distributed_loader = DataLoader(dataset, **loader_kwargs)
134
- if is_main_process:
135
- logging.info(
136
- colorize(
137
- "[Distributed Info] Attached DistributedSampler to provided DataLoader",
138
- color="cyan",
139
- )
140
- )
141
- return distributed_loader, sampler
nextrec/utils/file.py DELETED
@@ -1,92 +0,0 @@
1
- """
2
- File I/O utilities for NextRec
3
-
4
- Date: create on 03/12/2025
5
- Checkpoint: edit on 06/12/2025
6
- Author: Yang Zhou, zyaztec@gmail.com
7
- """
8
-
9
- import yaml
10
- import pandas as pd
11
- import pyarrow.parquet as pq
12
-
13
- from pathlib import Path
14
- from typing import Generator
15
-
16
-
17
- def resolve_file_paths(path: str) -> tuple[list[str], str]:
18
- """
19
- Resolve file or directory path into a sorted list of files and file type.
20
-
21
- Args: path: Path to a file or directory
22
- Returns: tuple: (list of file paths, file type)
23
- """
24
- path_obj = Path(path)
25
-
26
- if path_obj.is_file():
27
- file_type = path_obj.suffix.lower().lstrip(".")
28
- assert file_type in [
29
- "csv",
30
- "parquet",
31
- ], f"Unsupported file extension: {file_type}"
32
- return [str(path_obj)], file_type
33
-
34
- if path_obj.is_dir():
35
- collected_files = [p for p in path_obj.iterdir() if p.is_file()]
36
- csv_files = [str(p) for p in collected_files if p.suffix.lower() == ".csv"]
37
- parquet_files = [
38
- str(p) for p in collected_files if p.suffix.lower() == ".parquet"
39
- ]
40
-
41
- if csv_files and parquet_files:
42
- raise ValueError(
43
- "Directory contains both CSV and Parquet files. Please keep a single format."
44
- )
45
- file_paths = csv_files if csv_files else parquet_files
46
- if not file_paths:
47
- raise ValueError(f"No CSV or Parquet files found in directory: {path}")
48
- file_paths.sort()
49
- file_type = "csv" if csv_files else "parquet"
50
- return file_paths, file_type
51
-
52
- raise ValueError(f"Invalid path: {path}")
53
-
54
-
55
- def read_table(path: str | Path, data_format: str | None = None) -> pd.DataFrame:
56
- data_path = Path(path)
57
- fmt = data_format.lower() if data_format else data_path.suffix.lower().lstrip(".")
58
- if data_path.is_dir() and not fmt:
59
- fmt = "parquet"
60
- if fmt in {"parquet", ""}:
61
- return pd.read_parquet(data_path)
62
- if fmt in {"csv", "txt"}:
63
- # Use low_memory=False to avoid mixed-type DtypeWarning on wide CSVs
64
- return pd.read_csv(data_path, low_memory=False)
65
- raise ValueError(f"Unsupported data format: {data_path}")
66
-
67
-
68
- def load_dataframes(file_paths: list[str], file_type: str) -> list[pd.DataFrame]:
69
- return [read_table(fp, file_type) for fp in file_paths]
70
-
71
-
72
- def iter_file_chunks(
73
- file_path: str, file_type: str, chunk_size: int
74
- ) -> Generator[pd.DataFrame, None, None]:
75
- if file_type == "csv":
76
- yield from pd.read_csv(file_path, chunksize=chunk_size)
77
- return
78
- parquet_file = pq.ParquetFile(file_path)
79
- for batch in parquet_file.iter_batches(batch_size=chunk_size):
80
- yield batch.to_pandas()
81
-
82
-
83
- def default_output_dir(path: str) -> Path:
84
- path_obj = Path(path)
85
- if path_obj.is_file():
86
- return path_obj.parent / f"{path_obj.stem}_preprocessed"
87
- return path_obj.with_name(f"{path_obj.name}_preprocessed")
88
-
89
-
90
- def read_yaml(path: str | Path):
91
- with open(path, "r", encoding="utf-8") as file:
92
- return yaml.safe_load(file) or {}
@@ -1,79 +0,0 @@
1
- """
2
- Initialization utilities for NextRec
3
-
4
- Date: create on 13/11/2025
5
- Author: Yang Zhou, zyaztec@gmail.com
6
- """
7
-
8
- from typing import Any, Dict, Set
9
-
10
- import torch.nn as nn
11
-
12
- KNOWN_NONLINEARITIES: Set[str] = {
13
- "linear",
14
- "conv1d",
15
- "conv2d",
16
- "conv3d",
17
- "conv_transpose1d",
18
- "conv_transpose2d",
19
- "conv_transpose3d",
20
- "sigmoid",
21
- "tanh",
22
- "relu",
23
- "leaky_relu",
24
- "selu",
25
- "gelu",
26
- }
27
-
28
-
29
- def resolve_nonlinearity(activation: str):
30
- if activation in KNOWN_NONLINEARITIES:
31
- return activation
32
- return "linear"
33
-
34
-
35
- def resolve_gain(activation: str, param: Dict[str, Any]) -> float:
36
- if "gain" in param:
37
- return param["gain"]
38
- nonlinearity = resolve_nonlinearity(activation)
39
- try:
40
- return nn.init.calculate_gain(nonlinearity, param.get("param")) # type: ignore
41
- except ValueError:
42
- return 1.0
43
-
44
-
45
- def get_initializer(
46
- init_type: str = "normal",
47
- activation: str = "linear",
48
- param: Dict[str, Any] | None = None,
49
- ):
50
- param = param or {}
51
- nonlinearity = resolve_nonlinearity(activation)
52
- gain = resolve_gain(activation, param)
53
-
54
- def initializer_fn(tensor):
55
- if init_type == "xavier_uniform":
56
- nn.init.xavier_uniform_(tensor, gain=gain)
57
- elif init_type == "xavier_normal":
58
- nn.init.xavier_normal_(tensor, gain=gain)
59
- elif init_type == "kaiming_uniform":
60
- nn.init.kaiming_uniform_(
61
- tensor, a=param.get("a", 0), nonlinearity=nonlinearity # type: ignore
62
- )
63
- elif init_type == "kaiming_normal":
64
- nn.init.kaiming_normal_(
65
- tensor, a=param.get("a", 0), nonlinearity=nonlinearity # type: ignore
66
- )
67
- elif init_type == "orthogonal":
68
- nn.init.orthogonal_(tensor, gain=gain)
69
- elif init_type == "normal":
70
- nn.init.normal_(
71
- tensor, mean=param.get("mean", 0.0), std=param.get("std", 0.0001)
72
- )
73
- elif init_type == "uniform":
74
- nn.init.uniform_(tensor, a=param.get("a", -0.05), b=param.get("b", 0.05))
75
- else:
76
- raise ValueError(f"Unknown init_type: {init_type}")
77
- return tensor
78
-
79
- return initializer_fn
@@ -1,75 +0,0 @@
1
- """
2
- Optimizer and Scheduler utilities for NextRec
3
-
4
- Date: create on 13/11/2025
5
- Author: Yang Zhou, zyaztec@gmail.com
6
- """
7
-
8
- import torch
9
- from typing import Iterable
10
-
11
-
12
- def get_optimizer(
13
- optimizer: str | torch.optim.Optimizer = "adam",
14
- params: Iterable[torch.nn.Parameter] | None = None,
15
- **optimizer_params,
16
- ):
17
- if params is None:
18
- raise ValueError("params cannot be None. Please provide model parameters.")
19
-
20
- if "lr" not in optimizer_params:
21
- optimizer_params["lr"] = 1e-3
22
- if isinstance(optimizer, str):
23
- opt_name = optimizer.lower()
24
- if opt_name == "adam":
25
- opt_class = torch.optim.Adam
26
- elif opt_name == "sgd":
27
- opt_class = torch.optim.SGD
28
- elif opt_name == "adamw":
29
- opt_class = torch.optim.AdamW
30
- elif opt_name == "adagrad":
31
- opt_class = torch.optim.Adagrad
32
- elif opt_name == "rmsprop":
33
- opt_class = torch.optim.RMSprop
34
- else:
35
- raise NotImplementedError(f"Unsupported optimizer: {optimizer}")
36
- optimizer_fn = opt_class(params=params, **optimizer_params)
37
- elif isinstance(optimizer, torch.optim.Optimizer):
38
- optimizer_fn = optimizer
39
- else:
40
- raise TypeError(f"Invalid optimizer type: {type(optimizer)}")
41
- return optimizer_fn
42
-
43
-
44
- def get_scheduler(
45
- scheduler: (
46
- str
47
- | torch.optim.lr_scheduler._LRScheduler
48
- | torch.optim.lr_scheduler.LRScheduler
49
- | type[torch.optim.lr_scheduler._LRScheduler]
50
- | type[torch.optim.lr_scheduler.LRScheduler]
51
- | None
52
- ),
53
- optimizer,
54
- **scheduler_params,
55
- ):
56
- if isinstance(scheduler, str):
57
- if scheduler == "step":
58
- scheduler_fn = torch.optim.lr_scheduler.StepLR(
59
- optimizer, **scheduler_params
60
- )
61
- elif scheduler == "cosine":
62
- scheduler_fn = torch.optim.lr_scheduler.CosineAnnealingLR(
63
- optimizer, **scheduler_params
64
- )
65
- else:
66
- raise NotImplementedError(f"Unsupported scheduler: {scheduler}")
67
- elif isinstance(
68
- scheduler,
69
- (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.LRScheduler),
70
- ):
71
- scheduler_fn = scheduler
72
- else:
73
- raise TypeError(f"Invalid scheduler type: {type(scheduler)}")
74
-
75
- return scheduler_fn
nextrec/utils/tensor.py DELETED
@@ -1,72 +0,0 @@
1
- """
2
- Tensor manipulation utilities for NextRec
3
-
4
- Date: create on 03/12/2025
5
- Author: Yang Zhou, zyaztec@gmail.com
6
- """
7
-
8
- import torch
9
- from typing import Any
10
-
11
-
12
- def to_tensor(
13
- value: Any, dtype: torch.dtype, device: torch.device | str | None = None
14
- ) -> torch.Tensor:
15
- if value is None:
16
- raise ValueError("[Tensor Utils Error] Cannot convert None to tensor.")
17
- tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
18
- if tensor.dtype != dtype:
19
- tensor = tensor.to(dtype=dtype)
20
-
21
- if device is not None:
22
- target_device = (
23
- device if isinstance(device, torch.device) else torch.device(device)
24
- )
25
- if tensor.device != target_device:
26
- tensor = tensor.to(target_device)
27
- return tensor
28
-
29
-
30
- def stack_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
31
- if not tensors:
32
- raise ValueError("[Tensor Utils Error] Cannot stack empty list of tensors.")
33
- return torch.stack(tensors, dim=dim)
34
-
35
-
36
- def concat_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
37
- if not tensors:
38
- raise ValueError(
39
- "[Tensor Utils Error] Cannot concatenate empty list of tensors."
40
- )
41
- return torch.cat(tensors, dim=dim)
42
-
43
-
44
- def pad_sequence_tensors(
45
- tensors: list[torch.Tensor],
46
- max_len: int | None = None,
47
- padding_value: float = 0.0,
48
- padding_side: str = "right",
49
- ) -> torch.Tensor:
50
- if not tensors:
51
- raise ValueError("[Tensor Utils Error] Cannot pad empty list of tensors.")
52
- if max_len is None:
53
- max_len = max(t.size(0) for t in tensors)
54
- batch_size = len(tensors)
55
- padded = torch.full(
56
- (batch_size, max_len),
57
- padding_value,
58
- dtype=tensors[0].dtype,
59
- device=tensors[0].device,
60
- )
61
-
62
- for i, tensor in enumerate(tensors):
63
- length = min(tensor.size(0), max_len)
64
- if padding_side == "right":
65
- padded[i, :length] = tensor[:length]
66
- elif padding_side == "left":
67
- padded[i, -length:] = tensor[:length]
68
- else:
69
- raise ValueError(
70
- f"[Tensor Utils Error] padding_side must be 'right' or 'left', got {padding_side}"
71
- )
72
- return padded