nextrec 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/loss/__init__.py CHANGED
@@ -37,6 +37,5 @@ __all__ = [
37
37
  # Utilities
38
38
  "get_loss_fn",
39
39
  "get_loss_kwargs",
40
- "validate_training_mode",
41
40
  "VALID_TASK_TYPES",
42
41
  ]
@@ -21,138 +21,69 @@ from nextrec.loss.pointwise import (
21
21
  WeightedBCELoss,
22
22
  )
23
23
 
24
- # Valid task types for validation
24
+
25
25
  VALID_TASK_TYPES = [
26
- "binary",
27
- "multiclass",
28
- "regression",
29
- "multivariate_regression",
30
- "match",
31
- "ranking",
32
- "multitask",
33
- "multilabel",
26
+ "binary",
27
+ "multiclass",
28
+ "multilabel",
29
+ "regression",
34
30
  ]
35
31
 
32
+ def _build_cb_focal(kw):
33
+ if "class_counts" not in kw:
34
+ raise ValueError("class_balanced_focal requires class_counts")
35
+ return ClassBalancedFocalLoss(**kw)
36
36
 
37
- def get_loss_fn(
38
- task_type: str = "binary",
39
- training_mode: str | None = None,
40
- loss: str | nn.Module | None = None,
41
- **loss_kwargs,
42
- ) -> nn.Module:
43
- """
44
- Get loss function based on task type and training mode.
45
- """
46
37
 
38
+ def get_loss_fn(loss=None, **kw):
47
39
  if isinstance(loss, nn.Module):
48
40
  return loss
49
-
50
- # Common mappings
51
- if task_type == "match":
52
- return _get_match_loss(training_mode, loss, **loss_kwargs)
53
-
54
- if task_type in ["ranking", "multitask", "binary", "multilabel"]:
55
- return _get_classification_loss(loss, **loss_kwargs)
56
-
57
- if task_type == "multiclass":
58
- return _get_multiclass_loss(loss, **loss_kwargs)
59
-
60
- if task_type == "regression":
61
- if loss is None or loss == "mse":
62
- return nn.MSELoss(**loss_kwargs)
63
- if loss == "mae":
64
- return nn.L1Loss(**loss_kwargs)
65
- if isinstance(loss, str):
66
- raise ValueError(f"Unsupported regression loss: {loss}")
67
-
68
- raise ValueError(f"Unsupported task_type: {task_type}")
69
-
70
-
71
- def _get_match_loss(training_mode: str | None, loss: str | None, **loss_kwargs) -> nn.Module:
72
- if training_mode == "pointwise":
73
- if loss is None or loss in {"bce", "binary_crossentropy"}:
74
- return nn.BCELoss(**loss_kwargs)
75
- if loss == "weighted_bce":
76
- return WeightedBCELoss(**loss_kwargs)
77
- if loss == "focal":
78
- return FocalLoss(**loss_kwargs)
79
- if loss == "class_balanced_focal":
80
- return _build_cb_focal(loss_kwargs)
81
- if loss == "cosine_contrastive":
82
- return CosineContrastiveLoss(**loss_kwargs)
83
- if isinstance(loss, str):
84
- raise ValueError(f"Unsupported pointwise loss: {loss}")
85
-
86
- if training_mode == "pairwise":
87
- if loss is None or loss == "bpr":
88
- return BPRLoss(**loss_kwargs)
89
- if loss == "hinge":
90
- return HingeLoss(**loss_kwargs)
91
- if loss == "triplet":
92
- return TripletLoss(**loss_kwargs)
93
- if isinstance(loss, str):
94
- raise ValueError(f"Unsupported pairwise loss: {loss}")
95
-
96
- if training_mode == "listwise":
97
- if loss is None or loss in {"sampled_softmax", "softmax"}:
98
- return SampledSoftmaxLoss(**loss_kwargs)
99
- if loss == "infonce":
100
- return InfoNCELoss(**loss_kwargs)
101
- if loss == "listnet":
102
- return ListNetLoss(**loss_kwargs)
103
- if loss == "listmle":
104
- return ListMLELoss(**loss_kwargs)
105
- if loss == "approx_ndcg":
106
- return ApproxNDCGLoss(**loss_kwargs)
107
- if loss in {"crossentropy", "ce"}:
108
- return nn.CrossEntropyLoss(**loss_kwargs)
109
- if isinstance(loss, str):
110
- raise ValueError(f"Unsupported listwise loss: {loss}")
111
-
112
- raise ValueError(f"Unknown training_mode: {training_mode}")
113
-
114
-
115
- def _get_classification_loss(loss: str | None, **loss_kwargs) -> nn.Module:
116
- if loss is None or loss in {"bce", "binary_crossentropy"}:
117
- return nn.BCELoss(**loss_kwargs)
41
+ if loss is None:
42
+ raise ValueError("loss must be provided explicitly")
43
+ if loss in ["bce", "binary_crossentropy"]:
44
+ return nn.BCELoss(**kw)
118
45
  if loss == "weighted_bce":
