nextrec 0.4.16__py3-none-any.whl → 0.4.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/heads.py +99 -0
  3. nextrec/basic/loggers.py +5 -5
  4. nextrec/basic/model.py +217 -88
  5. nextrec/cli.py +1 -1
  6. nextrec/data/dataloader.py +93 -95
  7. nextrec/data/preprocessor.py +108 -46
  8. nextrec/loss/grad_norm.py +13 -13
  9. nextrec/models/multi_task/esmm.py +10 -11
  10. nextrec/models/multi_task/mmoe.py +20 -19
  11. nextrec/models/multi_task/ple.py +35 -34
  12. nextrec/models/multi_task/poso.py +23 -21
  13. nextrec/models/multi_task/share_bottom.py +18 -17
  14. nextrec/models/ranking/afm.py +4 -3
  15. nextrec/models/ranking/autoint.py +4 -3
  16. nextrec/models/ranking/dcn.py +4 -3
  17. nextrec/models/ranking/dcn_v2.py +4 -3
  18. nextrec/models/ranking/deepfm.py +4 -3
  19. nextrec/models/ranking/dien.py +2 -2
  20. nextrec/models/ranking/din.py +2 -2
  21. nextrec/models/ranking/eulernet.py +4 -3
  22. nextrec/models/ranking/ffm.py +4 -3
  23. nextrec/models/ranking/fibinet.py +2 -2
  24. nextrec/models/ranking/fm.py +4 -3
  25. nextrec/models/ranking/lr.py +4 -3
  26. nextrec/models/ranking/masknet.py +4 -5
  27. nextrec/models/ranking/pnn.py +5 -4
  28. nextrec/models/ranking/widedeep.py +8 -8
  29. nextrec/models/ranking/xdeepfm.py +5 -4
  30. nextrec/utils/console.py +20 -6
  31. nextrec/utils/data.py +154 -32
  32. nextrec/utils/model.py +86 -1
  33. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/METADATA +5 -6
  34. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/RECORD +37 -36
  35. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/WHEEL +0 -0
  36. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/entry_points.txt +0 -0
  37. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 09/12/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Xiao J, Ye H, He X, et al. Attentional factorization machines: Learning the weight of
@@ -40,7 +40,8 @@ import torch
40
40
  import torch.nn as nn
41
41
 
42
42
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
43
- from nextrec.basic.layers import EmbeddingLayer, InputMask, PredictionLayer
43
+ from nextrec.basic.layers import EmbeddingLayer, InputMask
44
+ from nextrec.basic.heads import TaskHead
44
45
  from nextrec.basic.model import BaseModel
45
46
 
46
47
 
@@ -141,7 +142,7 @@ class AFM(BaseModel):
141
142
  self.attention_p = nn.Linear(attention_dim, 1, bias=False)
142
143
  self.attention_dropout = nn.Dropout(attention_dropout)
143
144
  self.output_projection = nn.Linear(self.embedding_dim, 1, bias=False)
144
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
145
+ self.prediction_layer = TaskHead(task_type=self.default_task)
145
146
  self.input_mask = InputMask()
146
147
 
147
148
  # Register regularization weights
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 09/12/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Song W, Shi C, Xiao Z, et al. Autoint: Automatic feature interaction learning via
@@ -58,7 +58,8 @@ import torch
58
58
  import torch.nn as nn
59
59
 
60
60
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
61
- from nextrec.basic.layers import EmbeddingLayer, MultiHeadSelfAttention, PredictionLayer
61
+ from nextrec.basic.layers import EmbeddingLayer, MultiHeadSelfAttention
62
+ from nextrec.basic.heads import TaskHead
62
63
  from nextrec.basic.model import BaseModel
63
64
 
64
65
 
@@ -162,7 +163,7 @@ class AutoInt(BaseModel):
162
163
 
163
164
  # Final prediction layer
164
165
  self.fc = nn.Linear(num_fields * att_embedding_dim, 1)
165
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
166
+ self.prediction_layer = TaskHead(task_type=self.default_task)
166
167
 
167
168
  # Register regularization weights
