nextrec 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/layers.py +2 -2
  3. nextrec/basic/model.py +80 -47
  4. nextrec/loss/__init__.py +31 -24
  5. nextrec/loss/listwise.py +162 -4
  6. nextrec/loss/loss_utils.py +133 -105
  7. nextrec/loss/pairwise.py +103 -4
  8. nextrec/loss/pointwise.py +196 -4
  9. nextrec/models/match/dssm.py +24 -15
  10. nextrec/models/match/dssm_v2.py +18 -0
  11. nextrec/models/match/mind.py +16 -1
  12. nextrec/models/match/sdm.py +15 -0
  13. nextrec/models/match/youtube_dnn.py +21 -8
  14. nextrec/models/multi_task/esmm.py +5 -5
  15. nextrec/models/multi_task/mmoe.py +5 -5
  16. nextrec/models/multi_task/ple.py +5 -5
  17. nextrec/models/multi_task/share_bottom.py +5 -5
  18. nextrec/models/ranking/__init__.py +8 -0
  19. nextrec/models/ranking/afm.py +3 -1
  20. nextrec/models/ranking/autoint.py +3 -1
  21. nextrec/models/ranking/dcn.py +3 -1
  22. nextrec/models/ranking/deepfm.py +3 -1
  23. nextrec/models/ranking/dien.py +3 -1
  24. nextrec/models/ranking/din.py +3 -1
  25. nextrec/models/ranking/fibinet.py +3 -1
  26. nextrec/models/ranking/fm.py +3 -1
  27. nextrec/models/ranking/masknet.py +3 -1
  28. nextrec/models/ranking/pnn.py +3 -1
  29. nextrec/models/ranking/widedeep.py +3 -1
  30. nextrec/models/ranking/xdeepfm.py +3 -1
  31. nextrec/utils/__init__.py +5 -5
  32. nextrec/utils/initializer.py +3 -3
  33. nextrec/utils/optimizer.py +6 -6
  34. {nextrec-0.2.1.dist-info → nextrec-0.2.2.dist-info}/METADATA +2 -2
  35. nextrec-0.2.2.dist-info/RECORD +53 -0
  36. nextrec/loss/match_losses.py +0 -293
  37. nextrec-0.2.1.dist-info/RECORD +0 -54
  38. {nextrec-0.2.1.dist-info → nextrec-0.2.2.dist-info}/WHEEL +0 -0
  39. {nextrec-0.2.1.dist-info → nextrec-0.2.2.dist-info}/licenses/LICENSE +0 -0
@@ -44,6 +44,12 @@ class DSSM_v2(BaseMatchModel):
44
44
  embedding_l2_reg: float = 0.0,
45
45
  dense_l2_reg: float = 0.0,
46
46
  early_stop_patience: int = 20,
47
+ optimizer: str | torch.optim.Optimizer = "adam",
48
+ optimizer_params: dict | None = None,
49
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
50
+ scheduler_params: dict | None = None,
51
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
52
+ loss_params: dict | list[dict] | None = None,
47
53
  **kwargs):
48
54
 
49
55
  super(DSSM_v2, self).__init__(
@@ -137,6 +143,18 @@ class DSSM_v2(BaseMatchModel):
137
143
  include_modules=['item_dnn']
138
144
  )
139
145
 
146
+ if optimizer_params is None:
147
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
148
+
149
+ self.compile(
150
+ optimizer=optimizer,
151
+ optimizer_params=optimizer_params,
152
+ scheduler=scheduler,
153
+ scheduler_params=scheduler_params,
154
+ loss=loss,
155
+ loss_params=loss_params,
156
+ )
157
+
140
158
  self.to(device)
141
159
 
142
160
  def user_tower(self, user_input: dict) -> torch.Tensor:
@@ -41,7 +41,7 @@ class MIND(BaseMatchModel):
41
41
  item_dnn_hidden_units: list[int] = [256, 128],
42
42
  dnn_activation: str = 'relu',