119
- return WeightedBCELoss(**loss_kwargs)
120
- if loss == "focal":
121
- return FocalLoss(**loss_kwargs)
122
- if loss == "class_balanced_focal":
123
- return _build_cb_focal(loss_kwargs)
46
+ return WeightedBCELoss(**kw)
47
+ if loss in ["focal", "focal_loss"]:
48
+ return FocalLoss(**kw)
49
+ if loss in ["cb_focal", "class_balanced_focal"]:
50
+ return _build_cb_focal(kw)
51
+ if loss in ["crossentropy", "ce"]:
52
+ return nn.CrossEntropyLoss(**kw)
124
53
  if loss == "mse":
125
- return nn.MSELoss(**loss_kwargs)
54
+ return nn.MSELoss(**kw)
126
55
  if loss == "mae":
127
- return nn.L1Loss(**loss_kwargs)
128
- if loss in {"crossentropy", "ce"}:
129
- return nn.CrossEntropyLoss(**loss_kwargs)
130
- if isinstance(loss, str):
131
- raise ValueError(f"Unsupported loss function: {loss}")
132
- raise ValueError("Loss must be specified for classification task.")
133
-
134
-
135
- def _get_multiclass_loss(loss: str | None, **loss_kwargs) -> nn.Module:
136
- if loss is None or loss in {"crossentropy", "ce"}:
137
- return nn.CrossEntropyLoss(**loss_kwargs)
138
- if loss == "focal":
139
- return FocalLoss(**loss_kwargs)
140
- if loss == "class_balanced_focal":
141
- return _build_cb_focal(loss_kwargs)
142
- if isinstance(loss, str):
143
- raise ValueError(f"Unsupported multiclass loss: {loss}")
144
- raise ValueError("Loss must be specified for multiclass task.")
145
-
146
-
147
- def _build_cb_focal(loss_kwargs: dict) -> ClassBalancedFocalLoss:
148
- if "class_counts" not in loss_kwargs:
149
- raise ValueError("class_balanced_focal requires `class_counts` argument.")
150
- return ClassBalancedFocalLoss(**loss_kwargs)
151
-
56
+ return nn.L1Loss(**kw)
57
+
58
+ # Pairwise ranking Loss
59
+ if loss == "bpr":
60
+ return BPRLoss(**kw)
61
+ if loss == "hinge":
62
+ return HingeLoss(**kw)
63
+ if loss == "triplet":
64
+ return TripletLoss(**kw)
65
+
66
+ # Listwise ranking Loss
67
+ if loss in ["sampled_softmax", "softmax"]:
68
+ return SampledSoftmaxLoss(**kw)
69
+ if loss == "infonce":
70
+ return InfoNCELoss(**kw)
71
+ if loss == "listnet":
72
+ return ListNetLoss(**kw)
73
+ if loss == "listmle":
74
+ return ListMLELoss(**kw)
75
+ if loss == "approx_ndcg":
76
+ return ApproxNDCGLoss(**kw)
77
+
78
+ raise ValueError(f"Unsupported loss: {loss}")
152
79
 
153
80
  def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> dict:
154
81
  """
155
- Resolve per-task loss kwargs from a dict or list of dicts.
82
+ 解析每个 head 对应的 loss_kwargs。
83
+
84
+ - loss_params 为 None -> {}
85
+ - loss_params 为 dict -> 所有 head 共用
86
+ - loss_params 为 list[dict] -> 用 loss_params[index](若存在且非 None),否则 {}
156
87
  """
