nextrec 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +30 -15
- nextrec/basic/features.py +1 -0
- nextrec/basic/layers.py +6 -8
- nextrec/basic/loggers.py +14 -7
- nextrec/basic/metrics.py +6 -76
- nextrec/basic/model.py +312 -318
- nextrec/cli.py +5 -10
- nextrec/data/__init__.py +13 -16
- nextrec/data/batch_utils.py +3 -2
- nextrec/data/data_processing.py +10 -2
- nextrec/data/data_utils.py +9 -14
- nextrec/data/dataloader.py +12 -13
- nextrec/data/preprocessor.py +328 -255
- nextrec/loss/__init__.py +1 -5
- nextrec/loss/loss_utils.py +2 -8
- nextrec/models/generative/__init__.py +1 -8
- nextrec/models/generative/hstu.py +6 -4
- nextrec/models/multi_task/esmm.py +2 -2
- nextrec/models/multi_task/mmoe.py +2 -2
- nextrec/models/multi_task/ple.py +2 -2
- nextrec/models/multi_task/poso.py +2 -3
- nextrec/models/multi_task/share_bottom.py +2 -2
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -2
- nextrec/models/ranking/dcn.py +2 -2
- nextrec/models/ranking/dcn_v2.py +2 -2
- nextrec/models/ranking/deepfm.py +2 -2
- nextrec/models/ranking/dien.py +3 -3
- nextrec/models/ranking/din.py +3 -3
- nextrec/models/ranking/ffm.py +0 -0
- nextrec/models/ranking/fibinet.py +5 -5
- nextrec/models/ranking/fm.py +3 -7
- nextrec/models/ranking/lr.py +0 -0
- nextrec/models/ranking/masknet.py +2 -2
- nextrec/models/ranking/pnn.py +2 -2
- nextrec/models/ranking/widedeep.py +2 -2
- nextrec/models/ranking/xdeepfm.py +2 -2
- nextrec/models/representation/__init__.py +9 -0
- nextrec/models/{generative → representation}/rqvae.py +9 -9
- nextrec/models/retrieval/__init__.py +0 -0
- nextrec/models/{match → retrieval}/dssm.py +8 -3
- nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
- nextrec/models/{match → retrieval}/mind.py +4 -3
- nextrec/models/{match → retrieval}/sdm.py +4 -3
- nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
- nextrec/utils/__init__.py +60 -46
- nextrec/utils/config.py +8 -7
- nextrec/utils/console.py +371 -0
- nextrec/utils/{synthetic_data.py → data.py} +102 -15
- nextrec/utils/feature.py +15 -0
- nextrec/utils/torch_utils.py +411 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/METADATA +6 -6
- nextrec-0.4.9.dist-info/RECORD +70 -0
- nextrec/utils/cli_utils.py +0 -58
- nextrec/utils/device.py +0 -78
- nextrec/utils/distributed.py +0 -141
- nextrec/utils/file.py +0 -92
- nextrec/utils/initializer.py +0 -79
- nextrec/utils/optimizer.py +0 -75
- nextrec/utils/tensor.py +0 -72
- nextrec-0.4.8.dist-info/RECORD +0 -71
- /nextrec/models/{match/__init__.py → ranking/eulernet.py} +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/WHEEL +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
|
|
2
|
+
nextrec/__version__.py,sha256=LdxLMJM_JXsCQBeSvnxCNyGWmINE0yWfna3DQaT41Vs,22
|
|
3
|
+
nextrec/cli.py,sha256=vumtNQww-FXgGa6I90IhEjngL1Y-e0GSuSK442s4M40,19209
|
|
4
|
+
nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
+
nextrec/basic/activation.py,sha256=uzTWfCOtBSkbu_Gk9XBNTj8__s241CaYLJk6l8nGX9I,2885
|
|
6
|
+
nextrec/basic/callback.py,sha256=a6gg7r3x1v0xaOSya9PteLql7I14nepY7gX8tDYtins,14679
|
|
7
|
+
nextrec/basic/features.py,sha256=Wnbzr7UMotgv1Vzeg0o9Po-KKIvYUSYIghoVDfMPx_g,4340
|
|
8
|
+
nextrec/basic/layers.py,sha256=GJH2Tx3IkZrYGb7-ET976iHCC28Ubck_NO9-iyY4mDI,28911
|
|
9
|
+
nextrec/basic/loggers.py,sha256=JnQiFvmsVgZ63gqBLR2ZFWrVPzkxRbzWhTdeoiJKcos,6526
|
|
10
|
+
nextrec/basic/metrics.py,sha256=8RswR_3MGvIBkT_n6fnmON2eYH-hfD7kIKVnyJJjL3o,23131
|
|
11
|
+
nextrec/basic/model.py,sha256=bQPRRIOGB_0PNJv3Zr-8mruUMBC7N2dmPWy7r3L45M4,98649
|
|
12
|
+
nextrec/basic/session.py,sha256=UOG_-EgCOxvqZwCkiEd8sgNV2G1sm_HbzKYVQw8yYDI,4483
|
|
13
|
+
nextrec/data/__init__.py,sha256=YZQjpty1pDCM7q_YNmiA2sa5kbujUw26ObLHWjMPjKY,1194
|
|
14
|
+
nextrec/data/batch_utils.py,sha256=0bYGVX7RlhnHv_ZBaUngjDIpBNw-igCk98DgOsF7T6o,2879
|
|
15
|
+
nextrec/data/data_processing.py,sha256=lKXDBszrO5fJMAQetgSPr2mSQuzOluuz1eHV4jp0TDU,5538
|
|
16
|
+
nextrec/data/data_utils.py,sha256=0Ls1cnG9lBz0ovtyedw5vwp7WegGK_iF-F8e_3DEddo,880
|
|
17
|
+
nextrec/data/dataloader.py,sha256=D2QZDxc9Ic7zkSoaJQBxrjmbHWyJ8d3k0QF3IqLZxfk,18793
|
|
18
|
+
nextrec/data/preprocessor.py,sha256=wNjivq2N-iDzBropkp3YfSkN0jSA4l4h81C-ECa6k4c,44643
|
|
19
|
+
nextrec/loss/__init__.py,sha256=-sibZK8QXLblVNWqdqjrPPzMCDyIXSq7yd2eZ57p9Nw,810
|
|
20
|
+
nextrec/loss/listwise.py,sha256=UT9vJCOTOQLogVwaeTV7Z5uxIYnngGdxk-p9e97MGkU,5744
|
|
21
|
+
nextrec/loss/loss_utils.py,sha256=Eg_EKm47onSCLhgs2q7IkB7TV9TwV1Dz4QgVR2yh-gc,4610
|
|
22
|
+
nextrec/loss/pairwise.py,sha256=X9yg-8pcPt2IWU0AiUhWAt3_4W_3wIF0uSdDYTdoPFY,3398
|
|
23
|
+
nextrec/loss/pointwise.py,sha256=o9J3OznY0hlbDsUXqn3k-BBzYiuUH5dopz8QBFqS_kQ,7343
|
|
24
|
+
nextrec/models/generative/__init__.py,sha256=QtfMCu0-R7jg_HgjrUTP8ShWzXjQnPCpQjqpGwbMzp4,176
|
|
25
|
+
nextrec/models/generative/hstu.py,sha256=P2Kl7HEL3afwiCApGKQ6UbUNO9eNXXrB10H7iiF8cI0,19735
|
|
26
|
+
nextrec/models/generative/tiger.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
27
|
+
nextrec/models/multi_task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
28
|
+
nextrec/models/multi_task/esmm.py,sha256=3QQePhSkOcM4t472S4E5xU9_KiLiSwHb9CfdkEgmqqk,6491
|
|
29
|
+
nextrec/models/multi_task/mmoe.py,sha256=uFTbc0MiFBDTCIt8mTW6xs0oyOn1EesIHHZo81HR35k,8583
|
|
30
|
+
nextrec/models/multi_task/ple.py,sha256=z32etizNlTLwwR7CYKxy8u9owAbtiRh492Fje_y64hQ,13016
|
|
31
|
+
nextrec/models/multi_task/poso.py,sha256=foH7XDUz0XN0s0zoyHLuTmrcs3QOT8-x4YGxLX1Lxxg,19016
|
|
32
|
+
nextrec/models/multi_task/share_bottom.py,sha256=rmEnsX3LA3pNsLKfG1ir5WDLdkSY-imO_ASiclirJiA,6519
|
|
33
|
+
nextrec/models/ranking/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
|
+
nextrec/models/ranking/afm.py,sha256=96jGUPL4yTWobMIVBjHpOxl9AtAzCAGR8yw7Sy2JmdQ,10125
|
|
35
|
+
nextrec/models/ranking/autoint.py,sha256=S6Cxnp1q2OErSYqmIix5P-b4qLWR-0dY6TMStuU6WLg,8109
|
|
36
|
+
nextrec/models/ranking/dcn.py,sha256=whkjiKEuadl6oSP-NJdSOCOqvWZGX4EsId9oqlfVpa8,7299
|
|
37
|
+
nextrec/models/ranking/dcn_v2.py,sha256=QnqQbJsrtQp4mtvnBXFUVefKyr4dw-gHNWrCbO26oHw,11163
|
|
38
|
+
nextrec/models/ranking/deepfm.py,sha256=EMAGhPCjJHmxpkoTWaioVgNt2yVB0PzGJpDc8cQlczs,5224
|
|
39
|
+
nextrec/models/ranking/dien.py,sha256=c7Zs85vxhOgKHg5s0QcSLCn1xXCCSD177TMERgM_v8g,18958
|
|
40
|
+
nextrec/models/ranking/din.py,sha256=gdUhuKiKXBNOALbK8fGhlbSeuDT8agcEdNSrC_wveHc,9422
|
|
41
|
+
nextrec/models/ranking/eulernet.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
|
+
nextrec/models/ranking/ffm.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
43
|
+
nextrec/models/ranking/fibinet.py,sha256=_eroddVHooJcaGT8MqS4mUrtv5j4pnTmfI3FoAKOZhs,7919
|
|
44
|
+
nextrec/models/ranking/fm.py,sha256=SsrSKK3y4xg5Lv-t3JLnZan55Hzze2AxAiVPuscy0bk,4536
|
|
45
|
+
nextrec/models/ranking/lr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
46
|
+
nextrec/models/ranking/masknet.py,sha256=tY1y2lO0iq82oylPN0SBnL5Bikc8weinFXpURyVT1hE,12373
|
|
47
|
+
nextrec/models/ranking/pnn.py,sha256=FcNIFAw5J0ORGSR6L8ZK7NeXlJPpojwe_SpsxMQqCFw,8174
|
|
48
|
+
nextrec/models/ranking/widedeep.py,sha256=-ghKfe_0puvlI9fBQr8lK3gXkfVvslGwP40AJTGqc7w,5077
|
|
49
|
+
nextrec/models/ranking/xdeepfm.py,sha256=FMtl_zYO1Ty_2d9VWRsz6Jo-Xjw8vikpIQPZCDVavVY,8156
|
|
50
|
+
nextrec/models/representation/__init__.py,sha256=O3QHMMXBszwM-mTl7bA3wawNZvDGet-QIv6Ys5GHGJ8,190
|
|
51
|
+
nextrec/models/representation/rqvae.py,sha256=JyZxVY9CibcdBGk97TxjG5O3WQC10_60tHNcP_qtegs,29290
|
|
52
|
+
nextrec/models/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
53
|
+
nextrec/models/retrieval/dssm.py,sha256=yMTu9msn6WgPqpM_zaJ0Z8R3uAYPzuJY_NS0xjJGSx0,8043
|
|
54
|
+
nextrec/models/retrieval/dssm_v2.py,sha256=IJzCT78Ra1DeSpR85KF1Awqnywu_6kDko3XDPbcup5s,7075
|
|
55
|
+
nextrec/models/retrieval/mind.py,sha256=ktNXwZxhnH1DV08c0gdsbodo4QJQL4UQGeyOqpmGwVM,15090
|
|
56
|
+
nextrec/models/retrieval/sdm.py,sha256=LhkCZSfGhxOxziEkUtjr_hnqcyciJ2qpMoBSFBVW9lQ,10558
|
|
57
|
+
nextrec/models/retrieval/youtube_dnn.py,sha256=xtGPV6_5LeSZBKkrTaU1CmtxlhgYLvZmjpwYaXYIaEA,7403
|
|
58
|
+
nextrec/utils/__init__.py,sha256=5ss2XQq8QZ2Ko5eiQ7oIig5cIZNrYGIptaarYEeO7Fk,2550
|
|
59
|
+
nextrec/utils/config.py,sha256=0HOeMyTlx8g6BZVpXzo2lEOkb-mzNwhbigQuUomsYnY,19934
|
|
60
|
+
nextrec/utils/console.py,sha256=D2Vax9_b7bgvAAOyk-Q2oUhSk1B-OngY5buS9Gb9-I0,11398
|
|
61
|
+
nextrec/utils/data.py,sha256=alruiWZFbmwy3kO12q42VXmtHmXFFjVULpHa43fx_mI,21098
|
|
62
|
+
nextrec/utils/embedding.py,sha256=akAEc062MG2cD7VIOllHaqtwzAirQR2gq5iW7oKpGAU,1449
|
|
63
|
+
nextrec/utils/feature.py,sha256=rsUAv3ELyDpehVw8nPEEsLCCIjuKGTJJZuFaWB_wrPk,633
|
|
64
|
+
nextrec/utils/model.py,sha256=dYl1XfIZt6aVjNyV2AAhcArwFRMcEAKrjG_pr8AVHs0,1163
|
|
65
|
+
nextrec/utils/torch_utils.py,sha256=AKfYbSOJjEw874xsDB5IO3Ote4X7vnqzt_E0jJny0o8,13468
|
|
66
|
+
nextrec-0.4.9.dist-info/METADATA,sha256=FnndUmBNuNT6Odd_8XTlnW71NmqZ_Wr5lFNmxVwha1k,19463
|
|
67
|
+
nextrec-0.4.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
68
|
+
nextrec-0.4.9.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
|
|
69
|
+
nextrec-0.4.9.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
|
|
70
|
+
nextrec-0.4.9.dist-info/RECORD,,
|
nextrec/utils/cli_utils.py
DELETED
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
CLI utilities for NextRec.
|
|
3
|
-
|
|
4
|
-
This module provides small helpers used by command-line entrypoints, such as
|
|
5
|
-
printing startup banners and resolving the installed package version.
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
from __future__ import annotations
|
|
9
|
-
|
|
10
|
-
import logging
|
|
11
|
-
import os
|
|
12
|
-
import platform
|
|
13
|
-
import sys
|
|
14
|
-
from datetime import datetime
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def get_nextrec_version() -> str:
|
|
18
|
-
"""
|
|
19
|
-
Best-effort version resolver for NextRec.
|
|
20
|
-
|
|
21
|
-
Prefer in-repo `nextrec.__version__`, fall back to installed package metadata.
|
|
22
|
-
"""
|
|
23
|
-
try:
|
|
24
|
-
from nextrec import __version__ # type: ignore
|
|
25
|
-
|
|
26
|
-
if __version__:
|
|
27
|
-
return str(__version__)
|
|
28
|
-
except Exception:
|
|
29
|
-
pass
|
|
30
|
-
|
|
31
|
-
try:
|
|
32
|
-
from importlib.metadata import version
|
|
33
|
-
|
|
34
|
-
return version("nextrec")
|
|
35
|
-
except Exception:
|
|
36
|
-
return "unknown"
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def log_startup_info(
|
|
40
|
-
logger: logging.Logger, *, mode: str, config_path: str | None
|
|
41
|
-
) -> None:
|
|
42
|
-
"""Log a short, user-friendly startup banner."""
|
|
43
|
-
version = get_nextrec_version()
|
|
44
|
-
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
45
|
-
|
|
46
|
-
lines = [
|
|
47
|
-
"NextRec CLI",
|
|
48
|
-
f"- Version: {version}",
|
|
49
|
-
f"- Time: {now}",
|
|
50
|
-
f"- Mode: {mode}",
|
|
51
|
-
f"- Config: {config_path or '(not set)'}",
|
|
52
|
-
f"- Python: {platform.python_version()} ({sys.executable})",
|
|
53
|
-
f"- Platform: {platform.system()} {platform.release()} ({platform.machine()})",
|
|
54
|
-
f"- Workdir: {os.getcwd()}",
|
|
55
|
-
f"- Command: {' '.join(sys.argv)}",
|
|
56
|
-
]
|
|
57
|
-
for line in lines:
|
|
58
|
-
logger.info(line)
|
nextrec/utils/device.py
DELETED
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Device management utilities for NextRec
|
|
3
|
-
|
|
4
|
-
Date: create on 03/12/2025
|
|
5
|
-
Checkpoint: edit on 06/12/2025
|
|
6
|
-
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
import torch
|
|
10
|
-
import platform
|
|
11
|
-
import logging
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def resolve_device() -> str:
|
|
15
|
-
if torch.cuda.is_available():
|
|
16
|
-
return "cuda"
|
|
17
|
-
if torch.backends.mps.is_available():
|
|
18
|
-
mac_ver = platform.mac_ver()[0]
|
|
19
|
-
try:
|
|
20
|
-
major, _ = (int(x) for x in mac_ver.split(".")[:2])
|
|
21
|
-
except Exception:
|
|
22
|
-
major, _ = 0, 0
|
|
23
|
-
if major >= 14:
|
|
24
|
-
return "mps"
|
|
25
|
-
return "cpu"
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def get_device_info() -> dict:
|
|
29
|
-
info = {
|
|
30
|
-
"cuda_available": torch.cuda.is_available(),
|
|
31
|
-
"cuda_device_count": (
|
|
32
|
-
torch.cuda.device_count() if torch.cuda.is_available() else 0
|
|
33
|
-
),
|
|
34
|
-
"mps_available": torch.backends.mps.is_available(),
|
|
35
|
-
"current_device": resolve_device(),
|
|
36
|
-
}
|
|
37
|
-
|
|
38
|
-
if torch.cuda.is_available():
|
|
39
|
-
info["cuda_device_name"] = torch.cuda.get_device_name(0)
|
|
40
|
-
info["cuda_capability"] = torch.cuda.get_device_capability(0)
|
|
41
|
-
|
|
42
|
-
return info
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def configure_device(
|
|
46
|
-
distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
|
|
47
|
-
) -> torch.device:
|
|
48
|
-
try:
|
|
49
|
-
device = torch.device(base_device)
|
|
50
|
-
except Exception:
|
|
51
|
-
logging.warning(
|
|
52
|
-
"[configure_device Warning] Invalid base_device, falling back to CPU."
|
|
53
|
-
)
|
|
54
|
-
return torch.device("cpu")
|
|
55
|
-
|
|
56
|
-
if distributed:
|
|
57
|
-
if device.type == "cuda":
|
|
58
|
-
if not torch.cuda.is_available():
|
|
59
|
-
logging.warning(
|
|
60
|
-
"[Distributed Warning] CUDA requested but unavailable. Falling back to CPU."
|
|
61
|
-
)
|
|
62
|
-
return torch.device("cpu")
|
|
63
|
-
if not (0 <= local_rank < torch.cuda.device_count()):
|
|
64
|
-
logging.warning(
|
|
65
|
-
f"[Distributed Warning] local_rank {local_rank} is invalid for available CUDA devices. Falling back to CPU."
|
|
66
|
-
)
|
|
67
|
-
return torch.device("cpu")
|
|
68
|
-
try:
|
|
69
|
-
torch.cuda.set_device(local_rank)
|
|
70
|
-
return torch.device(f"cuda:{local_rank}")
|
|
71
|
-
except Exception as exc:
|
|
72
|
-
logging.warning(
|
|
73
|
-
f"[Distributed Warning] Failed to set CUDA device for local_rank {local_rank}: {exc}. Falling back to CPU."
|
|
74
|
-
)
|
|
75
|
-
return torch.device("cpu")
|
|
76
|
-
else:
|
|
77
|
-
return torch.device("cpu")
|
|
78
|
-
return device
|
nextrec/utils/distributed.py
DELETED
|
@@ -1,141 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Distributed utilities for NextRec.
|
|
3
|
-
|
|
4
|
-
Date: create on 04/12/2025
|
|
5
|
-
Checkpoint: edit on 05/12/2025
|
|
6
|
-
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
import logging
|
|
10
|
-
import numpy as np
|
|
11
|
-
import torch
|
|
12
|
-
import torch.distributed as dist
|
|
13
|
-
|
|
14
|
-
from torch.utils.data import DataLoader, IterableDataset
|
|
15
|
-
from torch.utils.data.distributed import DistributedSampler
|
|
16
|
-
from nextrec.basic.loggers import colorize
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def init_process_group(
|
|
20
|
-
distributed: bool, rank: int, world_size: int, device_id: int | None = None
|
|
21
|
-
) -> None:
|
|
22
|
-
"""
|
|
23
|
-
initialize distributed process group for multi-GPU training.
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
distributed: whether to enable distributed training
|
|
27
|
-
rank: global rank of the current process
|
|
28
|
-
world_size: total number of processes
|
|
29
|
-
"""
|
|
30
|
-
if (not distributed) or (not dist.is_available()) or dist.is_initialized():
|
|
31
|
-
return
|
|
32
|
-
backend = "nccl" if device_id is not None else "gloo"
|
|
33
|
-
if backend == "nccl":
|
|
34
|
-
torch.cuda.set_device(device_id)
|
|
35
|
-
dist.init_process_group(
|
|
36
|
-
backend=backend, init_method="env://", rank=rank, world_size=world_size
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def gather_numpy(self, array: np.ndarray | None) -> np.ndarray | None:
|
|
41
|
-
"""
|
|
42
|
-
Gather numpy arrays (or None) across ranks. Uses all_gather_object to avoid
|
|
43
|
-
shape mismatches and ensures every rank participates even when local data is empty.
|
|
44
|
-
"""
|
|
45
|
-
if not (self.distributed and dist.is_available() and dist.is_initialized()):
|
|
46
|
-
return array
|
|
47
|
-
|
|
48
|
-
world_size = dist.get_world_size()
|
|
49
|
-
gathered: list[np.ndarray | None] = [None for _ in range(world_size)]
|
|
50
|
-
dist.all_gather_object(gathered, array)
|
|
51
|
-
pieces: list[np.ndarray] = []
|
|
52
|
-
for item in gathered:
|
|
53
|
-
if item is None:
|
|
54
|
-
continue
|
|
55
|
-
item_np = np.asarray(item)
|
|
56
|
-
if item_np.size > 0:
|
|
57
|
-
pieces.append(item_np)
|
|
58
|
-
if not pieces:
|
|
59
|
-
return None
|
|
60
|
-
return np.concatenate(pieces, axis=0)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def add_distributed_sampler(
|
|
64
|
-
loader: DataLoader,
|
|
65
|
-
distributed: bool,
|
|
66
|
-
world_size: int,
|
|
67
|
-
rank: int,
|
|
68
|
-
shuffle: bool,
|
|
69
|
-
drop_last: bool,
|
|
70
|
-
default_batch_size: int,
|
|
71
|
-
is_main_process: bool = False,
|
|
72
|
-
) -> tuple[DataLoader, DistributedSampler | None]:
|
|
73
|
-
"""
|
|
74
|
-
add distributedsampler to a dataloader, this for distributed training
|
|
75
|
-
when each device has its own dataloader
|
|
76
|
-
"""
|
|
77
|
-
# early return if not distributed
|
|
78
|
-
if not (distributed and dist.is_available() and dist.is_initialized()):
|
|
79
|
-
return loader, None
|
|
80
|
-
# return if already has DistributedSampler
|
|
81
|
-
if isinstance(loader.sampler, DistributedSampler):
|
|
82
|
-
return loader, loader.sampler
|
|
83
|
-
dataset = getattr(loader, "dataset", None)
|
|
84
|
-
if dataset is None:
|
|
85
|
-
return loader, None
|
|
86
|
-
if isinstance(dataset, IterableDataset):
|
|
87
|
-
if is_main_process:
|
|
88
|
-
logging.info(
|
|
89
|
-
colorize(
|
|
90
|
-
"[Distributed Info] Iterable/streaming DataLoader provided; DistributedSampler is skipped. Ensure dataset handles sharding per rank.",
|
|
91
|
-
color="yellow",
|
|
92
|
-
)
|
|
93
|
-
)
|
|
94
|
-
return loader, None
|
|
95
|
-
sampler = DistributedSampler(
|
|
96
|
-
dataset,
|
|
97
|
-
num_replicas=world_size,
|
|
98
|
-
rank=rank,
|
|
99
|
-
shuffle=shuffle,
|
|
100
|
-
drop_last=drop_last,
|
|
101
|
-
)
|
|
102
|
-
loader_kwargs = {
|
|
103
|
-
"batch_size": (
|
|
104
|
-
loader.batch_size if loader.batch_size is not None else default_batch_size
|
|
105
|
-
),
|
|
106
|
-
"shuffle": False,
|
|
107
|
-
"sampler": sampler,
|
|
108
|
-
"num_workers": loader.num_workers,
|
|
109
|
-
"collate_fn": loader.collate_fn,
|
|
110
|
-
"drop_last": drop_last,
|
|
111
|
-
}
|
|
112
|
-
if getattr(loader, "pin_memory", False):
|
|
113
|
-
loader_kwargs["pin_memory"] = True
|
|
114
|
-
pin_memory_device = getattr(loader, "pin_memory_device", None)
|
|
115
|
-
if pin_memory_device:
|
|
116
|
-
loader_kwargs["pin_memory_device"] = pin_memory_device
|
|
117
|
-
timeout = getattr(loader, "timeout", None)
|
|
118
|
-
if timeout:
|
|
119
|
-
loader_kwargs["timeout"] = timeout
|
|
120
|
-
worker_init_fn = getattr(loader, "worker_init_fn", None)
|
|
121
|
-
if worker_init_fn is not None:
|
|
122
|
-
loader_kwargs["worker_init_fn"] = worker_init_fn
|
|
123
|
-
generator = getattr(loader, "generator", None)
|
|
124
|
-
if generator is not None:
|
|
125
|
-
loader_kwargs["generator"] = generator
|
|
126
|
-
if loader.num_workers > 0:
|
|
127
|
-
loader_kwargs["persistent_workers"] = getattr(
|
|
128
|
-
loader, "persistent_workers", False
|
|
129
|
-
)
|
|
130
|
-
prefetch_factor = getattr(loader, "prefetch_factor", None)
|
|
131
|
-
if prefetch_factor is not None:
|
|
132
|
-
loader_kwargs["prefetch_factor"] = prefetch_factor
|
|
133
|
-
distributed_loader = DataLoader(dataset, **loader_kwargs)
|
|
134
|
-
if is_main_process:
|
|
135
|
-
logging.info(
|
|
136
|
-
colorize(
|
|
137
|
-
"[Distributed Info] Attached DistributedSampler to provided DataLoader",
|
|
138
|
-
color="cyan",
|
|
139
|
-
)
|
|
140
|
-
)
|
|
141
|
-
return distributed_loader, sampler
|
nextrec/utils/file.py
DELETED
|
@@ -1,92 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
File I/O utilities for NextRec
|
|
3
|
-
|
|
4
|
-
Date: create on 03/12/2025
|
|
5
|
-
Checkpoint: edit on 06/12/2025
|
|
6
|
-
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
import yaml
|
|
10
|
-
import pandas as pd
|
|
11
|
-
import pyarrow.parquet as pq
|
|
12
|
-
|
|
13
|
-
from pathlib import Path
|
|
14
|
-
from typing import Generator
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def resolve_file_paths(path: str) -> tuple[list[str], str]:
|
|
18
|
-
"""
|
|
19
|
-
Resolve file or directory path into a sorted list of files and file type.
|
|
20
|
-
|
|
21
|
-
Args: path: Path to a file or directory
|
|
22
|
-
Returns: tuple: (list of file paths, file type)
|
|
23
|
-
"""
|
|
24
|
-
path_obj = Path(path)
|
|
25
|
-
|
|
26
|
-
if path_obj.is_file():
|
|
27
|
-
file_type = path_obj.suffix.lower().lstrip(".")
|
|
28
|
-
assert file_type in [
|
|
29
|
-
"csv",
|
|
30
|
-
"parquet",
|
|
31
|
-
], f"Unsupported file extension: {file_type}"
|
|
32
|
-
return [str(path_obj)], file_type
|
|
33
|
-
|
|
34
|
-
if path_obj.is_dir():
|
|
35
|
-
collected_files = [p for p in path_obj.iterdir() if p.is_file()]
|
|
36
|
-
csv_files = [str(p) for p in collected_files if p.suffix.lower() == ".csv"]
|
|
37
|
-
parquet_files = [
|
|
38
|
-
str(p) for p in collected_files if p.suffix.lower() == ".parquet"
|
|
39
|
-
]
|
|
40
|
-
|
|
41
|
-
if csv_files and parquet_files:
|
|
42
|
-
raise ValueError(
|
|
43
|
-
"Directory contains both CSV and Parquet files. Please keep a single format."
|
|
44
|
-
)
|
|
45
|
-
file_paths = csv_files if csv_files else parquet_files
|
|
46
|
-
if not file_paths:
|
|
47
|
-
raise ValueError(f"No CSV or Parquet files found in directory: {path}")
|
|
48
|
-
file_paths.sort()
|
|
49
|
-
file_type = "csv" if csv_files else "parquet"
|
|
50
|
-
return file_paths, file_type
|
|
51
|
-
|
|
52
|
-
raise ValueError(f"Invalid path: {path}")
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def read_table(path: str | Path, data_format: str | None = None) -> pd.DataFrame:
|
|
56
|
-
data_path = Path(path)
|
|
57
|
-
fmt = data_format.lower() if data_format else data_path.suffix.lower().lstrip(".")
|
|
58
|
-
if data_path.is_dir() and not fmt:
|
|
59
|
-
fmt = "parquet"
|
|
60
|
-
if fmt in {"parquet", ""}:
|
|
61
|
-
return pd.read_parquet(data_path)
|
|
62
|
-
if fmt in {"csv", "txt"}:
|
|
63
|
-
# Use low_memory=False to avoid mixed-type DtypeWarning on wide CSVs
|
|
64
|
-
return pd.read_csv(data_path, low_memory=False)
|
|
65
|
-
raise ValueError(f"Unsupported data format: {data_path}")
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def load_dataframes(file_paths: list[str], file_type: str) -> list[pd.DataFrame]:
|
|
69
|
-
return [read_table(fp, file_type) for fp in file_paths]
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def iter_file_chunks(
|
|
73
|
-
file_path: str, file_type: str, chunk_size: int
|
|
74
|
-
) -> Generator[pd.DataFrame, None, None]:
|
|
75
|
-
if file_type == "csv":
|
|
76
|
-
yield from pd.read_csv(file_path, chunksize=chunk_size)
|
|
77
|
-
return
|
|
78
|
-
parquet_file = pq.ParquetFile(file_path)
|
|
79
|
-
for batch in parquet_file.iter_batches(batch_size=chunk_size):
|
|
80
|
-
yield batch.to_pandas()
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def default_output_dir(path: str) -> Path:
|
|
84
|
-
path_obj = Path(path)
|
|
85
|
-
if path_obj.is_file():
|
|
86
|
-
return path_obj.parent / f"{path_obj.stem}_preprocessed"
|
|
87
|
-
return path_obj.with_name(f"{path_obj.name}_preprocessed")
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
def read_yaml(path: str | Path):
|
|
91
|
-
with open(path, "r", encoding="utf-8") as file:
|
|
92
|
-
return yaml.safe_load(file) or {}
|
nextrec/utils/initializer.py
DELETED
|
@@ -1,79 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Initialization utilities for NextRec
|
|
3
|
-
|
|
4
|
-
Date: create on 13/11/2025
|
|
5
|
-
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
from typing import Any, Dict, Set
|
|
9
|
-
|
|
10
|
-
import torch.nn as nn
|
|
11
|
-
|
|
12
|
-
KNOWN_NONLINEARITIES: Set[str] = {
|
|
13
|
-
"linear",
|
|
14
|
-
"conv1d",
|
|
15
|
-
"conv2d",
|
|
16
|
-
"conv3d",
|
|
17
|
-
"conv_transpose1d",
|
|
18
|
-
"conv_transpose2d",
|
|
19
|
-
"conv_transpose3d",
|
|
20
|
-
"sigmoid",
|
|
21
|
-
"tanh",
|
|
22
|
-
"relu",
|
|
23
|
-
"leaky_relu",
|
|
24
|
-
"selu",
|
|
25
|
-
"gelu",
|
|
26
|
-
}
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def resolve_nonlinearity(activation: str):
|
|
30
|
-
if activation in KNOWN_NONLINEARITIES:
|
|
31
|
-
return activation
|
|
32
|
-
return "linear"
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def resolve_gain(activation: str, param: Dict[str, Any]) -> float:
|
|
36
|
-
if "gain" in param:
|
|
37
|
-
return param["gain"]
|
|
38
|
-
nonlinearity = resolve_nonlinearity(activation)
|
|
39
|
-
try:
|
|
40
|
-
return nn.init.calculate_gain(nonlinearity, param.get("param")) # type: ignore
|
|
41
|
-
except ValueError:
|
|
42
|
-
return 1.0
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def get_initializer(
|
|
46
|
-
init_type: str = "normal",
|
|
47
|
-
activation: str = "linear",
|
|
48
|
-
param: Dict[str, Any] | None = None,
|
|
49
|
-
):
|
|
50
|
-
param = param or {}
|
|
51
|
-
nonlinearity = resolve_nonlinearity(activation)
|
|
52
|
-
gain = resolve_gain(activation, param)
|
|
53
|
-
|
|
54
|
-
def initializer_fn(tensor):
|
|
55
|
-
if init_type == "xavier_uniform":
|
|
56
|
-
nn.init.xavier_uniform_(tensor, gain=gain)
|
|
57
|
-
elif init_type == "xavier_normal":
|
|
58
|
-
nn.init.xavier_normal_(tensor, gain=gain)
|
|
59
|
-
elif init_type == "kaiming_uniform":
|
|
60
|
-
nn.init.kaiming_uniform_(
|
|
61
|
-
tensor, a=param.get("a", 0), nonlinearity=nonlinearity # type: ignore
|
|
62
|
-
)
|
|
63
|
-
elif init_type == "kaiming_normal":
|
|
64
|
-
nn.init.kaiming_normal_(
|
|
65
|
-
tensor, a=param.get("a", 0), nonlinearity=nonlinearity # type: ignore
|
|
66
|
-
)
|
|
67
|
-
elif init_type == "orthogonal":
|
|
68
|
-
nn.init.orthogonal_(tensor, gain=gain)
|
|
69
|
-
elif init_type == "normal":
|
|
70
|
-
nn.init.normal_(
|
|
71
|
-
tensor, mean=param.get("mean", 0.0), std=param.get("std", 0.0001)
|
|
72
|
-
)
|
|
73
|
-
elif init_type == "uniform":
|
|
74
|
-
nn.init.uniform_(tensor, a=param.get("a", -0.05), b=param.get("b", 0.05))
|
|
75
|
-
else:
|
|
76
|
-
raise ValueError(f"Unknown init_type: {init_type}")
|
|
77
|
-
return tensor
|
|
78
|
-
|
|
79
|
-
return initializer_fn
|
nextrec/utils/optimizer.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Optimizer and Scheduler utilities for NextRec
|
|
3
|
-
|
|
4
|
-
Date: create on 13/11/2025
|
|
5
|
-
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
from typing import Iterable
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def get_optimizer(
|
|
13
|
-
optimizer: str | torch.optim.Optimizer = "adam",
|
|
14
|
-
params: Iterable[torch.nn.Parameter] | None = None,
|
|
15
|
-
**optimizer_params,
|
|
16
|
-
):
|
|
17
|
-
if params is None:
|
|
18
|
-
raise ValueError("params cannot be None. Please provide model parameters.")
|
|
19
|
-
|
|
20
|
-
if "lr" not in optimizer_params:
|
|
21
|
-
optimizer_params["lr"] = 1e-3
|
|
22
|
-
if isinstance(optimizer, str):
|
|
23
|
-
opt_name = optimizer.lower()
|
|
24
|
-
if opt_name == "adam":
|
|
25
|
-
opt_class = torch.optim.Adam
|
|
26
|
-
elif opt_name == "sgd":
|
|
27
|
-
opt_class = torch.optim.SGD
|
|
28
|
-
elif opt_name == "adamw":
|
|
29
|
-
opt_class = torch.optim.AdamW
|
|
30
|
-
elif opt_name == "adagrad":
|
|
31
|
-
opt_class = torch.optim.Adagrad
|
|
32
|
-
elif opt_name == "rmsprop":
|
|
33
|
-
opt_class = torch.optim.RMSprop
|
|
34
|
-
else:
|
|
35
|
-
raise NotImplementedError(f"Unsupported optimizer: {optimizer}")
|
|
36
|
-
optimizer_fn = opt_class(params=params, **optimizer_params)
|
|
37
|
-
elif isinstance(optimizer, torch.optim.Optimizer):
|
|
38
|
-
optimizer_fn = optimizer
|
|
39
|
-
else:
|
|
40
|
-
raise TypeError(f"Invalid optimizer type: {type(optimizer)}")
|
|
41
|
-
return optimizer_fn
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def get_scheduler(
|
|
45
|
-
scheduler: (
|
|
46
|
-
str
|
|
47
|
-
| torch.optim.lr_scheduler._LRScheduler
|
|
48
|
-
| torch.optim.lr_scheduler.LRScheduler
|
|
49
|
-
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
50
|
-
| type[torch.optim.lr_scheduler.LRScheduler]
|
|
51
|
-
| None
|
|
52
|
-
),
|
|
53
|
-
optimizer,
|
|
54
|
-
**scheduler_params,
|
|
55
|
-
):
|
|
56
|
-
if isinstance(scheduler, str):
|
|
57
|
-
if scheduler == "step":
|
|
58
|
-
scheduler_fn = torch.optim.lr_scheduler.StepLR(
|
|
59
|
-
optimizer, **scheduler_params
|
|
60
|
-
)
|
|
61
|
-
elif scheduler == "cosine":
|
|
62
|
-
scheduler_fn = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
63
|
-
optimizer, **scheduler_params
|
|
64
|
-
)
|
|
65
|
-
else:
|
|
66
|
-
raise NotImplementedError(f"Unsupported scheduler: {scheduler}")
|
|
67
|
-
elif isinstance(
|
|
68
|
-
scheduler,
|
|
69
|
-
(torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.LRScheduler),
|
|
70
|
-
):
|
|
71
|
-
scheduler_fn = scheduler
|
|
72
|
-
else:
|
|
73
|
-
raise TypeError(f"Invalid scheduler type: {type(scheduler)}")
|
|
74
|
-
|
|
75
|
-
return scheduler_fn
|
nextrec/utils/tensor.py
DELETED
|
@@ -1,72 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Tensor manipulation utilities for NextRec
|
|
3
|
-
|
|
4
|
-
Date: create on 03/12/2025
|
|
5
|
-
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def to_tensor(
|
|
13
|
-
value: Any, dtype: torch.dtype, device: torch.device | str | None = None
|
|
14
|
-
) -> torch.Tensor:
|
|
15
|
-
if value is None:
|
|
16
|
-
raise ValueError("[Tensor Utils Error] Cannot convert None to tensor.")
|
|
17
|
-
tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
|
|
18
|
-
if tensor.dtype != dtype:
|
|
19
|
-
tensor = tensor.to(dtype=dtype)
|
|
20
|
-
|
|
21
|
-
if device is not None:
|
|
22
|
-
target_device = (
|
|
23
|
-
device if isinstance(device, torch.device) else torch.device(device)
|
|
24
|
-
)
|
|
25
|
-
if tensor.device != target_device:
|
|
26
|
-
tensor = tensor.to(target_device)
|
|
27
|
-
return tensor
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def stack_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
|
31
|
-
if not tensors:
|
|
32
|
-
raise ValueError("[Tensor Utils Error] Cannot stack empty list of tensors.")
|
|
33
|
-
return torch.stack(tensors, dim=dim)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def concat_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
|
37
|
-
if not tensors:
|
|
38
|
-
raise ValueError(
|
|
39
|
-
"[Tensor Utils Error] Cannot concatenate empty list of tensors."
|
|
40
|
-
)
|
|
41
|
-
return torch.cat(tensors, dim=dim)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def pad_sequence_tensors(
|
|
45
|
-
tensors: list[torch.Tensor],
|
|
46
|
-
max_len: int | None = None,
|
|
47
|
-
padding_value: float = 0.0,
|
|
48
|
-
padding_side: str = "right",
|
|
49
|
-
) -> torch.Tensor:
|
|
50
|
-
if not tensors:
|
|
51
|
-
raise ValueError("[Tensor Utils Error] Cannot pad empty list of tensors.")
|
|
52
|
-
if max_len is None:
|
|
53
|
-
max_len = max(t.size(0) for t in tensors)
|
|
54
|
-
batch_size = len(tensors)
|
|
55
|
-
padded = torch.full(
|
|
56
|
-
(batch_size, max_len),
|
|
57
|
-
padding_value,
|
|
58
|
-
dtype=tensors[0].dtype,
|
|
59
|
-
device=tensors[0].device,
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
for i, tensor in enumerate(tensors):
|
|
63
|
-
length = min(tensor.size(0), max_len)
|
|
64
|
-
if padding_side == "right":
|
|
65
|
-
padded[i, :length] = tensor[:length]
|
|
66
|
-
elif padding_side == "left":
|
|
67
|
-
padded[i, -length:] = tensor[:length]
|
|
68
|
-
else:
|
|
69
|
-
raise ValueError(
|
|
70
|
-
f"[Tensor Utils Error] padding_side must be 'right' or 'left', got {padding_side}"
|
|
71
|
-
)
|
|
72
|
-
return padded
|