43
43
  dnn_dropout: float = 0.0,
44
- training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'listwise',
44
+ training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
45
45
  num_negative_samples: int = 100,
46
46
  temperature: float = 1.0,
47
47
  similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
@@ -51,6 +51,12 @@ class MIND(BaseMatchModel):
51
51
  embedding_l2_reg: float = 0.0,
52
52
  dense_l2_reg: float = 0.0,
53
53
  early_stop_patience: int = 20,
54
+ optimizer: str | torch.optim.Optimizer = "adam",
55
+ optimizer_params: dict | None = None,
56
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
57
+ scheduler_params: dict | None = None,
58
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
59
+ loss_params: dict | list[dict] | None = None,
54
60
  **kwargs):
55
61
 
56
62
  super(MIND, self).__init__(
@@ -152,6 +158,15 @@ class MIND(BaseMatchModel):
152
158
  include_modules=['item_dnn'] if self.item_dnn else []
153
159
  )
154
160
 
161
+ self.compile(
162
+ optimizer=optimizer,
163
+ optimizer_params=optimizer_params,
164
+ scheduler=scheduler,
165
+ scheduler_params=scheduler_params,
166
+ loss=loss,
167
+ loss_params=loss_params,
168
+ )
169
+
155
170
  self.to(device)
156
171
 
157
172
  def user_tower(self, user_input: dict) -> torch.Tensor:
@@ -52,6 +52,12 @@ class SDM(BaseMatchModel):
52
52
  embedding_l2_reg: float = 0.0,
53
53
  dense_l2_reg: float = 0.0,
54
54
  early_stop_patience: int = 20,
55
+ optimizer: str | torch.optim.Optimizer = "adam",
56
+ optimizer_params: dict | None = None,
57
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
58
+ scheduler_params: dict | None = None,
59
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
60
+ loss_params: dict | list[dict] | None = None,
55
61
  **kwargs):
56
62
 
57
63
  super(SDM, self).__init__(
@@ -179,6 +185,15 @@ class SDM(BaseMatchModel):
179
185
  include_modules=['item_dnn'] if self.item_dnn else []
180
186
  )
181
187
 
188
+ self.compile(
189
+ optimizer=optimizer,
190
+ optimizer_params=optimizer_params,
191
+ scheduler=scheduler,
192
+ scheduler_params=scheduler_params,
193
+ loss=loss,
194
+ loss_params=loss_params,
195
+ )
196
+
182
197
  self.to(device)
183
198
 
184
199
  def user_tower(self, user_input: dict) -> torch.Tensor:
@@ -17,11 +17,10 @@ from nextrec.basic.layers import MLP, EmbeddingLayer, AveragePooling
17
17
 
18
18
  class YoutubeDNN(BaseMatchModel):
19
19
  """
20
- YouTube Deep Neural Network for Recommendations
21
-
22
- 用户塔:历史行为序列 + 用户特征 -> 用户embedding
23
- 物品塔:物品特征 -> 物品embedding
24
- 训练:sampled softmax loss (listwise)
20
+ YouTube Deep Neural Network for Recommendations.
21
+ User tower: behavior sequence + user features -> user embedding.
22
+ Item tower: item features -> item embedding.
23
+ Training usually uses listwise / sampled softmax style objectives.
25
24
  """
26
25
 
27
26
  @property
@@ -50,6 +49,12 @@ class YoutubeDNN(BaseMatchModel):
50
49
  embedding_l2_reg: float = 0.0,
51
50
  dense_l2_reg: float = 0.0,
52
51
  early_stop_patience: int = 20,
52
+ optimizer: str | torch.optim.Optimizer = "adam",
53
+ optimizer_params: dict | None = None,
54
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
55
+ scheduler_params: dict | None = None,
56
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
57
+ loss_params: dict | list[dict] | None = None,
53
58
  **kwargs):
54
59
 
