nextrec 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/features.py +5 -1
- nextrec/basic/layers.py +3 -7
- nextrec/basic/model.py +495 -664
- nextrec/data/data_utils.py +44 -12
- nextrec/data/dataloader.py +84 -285
- nextrec/data/preprocessor.py +91 -213
- nextrec/loss/__init__.py +0 -1
- nextrec/loss/loss_utils.py +51 -120
- nextrec/models/multi_task/esmm.py +1 -1
- nextrec/models/ranking/masknet.py +1 -1
- nextrec/utils/__init__.py +4 -1
- nextrec/utils/common.py +16 -0
- {nextrec-0.2.4.dist-info → nextrec-0.2.5.dist-info}/METADATA +2 -2
- {nextrec-0.2.4.dist-info → nextrec-0.2.5.dist-info}/RECORD +17 -16
- {nextrec-0.2.4.dist-info → nextrec-0.2.5.dist-info}/WHEEL +0 -0
- {nextrec-0.2.4.dist-info → nextrec-0.2.5.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/__init__.py
CHANGED
nextrec/loss/loss_utils.py
CHANGED
|
@@ -21,138 +21,69 @@ from nextrec.loss.pointwise import (
|
|
|
21
21
|
WeightedBCELoss,
|
|
22
22
|
)
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
VALID_TASK_TYPES = [
|
|
26
|
-
"binary",
|
|
27
|
-
"multiclass",
|
|
28
|
-
"
|
|
29
|
-
"
|
|
30
|
-
"match",
|
|
31
|
-
"ranking",
|
|
32
|
-
"multitask",
|
|
33
|
-
"multilabel",
|
|
26
|
+
"binary",
|
|
27
|
+
"multiclass",
|
|
28
|
+
"multilabel",
|
|
29
|
+
"regression",
|
|
34
30
|
]
|
|
35
31
|
|
|
32
|
+
def _build_cb_focal(kw):
|
|
33
|
+
if "class_counts" not in kw:
|
|
34
|
+
raise ValueError("class_balanced_focal requires class_counts")
|
|
35
|
+
return ClassBalancedFocalLoss(**kw)
|
|
36
36
|
|
|
37
|
-
def get_loss_fn(
|
|
38
|
-
task_type: str = "binary",
|
|
39
|
-
training_mode: str | None = None,
|
|
40
|
-
loss: str | nn.Module | None = None,
|
|
41
|
-
**loss_kwargs,
|
|
42
|
-
) -> nn.Module:
|
|
43
|
-
"""
|
|
44
|
-
Get loss function based on task type and training mode.
|
|
45
|
-
"""
|
|
46
37
|
|
|
38
|
+
def get_loss_fn(loss=None, **kw):
|
|
47
39
|
if isinstance(loss, nn.Module):
|
|
48
40
|
return loss
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
if
|
|
52
|
-
return
|
|
53
|
-
|
|
54
|
-
if task_type in ["ranking", "multitask", "binary", "multilabel"]:
|
|
55
|
-
return _get_classification_loss(loss, **loss_kwargs)
|
|
56
|
-
|
|
57
|
-
if task_type == "multiclass":
|
|
58
|
-
return _get_multiclass_loss(loss, **loss_kwargs)
|
|
59
|
-
|
|
60
|
-
if task_type == "regression":
|
|
61
|
-
if loss is None or loss == "mse":
|
|
62
|
-
return nn.MSELoss(**loss_kwargs)
|
|
63
|
-
if loss == "mae":
|
|
64
|
-
return nn.L1Loss(**loss_kwargs)
|
|
65
|
-
if isinstance(loss, str):
|
|
66
|
-
raise ValueError(f"Unsupported regression loss: {loss}")
|
|
67
|
-
|
|
68
|
-
raise ValueError(f"Unsupported task_type: {task_type}")
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def _get_match_loss(training_mode: str | None, loss: str | None, **loss_kwargs) -> nn.Module:
|
|
72
|
-
if training_mode == "pointwise":
|
|
73
|
-
if loss is None or loss in {"bce", "binary_crossentropy"}:
|
|
74
|
-
return nn.BCELoss(**loss_kwargs)
|
|
75
|
-
if loss == "weighted_bce":
|
|
76
|
-
return WeightedBCELoss(**loss_kwargs)
|
|
77
|
-
if loss == "focal":
|
|
78
|
-
return FocalLoss(**loss_kwargs)
|
|
79
|
-
if loss == "class_balanced_focal":
|
|
80
|
-
return _build_cb_focal(loss_kwargs)
|
|
81
|
-
if loss == "cosine_contrastive":
|
|
82
|
-
return CosineContrastiveLoss(**loss_kwargs)
|
|
83
|
-
if isinstance(loss, str):
|
|
84
|
-
raise ValueError(f"Unsupported pointwise loss: {loss}")
|
|
85
|
-
|
|
86
|
-
if training_mode == "pairwise":
|
|
87
|
-
if loss is None or loss == "bpr":
|
|
88
|
-
return BPRLoss(**loss_kwargs)
|
|
89
|
-
if loss == "hinge":
|
|
90
|
-
return HingeLoss(**loss_kwargs)
|
|
91
|
-
if loss == "triplet":
|
|
92
|
-
return TripletLoss(**loss_kwargs)
|
|
93
|
-
if isinstance(loss, str):
|
|
94
|
-
raise ValueError(f"Unsupported pairwise loss: {loss}")
|
|
95
|
-
|
|
96
|
-
if training_mode == "listwise":
|
|
97
|
-
if loss is None or loss in {"sampled_softmax", "softmax"}:
|
|
98
|
-
return SampledSoftmaxLoss(**loss_kwargs)
|
|
99
|
-
if loss == "infonce":
|
|
100
|
-
return InfoNCELoss(**loss_kwargs)
|
|
101
|
-
if loss == "listnet":
|
|
102
|
-
return ListNetLoss(**loss_kwargs)
|
|
103
|
-
if loss == "listmle":
|
|
104
|
-
return ListMLELoss(**loss_kwargs)
|
|
105
|
-
if loss == "approx_ndcg":
|
|
106
|
-
return ApproxNDCGLoss(**loss_kwargs)
|
|
107
|
-
if loss in {"crossentropy", "ce"}:
|
|
108
|
-
return nn.CrossEntropyLoss(**loss_kwargs)
|
|
109
|
-
if isinstance(loss, str):
|
|
110
|
-
raise ValueError(f"Unsupported listwise loss: {loss}")
|
|
111
|
-
|
|
112
|
-
raise ValueError(f"Unknown training_mode: {training_mode}")
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def _get_classification_loss(loss: str | None, **loss_kwargs) -> nn.Module:
|
|
116
|
-
if loss is None or loss in {"bce", "binary_crossentropy"}:
|
|
117
|
-
return nn.BCELoss(**loss_kwargs)
|
|
41
|
+
if loss is None:
|
|
42
|
+
raise ValueError("loss must be provided explicitly")
|
|
43
|
+
if loss in ["bce", "binary_crossentropy"]:
|
|
44
|
+
return nn.BCELoss(**kw)
|
|
118
45
|
if loss == "weighted_bce":
|
|
119
|
-
return WeightedBCELoss(**
|
|
120
|
-
if loss
|
|
121
|
-
return FocalLoss(**
|
|
122
|
-
if loss
|
|
123
|
-
return _build_cb_focal(
|
|
46
|
+
return WeightedBCELoss(**kw)
|
|
47
|
+
if loss in ["focal", "focal_loss"]:
|
|
48
|
+
return FocalLoss(**kw)
|
|
49
|
+
if loss in ["cb_focal", "class_balanced_focal"]:
|
|
50
|
+
return _build_cb_focal(kw)
|
|
51
|
+
if loss in ["crossentropy", "ce"]:
|
|
52
|
+
return nn.CrossEntropyLoss(**kw)
|
|
124
53
|
if loss == "mse":
|
|
125
|
-
return nn.MSELoss(**
|
|
54
|
+
return nn.MSELoss(**kw)
|
|
126
55
|
if loss == "mae":
|
|
127
|
-
return nn.L1Loss(**
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
if
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
if loss
|
|
139
|
-
return
|
|
140
|
-
if loss == "
|
|
141
|
-
return
|
|
142
|
-
if
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
return ClassBalancedFocalLoss(**loss_kwargs)
|
|
151
|
-
|
|
56
|
+
return nn.L1Loss(**kw)
|
|
57
|
+
|
|
58
|
+
# Pairwise ranking Loss
|
|
59
|
+
if loss == "bpr":
|
|
60
|
+
return BPRLoss(**kw)
|
|
61
|
+
if loss == "hinge":
|
|
62
|
+
return HingeLoss(**kw)
|
|
63
|
+
if loss == "triplet":
|
|
64
|
+
return TripletLoss(**kw)
|
|
65
|
+
|
|
66
|
+
# Listwise ranking Loss
|
|
67
|
+
if loss in ["sampled_softmax", "softmax"]:
|
|
68
|
+
return SampledSoftmaxLoss(**kw)
|
|
69
|
+
if loss == "infonce":
|
|
70
|
+
return InfoNCELoss(**kw)
|
|
71
|
+
if loss == "listnet":
|
|
72
|
+
return ListNetLoss(**kw)
|
|
73
|
+
if loss == "listmle":
|
|
74
|
+
return ListMLELoss(**kw)
|
|
75
|
+
if loss == "approx_ndcg":
|
|
76
|
+
return ApproxNDCGLoss(**kw)
|
|
77
|
+
|
|
78
|
+
raise ValueError(f"Unsupported loss: {loss}")
|
|
152
79
|
|
|
153
80
|
def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> dict:
|
|
154
81
|
"""
|
|
155
|
-
|
|
82
|
+
解析每个 head 对应的 loss_kwargs。
|
|
83
|
+
|
|
84
|
+
- loss_params 为 None -> {}
|
|
85
|
+
- loss_params 为 dict -> 所有 head 共用
|
|
86
|
+
- loss_params 为 list[dict] -> 用 loss_params[index](若存在且非 None),否则 {}
|
|
156
87
|
"""
|
|
157
88
|
if loss_params is None:
|
|
158
89
|
return {}
|
|
@@ -160,4 +91,4 @@ def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> di
|
|
|
160
91
|
if index < len(loss_params) and loss_params[index] is not None:
|
|
161
92
|
return loss_params[index]
|
|
162
93
|
return {}
|
|
163
|
-
return loss_params
|
|
94
|
+
return loss_params
|
|
@@ -40,7 +40,7 @@ class ESMM(BaseModel):
|
|
|
40
40
|
ctr_params: dict,
|
|
41
41
|
cvr_params: dict,
|
|
42
42
|
target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
|
|
43
|
-
task:
|
|
43
|
+
task: list[str] = ['binary', 'binary'],
|
|
44
44
|
optimizer: str = "adam",
|
|
45
45
|
optimizer_params: dict = {},
|
|
46
46
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
nextrec/utils/__init__.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
|
1
1
|
from .optimizer import get_optimizer, get_scheduler
|
|
2
2
|
from .initializer import get_initializer
|
|
3
3
|
from .embedding import get_auto_embedding_dim
|
|
4
|
-
from . import
|
|
4
|
+
from .common import resolve_device
|
|
5
|
+
from . import optimizer, initializer, embedding, common
|
|
5
6
|
|
|
6
7
|
__all__ = [
|
|
7
8
|
'get_optimizer',
|
|
8
9
|
'get_scheduler',
|
|
9
10
|
'get_initializer',
|
|
10
11
|
'get_auto_embedding_dim',
|
|
12
|
+
'resolve_device',
|
|
11
13
|
'optimizer',
|
|
12
14
|
'initializer',
|
|
13
15
|
'embedding',
|
|
16
|
+
'common',
|
|
14
17
|
]
|
nextrec/utils/common.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import platform
|
|
3
|
+
|
|
4
|
+
def resolve_device() -> str:
|
|
5
|
+
"""Select a usable device with graceful fallback."""
|
|
6
|
+
if torch.cuda.is_available():
|
|
7
|
+
return "cuda"
|
|
8
|
+
if torch.backends.mps.is_available():
|
|
9
|
+
mac_ver = platform.mac_ver()[0]
|
|
10
|
+
try:
|
|
11
|
+
major, minor = (int(x) for x in mac_ver.split(".")[:2])
|
|
12
|
+
except Exception:
|
|
13
|
+
major, minor = 0, 0
|
|
14
|
+
if major >= 14:
|
|
15
|
+
return "mps"
|
|
16
|
+
return "cpu"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -61,7 +61,7 @@ Description-Content-Type: text/markdown
|
|
|
61
61
|

|
|
62
62
|

|
|
63
63
|

|
|
64
|
-

|
|
65
65
|
|
|
66
66
|
English | [中文版](README_zh.md)
|
|
67
67
|
|
|
@@ -1,21 +1,21 @@
|
|
|
1
1
|
nextrec/__init__.py,sha256=CvocnY2uBp0cjNkhrT6ogw0q2bN9s1GNp754FLO-7lo,1117
|
|
2
|
-
nextrec/__version__.py,sha256=
|
|
2
|
+
nextrec/__version__.py,sha256=Xsa3ayOMVkhUWm4t06YeyHE0apjpZefxLH4ylp0CDtU,22
|
|
3
3
|
nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
nextrec/basic/activation.py,sha256=9EfYmwE0brTSKwx_0FIGQ_rybFBT9n_G-UWA7NAhMsI,2804
|
|
5
5
|
nextrec/basic/callback.py,sha256=qkq3k8rP0g4BW2C3FSCdVt_CyCcJwJ-rUXjhT2p4LP8,1035
|
|
6
|
-
nextrec/basic/features.py,sha256=
|
|
7
|
-
nextrec/basic/layers.py,sha256=
|
|
6
|
+
nextrec/basic/features.py,sha256=pQyqs-hGFN8T6goTxIjaXxw9S5PG57OWX7EpsKFlb4c,4194
|
|
7
|
+
nextrec/basic/layers.py,sha256=AezH_WYvU45eLw5EHmyAC69MCUkRTv_ZKbNW4WBC5iE,38071
|
|
8
8
|
nextrec/basic/loggers.py,sha256=x8lzyyK-uqBN5XGOm1Cb33dmfc2bl114n6QeFTtE54k,3752
|
|
9
9
|
nextrec/basic/metrics.py,sha256=w8tGe2tTbBNz9A1TNZF3jSpxcNC6QvFP5I0lWRd0Nw4,20398
|
|
10
|
-
nextrec/basic/model.py,sha256=
|
|
10
|
+
nextrec/basic/model.py,sha256=k9dbV4CP-1wvr-QLJ0dF6nYtGugXQIFR6J8kZDS9iSs,63968
|
|
11
11
|
nextrec/basic/session.py,sha256=2kogEjgKAN1_ygelbwoqOs187BAcUnDTqXG1w_Pgb9I,4791
|
|
12
12
|
nextrec/data/__init__.py,sha256=HLnARJrqDEVPTcofPSAEimy2Oj15vbomj-7UvT4ze_4,767
|
|
13
|
-
nextrec/data/data_utils.py,sha256=
|
|
14
|
-
nextrec/data/dataloader.py,sha256=
|
|
15
|
-
nextrec/data/preprocessor.py,sha256=
|
|
16
|
-
nextrec/loss/__init__.py,sha256=
|
|
13
|
+
nextrec/data/data_utils.py,sha256=xz0xVBA7UzHXz7r_Yf0eMB5RrarPKg_1ZTdWvAqRZCM,7623
|
|
14
|
+
nextrec/data/dataloader.py,sha256=vtgt2B7rUmIG7wg-HE2ZesBaD6cuS2PwklFCWGA9tCw,14142
|
|
15
|
+
nextrec/data/preprocessor.py,sha256=J-3fo_LIz100spqCHoSpewYcneiZwhaCKyRdroPSjeY,41548
|
|
16
|
+
nextrec/loss/__init__.py,sha256=mO5t417BneZ8Ysa51GyjDaffjWyjzFgPXIQrrggasaQ,827
|
|
17
17
|
nextrec/loss/listwise.py,sha256=LcYIPf6PGRtjV_AoWaAyp3rse904S2MghE5t032I07I,5628
|
|
18
|
-
nextrec/loss/loss_utils.py,sha256=
|
|
18
|
+
nextrec/loss/loss_utils.py,sha256=cFNSvv-eaFwcfjLgxN3yNmf0L7ofC0ysgkUYjliLBpE,2535
|
|
19
19
|
nextrec/loss/pairwise.py,sha256=RuQuTE-EkLaHQvT9m0CTAXxneTnVQLF1Pi9wblEClI8,3289
|
|
20
20
|
nextrec/loss/pointwise.py,sha256=6QveizdohzQTxAoBKTVSoCBpp-fy3JC8vCjImXa7jL0,7157
|
|
21
21
|
nextrec/models/generative/hstu.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -26,7 +26,7 @@ nextrec/models/match/dssm_v2.py,sha256=ywtqTy3YN9ke_7kzcDp7Fhtldw9RJz6yfewxALJb6
|
|
|
26
26
|
nextrec/models/match/mind.py,sha256=XSUDlZ-V95JXHHBDUl5sz99SaVuQKDvf3TArVjwUexs,9417
|
|
27
27
|
nextrec/models/match/sdm.py,sha256=96yfMQ6arP6JRhAkDTGEjlBiTteznMykrDV_3jqvvVk,10920
|
|
28
28
|
nextrec/models/match/youtube_dnn.py,sha256=pnrz9LYu65Fj4neOriFF45B5k2-yYiiREtQICxxYXZ0,7546
|
|
29
|
-
nextrec/models/multi_task/esmm.py,sha256=
|
|
29
|
+
nextrec/models/multi_task/esmm.py,sha256=U9DkYxwAhD_uWB8H2tsx1zfIStp5Xp8bvM6sc8S9tu4,4889
|
|
30
30
|
nextrec/models/multi_task/mmoe.py,sha256=zhQr43Vfz7Kgi6B9pKPmaenp_38a_D7w4VvlpwCyF6Y,6165
|
|
31
31
|
nextrec/models/multi_task/ple.py,sha256=otP6oLgzrJhwkLFItzNE-AtIPouObDkafRvWzTCxfNo,11335
|
|
32
32
|
nextrec/models/multi_task/share_bottom.py,sha256=LL5HBVlvvBzHV2fLBRQMGIwpqmlxILTgU4c51XyTCo4,4517
|
|
@@ -39,15 +39,16 @@ nextrec/models/ranking/dien.py,sha256=E6s9TDwQfGSwtzzh8hG2F5gwgVxzVZPcptYvHLNzOL
|
|
|
39
39
|
nextrec/models/ranking/din.py,sha256=j5tkT5k91CbsMlMr5vJOySrcY2_rFGxmEgJJ0McW7-Q,7196
|
|
40
40
|
nextrec/models/ranking/fibinet.py,sha256=X6CbQbritvq5jql_Tvs4bn_tRla2zpWPplftZv8k6f0,4853
|
|
41
41
|
nextrec/models/ranking/fm.py,sha256=3Qx_Fgowegr6UPQtEeTmHtOrbWzkvqH94ZTjOqRLu-E,2961
|
|
42
|
-
nextrec/models/ranking/masknet.py,sha256=
|
|
42
|
+
nextrec/models/ranking/masknet.py,sha256=Tx5deIv7oShm4DdXX1IJL8Hz8-5uGqcPMK7pj00xTHg,12230
|
|
43
43
|
nextrec/models/ranking/pnn.py,sha256=5RxIKdxD0XcGq-b_QDdwGRwk6b_5BQjyMvCw3Ibv2Kk,4957
|
|
44
44
|
nextrec/models/ranking/widedeep.py,sha256=b6ctElaZPv5WSYDA4piYUBo3je0eJpWpWECwcuWavM4,3716
|
|
45
45
|
nextrec/models/ranking/xdeepfm.py,sha256=I00J5tfE4tPluqeW-qrNtE4V_9fC7-rgFvA0Fxqka7o,4274
|
|
46
|
-
nextrec/utils/__init__.py,sha256=
|
|
46
|
+
nextrec/utils/__init__.py,sha256=A3mH6M-DmDBWQ1stIIaTsNzvUy_AKaUWtRmrzU5R3FE,429
|
|
47
|
+
nextrec/utils/common.py,sha256=-n4wSbP-EptpzLcJv6fV-ytBzPliOj6m-mrK_Qk6s4A,458
|
|
47
48
|
nextrec/utils/embedding.py,sha256=yxYSdFx0cJITh3Gf-K4SdhwRtKGcI0jOsyBgZ0NLa_c,465
|
|
48
49
|
nextrec/utils/initializer.py,sha256=ffYOs5QuIns_d_-5e40iNtg6s1ftgREJN-ueq_NbDQE,1647
|
|
49
50
|
nextrec/utils/optimizer.py,sha256=85ifoy2IQgjPHOqLqr1ho7XBGE_0ry1yEB9efS6C2lM,2446
|
|
50
|
-
nextrec-0.2.
|
|
51
|
-
nextrec-0.2.
|
|
52
|
-
nextrec-0.2.
|
|
53
|
-
nextrec-0.2.
|
|
51
|
+
nextrec-0.2.5.dist-info/METADATA,sha256=Ya8KTj9x1ozIaciXYTKnTFBLGiC4buUBIz-jVHHAM3s,11425
|
|
52
|
+
nextrec-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
53
|
+
nextrec-0.2.5.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
|
|
54
|
+
nextrec-0.2.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|