168
169
  self.register_regularization_weights(
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 09/12/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Wang R, Fu B, Fu G, et al. Deep & cross network for ad click predictions[C]
@@ -54,7 +54,8 @@ import torch
54
54
  import torch.nn as nn
55
55
 
56
56
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
57
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
57
+ from nextrec.basic.layers import MLP, EmbeddingLayer
58
+ from nextrec.basic.heads import TaskHead
58
59
  from nextrec.basic.model import BaseModel
59
60
 
60
61
 
@@ -163,7 +164,7 @@ class DCN(BaseModel):
163
164
  # Final layer only uses cross network output
164
165
  self.final_layer = nn.Linear(input_dim, 1)
165
166
 
166
- self.prediction_layer = PredictionLayer(task_type=self.task)
167
+ self.prediction_layer = TaskHead(task_type=self.task)
167
168
 
168
169
  # Register regularization weights
169
170
  self.register_regularization_weights(
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 09/12/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] R. Wang et al. DCN V2: Improved Deep & Cross Network and Practical Lessons for
@@ -47,7 +47,8 @@ import torch
47
47
  import torch.nn as nn
48
48
 
49
49
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
50
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
50
+ from nextrec.basic.layers import MLP, EmbeddingLayer
51
+ from nextrec.basic.heads import TaskHead
51
52
  from nextrec.basic.model import BaseModel
52
53
 
53
54
 
@@ -272,7 +273,7 @@ class DCNv2(BaseModel):
272
273
  final_input_dim = input_dim
273
274
 
274
275
  self.final_layer = nn.Linear(final_input_dim, 1)
275
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
276
+ self.prediction_layer = TaskHead(task_type=self.default_task)
276
277
 
277
278
  self.register_regularization_weights(
278
279
  embedding_attr="embedding",
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 27/10/2025
3
- Checkpoint: edit on 24/11/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Guo H, Tang R, Ye Y, et al. DeepFM: A factorization-machine based neural network
@@ -45,7 +45,8 @@ embedding,无需手工构造交叉特征即可端到端训练,常用于 CTR/
45
45
  import torch.nn as nn
46
46
 
47
47
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
48
- from nextrec.basic.layers import FM, LR, MLP, EmbeddingLayer, PredictionLayer
48
+ from nextrec.basic.layers import FM, LR, MLP, EmbeddingLayer
49
+ from nextrec.basic.heads import TaskHead
49
50
  from nextrec.basic.model import BaseModel
50
51
 
51
52
 
@@ -111,7 +112,7 @@ class DeepFM(BaseModel):
111
112
  self.linear = LR(fm_emb_dim_total)
112
113
  self.fm = FM(reduce_sum=True)
113
114
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
114
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
115
+ self.prediction_layer = TaskHead(task_type=self.default_task)
115
116
 
116
117
  # Register regularization weights
117
118
  self.register_regularization_weights(
@@ -55,8 +55,8 @@ from nextrec.basic.layers import (
55
55
  MLP,
56
56
  AttentionPoolingLayer,
57
57
  EmbeddingLayer,
58
- PredictionLayer,
59
58
  )
59
+ from nextrec.basic.heads import TaskHead
60
60
  from nextrec.basic.model import BaseModel
61
61
 
62
62
 
@@ -346,7 +346,7 @@ class DIEN(BaseModel):
346
346
  )
347
347
 
348
348
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
349
- self.prediction_layer = PredictionLayer(task_type=self.task)
349
+ self.prediction_layer = TaskHead(task_type=self.task)
350
350
 
351
351
  self.register_regularization_weights(
352
352
  embedding_attr="embedding",
@@ -55,8 +55,8 @@ from nextrec.basic.layers import (
55
55
  MLP,
56
56
  AttentionPoolingLayer,
57
57
  EmbeddingLayer,
58
- PredictionLayer,
59
58
  )
59
+ from nextrec.basic.heads import TaskHead
60
60
  from nextrec.basic.model import BaseModel
61
61
 
62
62
 
@@ -173,7 +173,7 @@ class DIN(BaseModel):
173
173
 
174
174
  # MLP for final prediction
175
175
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
176
- self.prediction_layer = PredictionLayer(task_type=self.task)
176
+ self.prediction_layer = TaskHead(task_type=self.task)
177
177
 
178
178
  # Register regularization weights
179
179
  self.register_regularization_weights(
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 09/12/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Zhao Z, Zhang H, Tang H, et al. EulerNet: Efficient and Effective Feature
@@ -38,7 +38,8 @@ import torch.nn as nn
38
38
  import torch.nn.functional as F
39
39
 
40
40
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
41
- from nextrec.basic.layers import LR, EmbeddingLayer, PredictionLayer
41
+ from nextrec.basic.layers import LR, EmbeddingLayer
42
+ from nextrec.basic.heads import TaskHead
42
43
  from nextrec.basic.model import BaseModel
43
44
 
44
45
 
@@ -295,7 +296,7 @@ class EulerNet(BaseModel):
295
296
  else:
296
297
  self.linear = None
297
298
 
298
- self.prediction_layer = PredictionLayer(task_type=self.task)
299
+ self.prediction_layer = TaskHead(task_type=self.task)
299
300
 
300
301
  modules = ["mapping", "layers", "w", "w_im"]
301
302
  if self.use_linear:
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 19/12/2025
3
- Checkpoint: edit on 19/12/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Juan Y, Zhuang Y, Chin W-S, et al. Field-aware Factorization Machines for CTR
@@ -43,7 +43,8 @@ import torch
43
43
  import torch.nn as nn
44
44
 
45
45
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
46
- from nextrec.basic.layers import AveragePooling, InputMask, PredictionLayer, SumPooling
46
+ from nextrec.basic.layers import AveragePooling, InputMask, SumPooling
47
+ from nextrec.basic.heads import TaskHead
47
48
  from nextrec.basic.model import BaseModel
48
49
  from nextrec.utils.torch_utils import get_initializer
49
50
 
@@ -140,7 +141,7 @@ class FFM(BaseModel):
140
141
  nn.Linear(dense_input_dim, 1, bias=True) if dense_input_dim > 0 else None
141
142
  )
142
143
 
143
- self.prediction_layer = PredictionLayer(task_type=self.task)
144
+ self.prediction_layer = TaskHead(task_type=self.task)
144
145
  self.input_mask = InputMask()
145
146
  self.mean_pool = AveragePooling()
146
147
  self.sum_pool = SumPooling()
@@ -50,9 +50,9 @@ from nextrec.basic.layers import (
50
50
  BiLinearInteractionLayer,
51
51
  EmbeddingLayer,
52
52
  HadamardInteractionLayer,
53
- PredictionLayer,
54
53
  SENETLayer,
55
54
  )
55
+ from nextrec.basic.heads import TaskHead
56
56
  from nextrec.basic.model import BaseModel
57
57
 
58
58
 
@@ -168,7 +168,7 @@ class FiBiNET(BaseModel):
168
168
  num_pairs = self.num_fields * (self.num_fields - 1) // 2
169
169
  interaction_dim = num_pairs * self.embedding_dim * 2
170
170
  self.mlp = MLP(input_dim=interaction_dim, **mlp_params)
171
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
171
+ self.prediction_layer = TaskHead(task_type=self.default_task)
172
172
 
173
173
  # Register regularization weights
174
174
  self.register_regularization_weights(
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 09/12/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Rendle S. Factorization machines[C]//ICDM. 2010: 995-1000.
@@ -42,7 +42,8 @@ import torch.nn as nn
42
42
 
43
43
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
44
44
  from nextrec.basic.layers import FM as FMInteraction
45
- from nextrec.basic.layers import LR, EmbeddingLayer, PredictionLayer
45
+ from nextrec.basic.heads import TaskHead
46
+ from nextrec.basic.layers import LR, EmbeddingLayer
46
47
  from nextrec.basic.model import BaseModel
47
48
 
48
49
 
@@ -105,7 +106,7 @@ class FM(BaseModel):
105
106
  fm_input_dim = sum([f.embedding_dim for f in self.fm_features])
106
107
  self.linear = LR(fm_input_dim)
107
108
  self.fm = FMInteraction(reduce_sum=True)
108
- self.prediction_layer = PredictionLayer(task_type=self.task)
109
+ self.prediction_layer = TaskHead(task_type=self.task)
109
110
 
110
111
  # Register regularization weights
111
112
  self.register_regularization_weights(
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 09/12/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Hosmer D W, Lemeshow S, Sturdivant R X. Applied Logistic Regression.
@@ -41,7 +41,8 @@ LR 是 CTR/排序任务中最经典的线性基线模型。它将稠密、稀疏
41
41
  import torch.nn as nn
42
42
 
43
43
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
44
- from nextrec.basic.layers import EmbeddingLayer, LR as LinearLayer, PredictionLayer
44
+ from nextrec.basic.layers import EmbeddingLayer, LR as LinearLayer
45
+ from nextrec.basic.heads import TaskHead
45
46
  from nextrec.basic.model import BaseModel
46
47
 
47
48
 
@@ -99,7 +100,7 @@ class LR(BaseModel):
99
100
  self.embedding = EmbeddingLayer(features=self.all_features)
100
101
  linear_input_dim = self.embedding.input_dim
101
102
  self.linear = LinearLayer(linear_input_dim)
102
- self.prediction_layer = PredictionLayer(task_type=self.task)
103
+ self.prediction_layer = TaskHead(task_type=self.task)
103
104
 
104
105
  self.register_regularization_weights(
105
106
  embedding_attr="embedding", include_modules=["linear"]
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 29/11/2025
3
+ Checkpoint: edit on 23/12/2025
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Wang Z, She Q, Zhang J. MaskNet: Introducing Feature-Wise
@@ -58,7 +58,8 @@ import torch.nn as nn
58
58
  import torch.nn.functional as F
59
59
 
60
60
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
61
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
61
+ from nextrec.basic.layers import MLP, EmbeddingLayer
62
+ from nextrec.basic.heads import TaskHead
62
63
  from nextrec.basic.model import BaseModel
63
64
 
64
65
 
@@ -282,14 +283,13 @@ class MaskNet(BaseModel):
282
283
  input_dim=self.num_blocks * block_hidden_dim, **mlp_params
283
284
  )
284
285
  self.output_layer = None
285
- self.prediction_layer = PredictionLayer(task_type=self.task)
286
+ self.prediction_layer = TaskHead(task_type=self.task)
286
287
 
287
288
  if self.architecture == "serial":
288
289
  self.register_regularization_weights(
289
290
  embedding_attr="embedding",
290
291
  include_modules=["mask_blocks", "output_layer"],
291
292
  )
292
- # serial
293
293
  else:
294
294
  self.register_regularization_weights(
295
295
  embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"]
@@ -314,7 +314,6 @@ class MaskNet(BaseModel):
314
314
  block_outputs.append(h)
315
315
  concat_hidden = torch.cat(block_outputs, dim=-1)
316
316
  logit = self.final_mlp(concat_hidden) # [B, 1]
317
- # serial
318
317
  else:
319
318
  hidden = self.first_block(field_emb, v_emb_flat)
320
319
  hidden = self.block_dropout(hidden)
@@ -1,7 +1,7 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
3
+ Checkpoint: edit on 23/12/2025
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Qu Y, Cai H, Ren K, et al. Product-based neural networks for user response
7
7
  prediction[C]//ICDM. 2016: 1149-1154. (https://arxiv.org/abs/1611.00144)
@@ -38,7 +38,8 @@ import torch
38
38
  import torch.nn as nn
39
39
 
40
40
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
41
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
41
+ from nextrec.basic.layers import MLP, EmbeddingLayer
42
+ from nextrec.basic.heads import TaskHead
42
43
  from nextrec.basic.model import BaseModel
43
44
 
44
45
 
@@ -136,7 +137,7 @@ class PNN(BaseModel):
136
137
  product_dim = 2 * self.num_pairs
137
138
 
138
139
  self.mlp = MLP(input_dim=linear_dim + product_dim, **mlp_params)
139
- self.prediction_layer = PredictionLayer(task_type=self.task)
140
+ self.prediction_layer = TaskHead(task_type=self.task)
140
141
 
141
142
  modules = ["mlp"]
142
143
  if self.kernel is not None:
@@ -1,12 +1,11 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Checkpoint: edit on 24/11/2025
4
- Author:
5
- Yang Zhou,zyaztec@gmail.com
3
+ Checkpoint: edit on 23/12/2025
4
+ Author: Yang Zhou, zyaztec@gmail.com
6
5
  Reference:
7
- [1] Cheng H T, Koc L, Harmsen J, et al. Wide & Deep learning for recommender systems[C]
8
- //Proceedings of the 1st Workshop on Deep Learning for Recommender Systems. 2016: 7-10.
9
- (https://arxiv.org/abs/1606.07792)
6
+ [1] Cheng H T, Koc L, Harmsen J, et al. Wide & Deep learning for recommender systems[C]
7
+ //Proceedings of the 1st Workshop on Deep Learning for Recommender Systems. 2016: 7-10.
8
+ (https://arxiv.org/abs/1606.07792)
10
9
 
11
10
  Wide & Deep blends a linear wide component (memorization of cross features) with a
12
11
  deep neural network (generalization) sharing the same feature space. The wide part
@@ -42,7 +41,8 @@ Wide & Deep 同时使用宽线性部分(记忆共现/手工交叉)与深网
42
41
  import torch.nn as nn
43
42
 
44
43
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
45
- from nextrec.basic.layers import LR, MLP, EmbeddingLayer, PredictionLayer
44
+ from nextrec.basic.layers import LR, MLP, EmbeddingLayer
45
+ from nextrec.basic.heads import TaskHead
46
46
  from nextrec.basic.model import BaseModel
47
47
 
48
48
 
@@ -114,7 +114,7 @@ class WideDeep(BaseModel):
114
114
  # deep_emb_dim_total = sum([f.embedding_dim for f in self.deep_features if not isinstance(f, DenseFeature)])
115
115
  # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
116
116
  self.mlp = MLP(input_dim=input_dim, **mlp_params)
117
- self.prediction_layer = PredictionLayer(task_type=self.task)
117
+ self.prediction_layer = TaskHead(task_type=self.task)
118
118
  # Register regularization weights
119
119
  self.register_regularization_weights(
120
120
  embedding_attr="embedding", include_modules=["linear", "mlp"]
@@ -1,7 +1,7 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
- Author:
4
- Yang Zhou,zyaztec@gmail.com
3
+ Checkpoint: edit on 23/12/2025
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  [1] Lian J, Zhou X, Zhang F, et al. xdeepfm: Combining explicit and implicit feature interactions
7
7
  for recommender systems[C]//Proceedings of the 24th ACM SIGKDD international conference on
@@ -56,7 +56,8 @@ import torch.nn as nn
56
56
  import torch.nn.functional as F
57
57
 
58
58
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
59
- from nextrec.basic.layers import LR, MLP, EmbeddingLayer, PredictionLayer
59
+ from nextrec.basic.layers import LR, MLP, EmbeddingLayer
60
+ from nextrec.basic.heads import TaskHead
60
61
  from nextrec.basic.model import BaseModel
61
62
 
62
63
 
@@ -186,7 +187,7 @@ class xDeepFM(BaseModel):
186
187
  [getattr(f, "embedding_dim", 1) or 1 for f in dense_features]
187
188
  )
188
189
  self.mlp = MLP(input_dim=deep_emb_dim_total + dense_input_dim, **mlp_params)
189
- self.prediction_layer = PredictionLayer(task_type=self.task)
190
+ self.prediction_layer = TaskHead(task_type=self.task)
190
191
 
191
192
  # Register regularization weights
192
193
  self.register_regularization_weights(
nextrec/utils/console.py CHANGED
@@ -203,18 +203,32 @@ def progress(iterable, *, description=None, total=None, disable=False):
203
203
  console=console,
204
204
  )
205
205
 
206
- with progress_bar:
207
- task_id = progress_bar.add_task(description or "Working", total=resolved_total)
208
- for item in iterable:
209
- yield item
210
- progress_bar.advance(task_id, 1)
206
+ if hasattr(progress_bar, "__enter__"):
207
+ with progress_bar:
208
+ task_id = progress_bar.add_task(
209
+ description or "Working", total=resolved_total
210
+ )
211
+ for item in iterable:
212
+ yield item
213
+ progress_bar.advance(task_id, 1)
214
+ else:
215
+ progress_bar.start()
216
+ try:
217
+ task_id = progress_bar.add_task(
218
+ description or "Working", total=resolved_total
219
+ )
220
+ for item in iterable:
221
+ yield item
222
+ progress_bar.advance(task_id, 1)
223
+ finally:
224
+ progress_bar.stop()
211
225
 
212
226
 
213
227
  def group_metrics_by_task(
214
228
  metrics: Mapping[str, Any] | None,
215
229
  target_names: list[str] | str | None,
216
230
  default_task_name: str = "overall",
217
- ) -> tuple[list[str], dict[str, dict[str, float]]]:
231
+ ):
218
232
  if not metrics:
219
233
  return [], {}
220
234