55
60
  super(YoutubeDNN, self).__init__(
@@ -94,7 +99,7 @@ class YoutubeDNN(BaseMatchModel):
94
99
  for feat in user_sparse_features or []:
95
100
  user_input_dim += feat.embedding_dim
96
101
  for feat in user_sequence_features or []:
97
- # 序列特征通过平均池化聚合
102
+ # Sequence features are pooled before entering the DNN
98
103
  user_input_dim += feat.embedding_dim
99
104
 
100
105
  user_dnn_units = user_dnn_hidden_units + [embedding_dim]
@@ -144,12 +149,20 @@ class YoutubeDNN(BaseMatchModel):
144
149
  include_modules=['item_dnn']
145
150
  )
146
151
 
152
+ self.compile(
153
+ optimizer=optimizer,
154
+ optimizer_params=optimizer_params,
155
+ scheduler=scheduler,
156
+ scheduler_params=scheduler_params,
157
+ loss=loss,
158
+ loss_params=loss_params,
159
+ )
160
+
147
161
  self.to(device)
148
162
 
149
163
  def user_tower(self, user_input: dict) -> torch.Tensor:
150
164
  """
151
- User tower
152
- 处理用户历史行为序列和其他用户特征
165
+ User tower to encode historical behavior sequences and user features.
153
166
  """
154
167
  all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
155
168
  user_emb = self.user_embedding(user_input, all_user_features, squeeze_dim=True)