157
88
  if loss_params is None:
158
89
  return {}
@@ -160,4 +91,4 @@ def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> di
160
91
  if index < len(loss_params) and loss_params[index] is not None:
161
92
  return loss_params[index]
162
93
  return {}
163
- return loss_params
94
+ return loss_params
@@ -40,7 +40,7 @@ class ESMM(BaseModel):
40
40
  ctr_params: dict,
41
41
  cvr_params: dict,
42
42
  target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
43
- task: str | list[str] = 'binary',
43
+ task: list[str] = ['binary', 'binary'],
44
44
  optimizer: str = "adam",
45
45
  optimizer_params: dict = {},
46
46
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
@@ -144,7 +144,7 @@ class MaskNet(BaseModel):
144
144
 
145
145
  @property
146
146
  def task_type(self):
147
- return "binary"
147
+ return "binary_classification"
148
148
 
149
149
  def __init__(
150
150
  self,
nextrec/utils/__init__.py CHANGED
@@ -1,14 +1,17 @@
1
1
  from .optimizer import get_optimizer, get_scheduler
2
2
  from .initializer import get_initializer
3
3
  from .embedding import get_auto_embedding_dim
4
- from . import optimizer, initializer, embedding
4
+ from .common import resolve_device
5
+ from . import optimizer, initializer, embedding, common
5
6
 
6
7
  __all__ = [
7
8
  'get_optimizer',
8
9
  'get_scheduler',
9
10
  'get_initializer',
10
11
  'get_auto_embedding_dim',
12
+ 'resolve_device',
11
13
  'optimizer',
12
14
  'initializer',
13
15
  'embedding',
16
+ 'common',
14
17
  ]
@@ -0,0 +1,16 @@
1
+ import torch
2
+ import platform
3
+
4
+ def resolve_device() -> str:
5
+ """Select a usable device with graceful fallback."""
6
+ if torch.cuda.is_available():
7
+ return "cuda"
8
+ if torch.backends.mps.is_available():
9
+ mac_ver = platform.mac_ver()[0]
10
+ try:
11
+ major, minor = (int(x) for x in mac_ver.split(".")[:2])
12
+ except Exception:
13
+ major, minor = 0, 0
14
+ if major >= 14:
15
+ return "mps"
16
+ return "cpu"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.2.4
3
+ Version: 0.2.5
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -61,7 +61,7 @@ Description-Content-Type: text/markdown
61
61
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
62
62
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
63
63
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
64
- ![Version](https://img.shields.io/badge/Version-0.2.4-orange.svg)
64
+ ![Version](https://img.shields.io/badge/Version-0.2.5-orange.svg)
65
65
 
66
66
  English | [中文版](README_zh.md)
67
67
 
@@ -1,21 +1,21 @@
1
1
  nextrec/__init__.py,sha256=CvocnY2uBp0cjNkhrT6ogw0q2bN9s1GNp754FLO-7lo,1117
2
- nextrec/__version__.py,sha256=SBl2EPFW-ltPvQ7vbVWItyAsz3aKYIpjO7vcfr84GkU,22
2
+ nextrec/__version__.py,sha256=Xsa3ayOMVkhUWm4t06YeyHE0apjpZefxLH4ylp0CDtU,22
3
3
  nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  nextrec/basic/activation.py,sha256=9EfYmwE0brTSKwx_0FIGQ_rybFBT9n_G-UWA7NAhMsI,2804
5
5
  nextrec/basic/callback.py,sha256=qkq3k8rP0g4BW2C3FSCdVt_CyCcJwJ-rUXjhT2p4LP8,1035
6
- nextrec/basic/features.py,sha256=TJfJgzWuy68lBKOeCzztcUK3ZtjHhK8oSMs8k0vXGlg,3961
7
- nextrec/basic/layers.py,sha256=mDNApSlPkmPSnIPj3BDHfDEjviLybWuSGrh61Zog2uk,38290
6
+ nextrec/basic/features.py,sha256=pQyqs-hGFN8T6goTxIjaXxw9S5PG57OWX7EpsKFlb4c,4194
7
+ nextrec/basic/layers.py,sha256=AezH_WYvU45eLw5EHmyAC69MCUkRTv_ZKbNW4WBC5iE,38071
8
8
  nextrec/basic/loggers.py,sha256=x8lzyyK-uqBN5XGOm1Cb33dmfc2bl114n6QeFTtE54k,3752
9
9
  nextrec/basic/metrics.py,sha256=w8tGe2tTbBNz9A1TNZF3jSpxcNC6QvFP5I0lWRd0Nw4,20398
10
- nextrec/basic/model.py,sha256=mtVttDWmrhdW-L1PAelJ90a1BW0q6bzG9roMvrPTU0U,66342
10
+ nextrec/basic/model.py,sha256=k9dbV4CP-1wvr-QLJ0dF6nYtGugXQIFR6J8kZDS9iSs,63968
11
11
  nextrec/basic/session.py,sha256=2kogEjgKAN1_ygelbwoqOs187BAcUnDTqXG1w_Pgb9I,4791
12
12
  nextrec/data/__init__.py,sha256=HLnARJrqDEVPTcofPSAEimy2Oj15vbomj-7UvT4ze_4,767
13
- nextrec/data/data_utils.py,sha256=vGZ378YM_JQXO9npRB7JqojJx1ovjbJCWI-7lQJkicA,6298
14
- nextrec/data/dataloader.py,sha256=LAKpcSHhq53scq8PKwF8uqxa8wQLG0FshjY3TQwIvBU,20459
15
- nextrec/data/preprocessor.py,sha256=N7m4PYGZE6AND0XyYRvXKYAUub9aHGb1qmxbBRxlZKA,42294
16
- nextrec/loss/__init__.py,sha256=t-wkqxcu5wdYlrb67-CxX9aOGom0CpMJK8Fe8KGDSEE,857
13
+ nextrec/data/data_utils.py,sha256=xz0xVBA7UzHXz7r_Yf0eMB5RrarPKg_1ZTdWvAqRZCM,7623
14
+ nextrec/data/dataloader.py,sha256=vtgt2B7rUmIG7wg-HE2ZesBaD6cuS2PwklFCWGA9tCw,14142
15
+ nextrec/data/preprocessor.py,sha256=J-3fo_LIz100spqCHoSpewYcneiZwhaCKyRdroPSjeY,41548
16
+ nextrec/loss/__init__.py,sha256=mO5t417BneZ8Ysa51GyjDaffjWyjzFgPXIQrrggasaQ,827
17
17
  nextrec/loss/listwise.py,sha256=LcYIPf6PGRtjV_AoWaAyp3rse904S2MghE5t032I07I,5628
18
- nextrec/loss/loss_utils.py,sha256=LnTkpMTS2bhbq4Lsjf3AUn1uBaOg1TaH5VO2R8hwARc,5324
18
+ nextrec/loss/loss_utils.py,sha256=cFNSvv-eaFwcfjLgxN3yNmf0L7ofC0ysgkUYjliLBpE,2535
19
19
  nextrec/loss/pairwise.py,sha256=RuQuTE-EkLaHQvT9m0CTAXxneTnVQLF1Pi9wblEClI8,3289
20
20
  nextrec/loss/pointwise.py,sha256=6QveizdohzQTxAoBKTVSoCBpp-fy3JC8vCjImXa7jL0,7157
21
21
  nextrec/models/generative/hstu.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -26,7 +26,7 @@ nextrec/models/match/dssm_v2.py,sha256=ywtqTy3YN9ke_7kzcDp7Fhtldw9RJz6yfewxALJb6
26
26
  nextrec/models/match/mind.py,sha256=XSUDlZ-V95JXHHBDUl5sz99SaVuQKDvf3TArVjwUexs,9417
27
27
  nextrec/models/match/sdm.py,sha256=96yfMQ6arP6JRhAkDTGEjlBiTteznMykrDV_3jqvvVk,10920
28
28
  nextrec/models/match/youtube_dnn.py,sha256=pnrz9LYu65Fj4neOriFF45B5k2-yYiiREtQICxxYXZ0,7546
29
- nextrec/models/multi_task/esmm.py,sha256=E9B6TlpnPUeyldTofyFg4B7SKByyxbiW2fUGHLOryO4,4883
29
+ nextrec/models/multi_task/esmm.py,sha256=U9DkYxwAhD_uWB8H2tsx1zfIStp5Xp8bvM6sc8S9tu4,4889
30
30
  nextrec/models/multi_task/mmoe.py,sha256=zhQr43Vfz7Kgi6B9pKPmaenp_38a_D7w4VvlpwCyF6Y,6165
31
31
  nextrec/models/multi_task/ple.py,sha256=otP6oLgzrJhwkLFItzNE-AtIPouObDkafRvWzTCxfNo,11335
32
32
  nextrec/models/multi_task/share_bottom.py,sha256=LL5HBVlvvBzHV2fLBRQMGIwpqmlxILTgU4c51XyTCo4,4517
@@ -39,15 +39,16 @@ nextrec/models/ranking/dien.py,sha256=E6s9TDwQfGSwtzzh8hG2F5gwgVxzVZPcptYvHLNzOL
39
39
  nextrec/models/ranking/din.py,sha256=j5tkT5k91CbsMlMr5vJOySrcY2_rFGxmEgJJ0McW7-Q,7196
40
40
  nextrec/models/ranking/fibinet.py,sha256=X6CbQbritvq5jql_Tvs4bn_tRla2zpWPplftZv8k6f0,4853
41
41
  nextrec/models/ranking/fm.py,sha256=3Qx_Fgowegr6UPQtEeTmHtOrbWzkvqH94ZTjOqRLu-E,2961
42
- nextrec/models/ranking/masknet.py,sha256=hU3m270vd9DWH2_Hh1hYiCGaF9fKC3eIsWQLSA-Gdf8,12215
42
+ nextrec/models/ranking/masknet.py,sha256=Tx5deIv7oShm4DdXX1IJL8Hz8-5uGqcPMK7pj00xTHg,12230
43
43
  nextrec/models/ranking/pnn.py,sha256=5RxIKdxD0XcGq-b_QDdwGRwk6b_5BQjyMvCw3Ibv2Kk,4957
44
44
  nextrec/models/ranking/widedeep.py,sha256=b6ctElaZPv5WSYDA4piYUBo3je0eJpWpWECwcuWavM4,3716
45
45
  nextrec/models/ranking/xdeepfm.py,sha256=I00J5tfE4tPluqeW-qrNtE4V_9fC7-rgFvA0Fxqka7o,4274
46
- nextrec/utils/__init__.py,sha256=6x3OZbqks2gtgJd00y_-Y8QiAT42x5t14ARHQ-ULQDo,350
46
+ nextrec/utils/__init__.py,sha256=A3mH6M-DmDBWQ1stIIaTsNzvUy_AKaUWtRmrzU5R3FE,429
47
+ nextrec/utils/common.py,sha256=-n4wSbP-EptpzLcJv6fV-ytBzPliOj6m-mrK_Qk6s4A,458
47
48
  nextrec/utils/embedding.py,sha256=yxYSdFx0cJITh3Gf-K4SdhwRtKGcI0jOsyBgZ0NLa_c,465
48
49
  nextrec/utils/initializer.py,sha256=ffYOs5QuIns_d_-5e40iNtg6s1ftgREJN-ueq_NbDQE,1647
49
50
  nextrec/utils/optimizer.py,sha256=85ifoy2IQgjPHOqLqr1ho7XBGE_0ry1yEB9efS6C2lM,2446
50
- nextrec-0.2.4.dist-info/METADATA,sha256=VurhzAYPQ_PbBi6WJFHvgbbk08OVd_1udwLPHxqApag,11425
51
- nextrec-0.2.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
52
- nextrec-0.2.4.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
53
- nextrec-0.2.4.dist-info/RECORD,,
51
+ nextrec-0.2.5.dist-info/METADATA,sha256=Ya8KTj9x1ozIaciXYTKnTFBLGiC4buUBIz-jVHHAM3s,11425
52
+ nextrec-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
53
+ nextrec-0.2.5.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
54
+ nextrec-0.2.5.dist-info/RECORD,,