@@ -1,9 +1,7 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
5
- Reference:
6
- [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
3
+ Author: Yang Zhou,zyaztec@gmail.com
4
+ Reference: [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
7
5
  """
8
6
 
9
7
  import torch
@@ -46,6 +44,7 @@ class ESMM(BaseModel):
46
44
  optimizer: str = "adam",
47
45
  optimizer_params: dict = {},
48
46
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
47
+ loss_params: dict | list[dict] | None = None,
49
48
  device: str = 'cpu',
50
49
  embedding_l1_reg=1e-6,
51
50
  dense_l1_reg=1e-5,
@@ -106,7 +105,8 @@ class ESMM(BaseModel):
106
105
  self.compile(
107
106
  optimizer=optimizer,
108
107
  optimizer_params=optimizer_params,
109
- loss=loss
108
+ loss=loss,
109
+ loss_params=loss_params,
110
110
  )
111
111
 
112
112
  def forward(self, x):
@@ -1,9 +1,7 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
5
- Reference:
6
- [1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with multi-gate mixture-of-experts[C]//KDD. 2018: 1930-1939.
3
+ Author: Yang Zhou,zyaztec@gmail.com
4
+ Reference: [1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with multi-gate mixture-of-experts[C]//KDD. 2018: 1930-1939.
7
5
  """
8
6
 
9
7
  import torch
@@ -44,6 +42,7 @@ class MMOE(BaseModel):
44
42
  optimizer: str = "adam",
45
43
  optimizer_params: dict = {},
46
44
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
45
+ loss_params: dict | list[dict] | None = None,
47
46
  device: str = 'cpu',
48
47
  embedding_l1_reg=1e-6,
49
48
  dense_l1_reg=1e-5,
@@ -128,7 +127,8 @@ class MMOE(BaseModel):
128
127
  self.compile(
129
128
  optimizer=optimizer,
130
129
  optimizer_params=optimizer_params,
131
- loss=loss
130
+ loss=loss,
131
+ loss_params=loss_params,
132
132
  )
133
133
 
134
134
  def forward(self, x):
@@ -1,9 +1,7 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
5
- Reference:
6
- [1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (ple): A novel multi-task learning (mtl) model for personalized recommendations[C]//RecSys. 2020: 269-278.
3
+ Author: Yang Zhou,zyaztec@gmail.com
4
+ Reference: [1] Tang H, Liu J, Zhao M, et al. Progressive layered extraction (ple): A novel multi-task learning (mtl) model for personalized recommendations[C]//RecSys. 2020: 269-278.
7
5
  """
8
6
 
9
7
  import torch
@@ -47,6 +45,7 @@ class PLE(BaseModel):
47
45
  optimizer: str = "adam",
48
46
  optimizer_params: dict = {},
49
47
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
48
+ loss_params: dict | list[dict] | None = None,
50
49
  device: str = 'cpu',
51
50
  embedding_l1_reg=1e-6,
52
51
  dense_l1_reg=1e-5,
@@ -166,7 +165,8 @@ class PLE(BaseModel):
166
165
  self.compile(
167
166
  optimizer=optimizer,
168
167
  optimizer_params=optimizer_params,
169
- loss=loss
168
+ loss=loss,
169
+ loss_params=loss_params,
170
170
  )
171
171
 
172
172
  def forward(self, x):
@@ -1,9 +1,7 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
5
- Reference:
6
- [1] Caruana R. Multitask learning[J]. Machine learning, 1997, 28: 41-75.
3
+ Author: Yang Zhou,zyaztec@gmail.com
4
+ Reference: [1] Caruana R. Multitask learning[J]. Machine learning, 1997, 28: 41-75.
7
5
  """
8
6
 
9
7
  import torch
@@ -35,6 +33,7 @@ class ShareBottom(BaseModel):
35
33
  optimizer: str = "adam",
36
34
  optimizer_params: dict = {},
37
35
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
36
+ loss_params: dict | list[dict] | None = None,
38
37
  device: str = 'cpu',
39
38
  embedding_l1_reg=1e-6,
40
39
  dense_l1_reg=1e-5,
@@ -105,7 +104,8 @@ class ShareBottom(BaseModel):
105
104
  self.compile(
106
105
  optimizer=optimizer,
107
106
  optimizer_params=optimizer_params,
108
- loss=loss
107
+ loss=loss,
108
+ loss_params=loss_params,
109
109
  )
110
110
 
111
111
  def forward(self, x):
@@ -1,3 +1,7 @@
1
+ from .fm import FM
2
+ from .afm import AFM
3
+ from .masknet import MaskNet
4
+ from .pnn import PNN
1
5
  from .deepfm import DeepFM
2
6
  from .autoint import AutoInt
3
7
  from .widedeep import WideDeep
@@ -14,4 +18,8 @@ __all__ = [
14
18
  'DCN',
15
19
  'DIN',
16
20
  'DIEN',
21
+ 'FM',
22
+ 'AFM',
23
+ 'MaskNet',
24
+ 'PNN',
17
25
  ]
@@ -34,6 +34,7 @@ class AFM(BaseModel):
34
34
  optimizer: str = "adam",
35
35
  optimizer_params: dict = {},
36
36
  loss: str | nn.Module | None = "bce",
37
+ loss_params: dict | list[dict] | None = None,
37
38
  device: str = 'cpu',
38
39
  embedding_l1_reg=1e-6,
39
40
  dense_l1_reg=1e-5,
@@ -88,7 +89,8 @@ class AFM(BaseModel):
88
89
  self.compile(
89
90
  optimizer=optimizer,
90
91
  optimizer_params=optimizer_params,
91
- loss=loss
92
+ loss=loss,
93
+ loss_params=loss_params,
92
94
  )
93
95
 
94
96
  def forward(self, x):
@@ -39,6 +39,7 @@ class AutoInt(BaseModel):
39
39
  optimizer: str = "adam",
40
40
  optimizer_params: dict = {},
41
41
  loss: str | nn.Module | None = "bce",
42
+ loss_params: dict | list[dict] | None = None,
42
43
  device: str = 'cpu',
43
44
  embedding_l1_reg=1e-6,
44
45
  dense_l1_reg=1e-5,
@@ -113,7 +114,8 @@ class AutoInt(BaseModel):
113
114
  self.compile(
114
115
  optimizer=optimizer,
115
116
  optimizer_params=optimizer_params,
116
- loss=loss
117
+ loss=loss,
118
+ loss_params=loss_params,
117
119
  )
118
120
 
119
121
  def forward(self, x):
@@ -35,6 +35,7 @@ class DCN(BaseModel):
35
35
  optimizer: str = "adam",
36
36
  optimizer_params: dict = {},
37
37
  loss: str | nn.Module | None = "bce",
38
+ loss_params: dict | list[dict] | None = None,
38
39
  device: str = 'cpu',
39
40
  embedding_l1_reg=1e-6,
40
41
  dense_l1_reg=1e-5,
@@ -97,7 +98,8 @@ class DCN(BaseModel):
97
98
  self.compile(
98
99
  optimizer=optimizer,
99
100
  optimizer_params=optimizer_params,
100
- loss=loss
101
+ loss=loss,
102
+ loss_params=loss_params,
101
103
  )
102
104
 
103
105
  def forward(self, x):
@@ -31,6 +31,7 @@ class DeepFM(BaseModel):
31
31
  optimizer: str = "adam",
32
32
  optimizer_params: dict = {},
33
33
  loss: str | nn.Module | None = "bce",
34
+ loss_params: dict | list[dict] | None = None,
34
35
  device: str = 'cpu',
35
36
  embedding_l1_reg=1e-6,
36
37
  dense_l1_reg=1e-5,
@@ -79,7 +80,8 @@ class DeepFM(BaseModel):
79
80
  self.compile(
80
81
  optimizer=optimizer,
81
82
  optimizer_params=optimizer_params,
82
- loss=loss
83
+ loss=loss,
84
+ loss_params=loss_params,
83
85
  )
84
86
 
85
87
  def forward(self, x):
@@ -38,6 +38,7 @@ class DIEN(BaseModel):
38
38
  optimizer: str = "adam",
39
39
  optimizer_params: dict = {},
40
40
  loss: str | nn.Module | None = "bce",
41
+ loss_params: dict | list[dict] | None = None,
41
42
  device: str = 'cpu',
42
43
  embedding_l1_reg=1e-6,
43
44
  dense_l1_reg=1e-5,
@@ -128,7 +129,8 @@ class DIEN(BaseModel):
128
129
  self.compile(
129
130
  optimizer=optimizer,
130
131
  optimizer_params=optimizer_params,
131
- loss=loss
132
+ loss=loss,
133
+ loss_params=loss_params,
132
134
  )
133
135
 
134
136
  def forward(self, x):
@@ -37,6 +37,7 @@ class DIN(BaseModel):
37
37
  optimizer: str = "adam",
38
38
  optimizer_params: dict = {},
39
39
  loss: str | nn.Module | None = "bce",
40
+ loss_params: dict | list[dict] | None = None,
40
41
  device: str = 'cpu',
41
42
  embedding_l1_reg=1e-6,
42
43
  dense_l1_reg=1e-5,
@@ -115,7 +116,8 @@ class DIN(BaseModel):
115
116
  self.compile(
116
117
  optimizer=optimizer,
117
118
  optimizer_params=optimizer_params,
118
- loss=loss
119
+ loss=loss,
120
+ loss_params=loss_params,
119
121
  )
120
122
 
121
123
  def forward(self, x):
@@ -42,6 +42,7 @@ class FiBiNET(BaseModel):
42
42
  optimizer: str = "adam",
43
43
  optimizer_params: dict = {},
44
44
  loss: str | nn.Module | None = "bce",
45
+ loss_params: dict | list[dict] | None = None,
45
46
  device: str = 'cpu',
46
47
  embedding_l1_reg=1e-6,
47
48
  dense_l1_reg=1e-5,
@@ -111,7 +112,8 @@ class FiBiNET(BaseModel):
111
112
  self.compile(
112
113
  optimizer=optimizer,
113
114
  optimizer_params=optimizer_params,
114
- loss=loss
115
+ loss=loss,
116
+ loss_params=loss_params,
115
117
  )
116
118
 
117
119
  def forward(self, x):
@@ -30,6 +30,7 @@ class FM(BaseModel):
30
30
  optimizer: str = "adam",
31
31
  optimizer_params: dict = {},
32
32
  loss: str | nn.Module | None = "bce",
33
+ loss_params: dict | list[dict] | None = None,
33
34
  device: str = 'cpu',
34
35
  embedding_l1_reg=1e-6,
35
36
  dense_l1_reg=1e-5,
@@ -76,7 +77,8 @@ class FM(BaseModel):
76
77
  self.compile(
77
78
  optimizer=optimizer,
78
79
  optimizer_params=optimizer_params,
79
- loss=loss
80
+ loss=loss,
81
+ loss_params=loss_params,
80
82
  )
81
83
 
82
84
  def forward(self, x):
@@ -36,6 +36,7 @@ class MaskNet(BaseModel):
36
36
  optimizer: str = "adam",
37
37
  optimizer_params: dict = {},
38
38
  loss: str | nn.Module | None = "bce",
39
+ loss_params: dict | list[dict] | None = None,
39
40
  device: str = 'cpu',
40
41
  embedding_l1_reg=1e-6,
41
42
  dense_l1_reg=1e-5,
@@ -98,7 +99,8 @@ class MaskNet(BaseModel):
98
99
  self.compile(
99
100
  optimizer=optimizer,
100
101
  optimizer_params=optimizer_params,
101
- loss=loss
102
+ loss=loss,
103
+ loss_params=loss_params,
102
104
  )
103
105
 
104
106
  def forward(self, x):
@@ -34,6 +34,7 @@ class PNN(BaseModel):
34
34
  optimizer: str = "adam",
35
35
  optimizer_params: dict = {},
36
36
  loss: str | nn.Module | None = "bce",
37
+ loss_params: dict | list[dict] | None = None,
37
38
  device: str = 'cpu',
38
39
  embedding_l1_reg=1e-6,
39
40
  dense_l1_reg=1e-5,
@@ -98,7 +99,8 @@ class PNN(BaseModel):
98
99
  self.compile(
99
100
  optimizer=optimizer,
100
101
  optimizer_params=optimizer_params,
101
- loss=loss
102
+ loss=loss,
103
+ loss_params=loss_params,
102
104
  )
103
105
 
104
106
  def forward(self, x):
@@ -34,6 +34,7 @@ class WideDeep(BaseModel):
34
34
  optimizer: str = "adam",
35
35
  optimizer_params: dict = {},
36
36
  loss: str | nn.Module | None = "bce",
37
+ loss_params: dict | list[dict] | None = None,
37
38
  device: str = 'cpu',
38
39
  embedding_l1_reg=1e-6,
39
40
  dense_l1_reg=1e-5,
@@ -88,7 +89,8 @@ class WideDeep(BaseModel):
88
89
  self.compile(
89
90
  optimizer=optimizer,
90
91
  optimizer_params=optimizer_params,
91
- loss=loss
92
+ loss=loss,
93
+ loss_params=loss_params,
92
94
  )
93
95
 
94
96
  def forward(self, x):
@@ -37,6 +37,7 @@ class xDeepFM(BaseModel):
37
37
  optimizer: str = "adam",
38
38
  optimizer_params: dict = {},
39
39
  loss: str | nn.Module | None = "bce",
40
+ loss_params: dict | list[dict] | None = None,
40
41
  device: str = 'cpu',
41
42
  embedding_l1_reg=1e-6,
42
43
  dense_l1_reg=1e-5,
@@ -95,7 +96,8 @@ class xDeepFM(BaseModel):
95
96
  self.compile(
96
97
  optimizer=optimizer,
97
98
  optimizer_params=optimizer_params,
98
- loss=loss
99
+ loss=loss,
100
+ loss_params=loss_params,
99
101
  )
100
102
 
101
103
  def forward(self, x):
nextrec/utils/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
- from .optimizer import get_optimizer_fn, get_scheduler_fn
2
- from .initializer import get_initializer_fn
1
+ from .optimizer import get_optimizer, get_scheduler
2
+ from .initializer import get_initializer
3
3
  from .embedding import get_auto_embedding_dim
4
4
  from . import optimizer, initializer, embedding
5
5
 
6
6
  __all__ = [
7
- 'get_optimizer_fn',
8
- 'get_scheduler_fn',
9
- 'get_initializer_fn',
7
+ 'get_optimizer',
8
+ 'get_scheduler',
9
+ 'get_initializer',
10
10
  'get_auto_embedding_dim',
11
11
  'optimizer',
12
12
  'initializer',
@@ -8,14 +8,14 @@ Author: Yang Zhou, zyaztec@gmail.com
8
8
  import torch.nn as nn
9
9
 
10
10
 
11
- def get_initializer_fn(init_type='normal', activation='linear', param=None):
11
+ def get_initializer(init_type='normal', activation='linear', param=None):
12
12
  """
13
13
  Get parameter initialization function.
14
14
 
15
15
  Examples:
16
- >>> init_fn = get_initializer_fn('xavier_uniform', 'relu')
16
+ >>> init_fn = get_initializer('xavier_uniform', 'relu')
17
17
  >>> init_fn(tensor)
18
- >>> init_fn = get_initializer_fn('normal', param={'mean': 0.0, 'std': 0.01})
18
+ >>> init_fn = get_initializer('normal', param={'mean': 0.0, 'std': 0.01})
19
19
  """
20
20
  param = param or {}
21
21
 
@@ -9,7 +9,7 @@ import torch
9
9
  from typing import Iterable
10
10
 
11
11
 
12
- def get_optimizer_fn(
12
+ def get_optimizer(
13
13
  optimizer: str = "adam",
14
14
  params: Iterable[torch.nn.Parameter] | None = None,
15
15
  **optimizer_params
@@ -18,8 +18,8 @@ def get_optimizer_fn(
18
18
  Get optimizer function based on optimizer name or instance.
19
19
 
20
20
  Examples:
21
- >>> optimizer = get_optimizer_fn("adam", model.parameters(), lr=1e-3)
22
- >>> optimizer = get_optimizer_fn("sgd", model.parameters(), lr=0.01, momentum=0.9)
21
+ >>> optimizer = get_optimizer("adam", model.parameters(), lr=1e-3)
22
+ >>> optimizer = get_optimizer("sgd", model.parameters(), lr=0.01, momentum=0.9)
23
23
  """
24
24
  if params is None:
25
25
  raise ValueError("params cannot be None. Please provide model parameters.")
@@ -51,13 +51,13 @@ def get_optimizer_fn(
51
51
  return optimizer_fn
52
52
 
53
53
 
54
- def get_scheduler_fn(scheduler, optimizer, **scheduler_params):
54
+ def get_scheduler(scheduler, optimizer, **scheduler_params):
55
55
  """
56
56
  Get learning rate scheduler function.
57
57
 
58
58
  Examples:
59
- >>> scheduler = get_scheduler_fn("step", optimizer, step_size=10, gamma=0.1)
60
- >>> scheduler = get_scheduler_fn("cosine", optimizer, T_max=100)
59
+ >>> scheduler = get_scheduler("step", optimizer, step_size=10, gamma=0.1)
60
+ >>> scheduler = get_scheduler("cosine", optimizer, T_max=100)
61
61
  """
62
62
  if isinstance(scheduler, str):
63
63
  if scheduler == "step":
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -61,7 +61,7 @@ Description-Content-Type: text/markdown
61
61
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
62
62
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
63
63
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
64
- ![Version](https://img.shields.io/badge/Version-0.2.1-orange.svg)
64
+ ![Version](https://img.shields.io/badge/Version-0.2.2-orange.svg)
65
65
 
66
66
  English | [中文版](README_zh.md)
67
67