nextrec 0.4.20__py3-none-any.whl → 0.4.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +9 -4
  3. nextrec/basic/callback.py +39 -87
  4. nextrec/basic/features.py +149 -28
  5. nextrec/basic/heads.py +3 -1
  6. nextrec/basic/layers.py +375 -94
  7. nextrec/basic/loggers.py +236 -39
  8. nextrec/basic/model.py +259 -326
  9. nextrec/basic/session.py +2 -2
  10. nextrec/basic/summary.py +323 -0
  11. nextrec/cli.py +3 -3
  12. nextrec/data/data_processing.py +45 -1
  13. nextrec/data/dataloader.py +2 -2
  14. nextrec/data/preprocessor.py +2 -2
  15. nextrec/loss/__init__.py +0 -4
  16. nextrec/loss/grad_norm.py +3 -3
  17. nextrec/models/multi_task/esmm.py +4 -6
  18. nextrec/models/multi_task/mmoe.py +4 -6
  19. nextrec/models/multi_task/ple.py +6 -8
  20. nextrec/models/multi_task/poso.py +5 -7
  21. nextrec/models/multi_task/share_bottom.py +6 -8
  22. nextrec/models/ranking/afm.py +4 -6
  23. nextrec/models/ranking/autoint.py +4 -6
  24. nextrec/models/ranking/dcn.py +8 -7
  25. nextrec/models/ranking/dcn_v2.py +4 -6
  26. nextrec/models/ranking/deepfm.py +5 -7
  27. nextrec/models/ranking/dien.py +8 -7
  28. nextrec/models/ranking/din.py +8 -7
  29. nextrec/models/ranking/eulernet.py +5 -7
  30. nextrec/models/ranking/ffm.py +5 -7
  31. nextrec/models/ranking/fibinet.py +4 -6
  32. nextrec/models/ranking/fm.py +4 -6
  33. nextrec/models/ranking/lr.py +4 -6
  34. nextrec/models/ranking/masknet.py +8 -9
  35. nextrec/models/ranking/pnn.py +4 -6
  36. nextrec/models/ranking/widedeep.py +5 -7
  37. nextrec/models/ranking/xdeepfm.py +8 -7
  38. nextrec/models/retrieval/dssm.py +4 -10
  39. nextrec/models/retrieval/dssm_v2.py +0 -6
  40. nextrec/models/retrieval/mind.py +4 -10
  41. nextrec/models/retrieval/sdm.py +4 -10
  42. nextrec/models/retrieval/youtube_dnn.py +4 -10
  43. nextrec/models/sequential/hstu.py +1 -3
  44. nextrec/utils/__init__.py +17 -15
  45. nextrec/utils/config.py +15 -5
  46. nextrec/utils/console.py +2 -2
  47. nextrec/utils/feature.py +2 -2
  48. nextrec/{loss/loss_utils.py → utils/loss.py} +21 -36
  49. nextrec/utils/torch_utils.py +57 -112
  50. nextrec/utils/types.py +63 -0
  51. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/METADATA +8 -6
  52. nextrec-0.4.22.dist-info/RECORD +81 -0
  53. nextrec-0.4.20.dist-info/RECORD +0 -79
  54. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/WHEEL +0 -0
  55. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/entry_points.txt +0 -0
  56. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.20"
1
+ __version__ = "0.4.22"
@@ -1,8 +1,8 @@
1
1
  """
2
- Activation function definitions
2
+ Activation function definitions for NextRec models.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 28/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -10,6 +10,9 @@ import torch
10
10
  import torch.nn as nn
11
11
 
12
12
 
13
+ from nextrec.utils.types import ActivationName
14
+
15
+
13
16
  class Dice(nn.Module):
14
17
  """
15
18
  Dice activation function from the paper:
@@ -41,9 +44,11 @@ class Dice(nn.Module):
41
44
  return output
42
45
 
43
46
 
44
- def activation_layer(activation: str, emb_size: int | None = None):
47
+ def activation_layer(
48
+ activation: ActivationName = "none",
49
+ emb_size: int | None = None,
50
+ ):
45
51
  """Create an activation layer based on the given activation name."""
46
- activation = activation.lower()
47
52
  if activation == "dice":
48
53
  if emb_size is None:
49
54
  raise ValueError(
nextrec/basic/callback.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Callback System for Training Process
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 19/12/2025
5
+ Checkpoint: edit on 27/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -61,16 +61,16 @@ class Callback:
61
61
  self.params = params
62
62
 
63
63
  def should_run(self) -> bool:
64
- if not getattr(self, "run_on_main_process_only", False):
64
+ if not self.run_on_main_process_only:
65
65
  return True
66
- model = getattr(self, "model", None)
67
- if model is None:
68
- return True
69
- return bool(getattr(model, "is_main_process", True))
66
+ model = self.model
67
+ return bool(model.is_main_process)
70
68
 
71
69
 
72
70
  class CallbackList:
73
- """Generates a list of callbacks"""
71
+ """
72
+ Generates a list of callbacks
73
+ """
74
74
 
75
75
  def __init__(self, callbacks: Optional[list[Callback]] = None):
76
76
  self.callbacks = callbacks or []
@@ -85,7 +85,8 @@ class CallbackList:
85
85
  getattr(callback, fn_name)(*args, **kwargs)
86
86
 
87
87
  def set_model(self, model):
88
- self.call("set_model", model)
88
+ for callback in self.callbacks:
89
+ callback.set_model(model)
89
90
 
90
91
  def set_params(self, params: dict):
91
92
  self.call("set_params", params)
@@ -194,9 +195,8 @@ class EarlyStopper(Callback):
194
195
  self.wait += 1
195
196
  if self.wait >= self.patience:
196
197
  self.stopped_epoch = epoch
197
- if hasattr(self.model, "stop_training"):
198
- self.model.stop_training = True
199
- if self.verbose > 0:
198
+ self.model.stop_training = True
199
+ if self.verbose == 1:
200
200
  logging.info(
201
201
  f"Early stopping triggered at epoch {epoch + 1}. "
202
202
  f"Best {self.monitor}: {self.best_value:.6f} at epoch {self.best_epoch + 1}"
@@ -218,14 +218,15 @@ class EarlyStopper(Callback):
218
218
 
219
219
 
220
220
  class CheckpointSaver(Callback):
221
- """Callback to save model checkpoints during training.
221
+ """
222
+ Callback to save model checkpoints during training.
222
223
 
223
224
  Args:
224
225
  save_path: Path to save checkpoints.
225
226
  monitor: Metric name to monitor for saving best model.
226
227
  mode: One of {'min', 'max'}.
227
228
  save_best_only: If True, only save when the model is considered the "best".
228
- save_freq: Frequency of checkpoint saving ('epoch' or integer for every N epochs).
229
+ save_freq: Frequency of checkpoint saving (integer for every N epochs).
229
230
  verbose: Verbosity mode.
230
231
  run_on_main_process_only: Whether to run this callback only on the main process in DDP.
231
232
  """
@@ -237,7 +238,7 @@ class CheckpointSaver(Callback):
237
238
  monitor: str = "val_auc",
238
239
  mode: str = "max",
239
240
  save_best_only: bool = False,
240
- save_freq: str | int = "epoch",
241
+ save_freq: int = 1,
241
242
  verbose: int = 1,
242
243
  run_on_main_process_only: bool = True,
243
244
  ):
@@ -272,7 +273,7 @@ class CheckpointSaver(Callback):
272
273
  logs = logs or {}
273
274
 
274
275
  should_save = False
275
- if self.save_freq == "epoch":
276
+ if self.save_freq == 1:
276
277
  should_save = True
277
278
  elif isinstance(self.save_freq, int) and (epoch + 1) % self.save_freq == 0:
278
279
  should_save = True
@@ -306,12 +307,10 @@ class CheckpointSaver(Callback):
306
307
 
307
308
  def save_checkpoint(self, path: Path, epoch: int, logs: dict):
308
309
 
309
- # Get the actual model (unwrap DDP if needed)
310
- model_to_save = (
311
- self.model.ddp_model.module
312
- if getattr(self.model, "ddp_model", None) is not None
313
- else self.model
314
- )
310
+ if hasattr(self.model, "ddp_model") and self.model.ddp_model is not None:
311
+ model_to_save = self.model.ddp_model.module
312
+ else:
313
+ model_to_save = self.model
315
314
 
316
315
  # Save only state_dict to match BaseModel.save_model() format
317
316
  torch.save(model_to_save.state_dict(), path)
@@ -328,12 +327,13 @@ class CheckpointSaver(Callback):
328
327
  with open(config_path, "wb") as f:
329
328
  pickle.dump(features_config, f)
330
329
 
331
- if self.verbose > 1:
330
+ if self.verbose == 1:
332
331
  logging.info(f"Saved checkpoint to {path}")
333
332
 
334
333
 
335
334
  class LearningRateScheduler(Callback):
336
- """Callback for learning rate scheduling.
335
+ """
336
+ Callback for learning rate scheduling.
337
337
 
338
338
  Args:
339
339
  scheduler: Learning rate scheduler instance or name.
@@ -346,73 +346,25 @@ class LearningRateScheduler(Callback):
346
346
  self.verbose = verbose
347
347
 
348
348
  def on_train_begin(self, logs: Optional[dict] = None):
349
- if self.scheduler is None and hasattr(self.model, "scheduler_fn"):
349
+ if self.scheduler is None:
350
350
  self.scheduler = self.model.scheduler_fn
351
351
 
352
352
  def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
353
353
  if self.scheduler is not None:
354
- # Get current lr before step
355
- if hasattr(self.model, "optimizer_fn"):
356
- old_lr = self.model.optimizer_fn.param_groups[0]["lr"]
357
-
358
- # Step the scheduler
359
- if hasattr(self.scheduler, "step"):
360
- # Some schedulers need metrics
361
- if logs is None:
362
- logs = {}
363
- if "val_loss" in logs and hasattr(self.scheduler, "mode"):
364
- self.scheduler.step(logs["val_loss"])
365
- else:
366
- self.scheduler.step()
354
+ old_lr = self.model.optimizer_fn.param_groups[0]["lr"]
355
+ if logs is None:
356
+ logs = {}
367
357
 
368
- # Log new lr
369
- if self.verbose > 0 and hasattr(self.model, "optimizer_fn"):
370
- if getattr(self.model, "is_main_process", True):
371
- new_lr = self.model.optimizer_fn.param_groups[0]["lr"]
372
- if new_lr != old_lr:
373
- logging.info(
374
- f"Learning rate changed from {old_lr:.6e} to {new_lr:.6e}"
375
- )
358
+ # step for ReduceLROnPlateau
359
+ if "val_loss" in logs and hasattr(self.scheduler, "mode"):
360
+ self.scheduler.step(logs["val_loss"])
361
+ else:
362
+ self.scheduler.step()
376
363
 
377
-
378
- class MetricsLogger(Callback):
379
- """Callback for logging training metrics.
380
-
381
- Args:
382
- log_freq: Frequency of logging ('epoch', 'batch', or integer for every N epochs/batches).
383
- verbose: Verbosity mode.
384
- """
385
-
386
- def __init__(self, log_freq: str | int = "epoch", verbose: int = 1):
387
- super().__init__()
388
- self.run_on_main_process_only = True
389
- self.log_freq = log_freq
390
- self.verbose = verbose
391
-
392
- def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
393
- if self.verbose > 0 and (
394
- self.log_freq == "epoch"
395
- or (isinstance(self.log_freq, int) and (epoch + 1) % self.log_freq == 0)
396
- ):
397
- logs = logs or {}
398
- metrics_str = " - ".join(
399
- [
400
- f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
401
- for k, v in logs.items()
402
- ]
403
- )
404
- logging.info(f"Epoch {epoch + 1}: {metrics_str}")
405
-
406
- def on_batch_end(self, batch: int, logs: Optional[dict] = None):
407
- if self.verbose > 1 and (
408
- self.log_freq == "batch"
409
- or (isinstance(self.log_freq, int) and (batch + 1) % self.log_freq == 0)
410
- ):
411
- logs = logs or {}
412
- metrics_str = " - ".join(
413
- [
414
- f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
415
- for k, v in logs.items()
416
- ]
417
- )
418
- logging.info(f"Batch {batch}: {metrics_str}")
364
+ # Log new lr
365
+ if self.verbose == 1:
366
+ new_lr = self.model.optimizer_fn.param_groups[0]["lr"]
367
+ if new_lr != old_lr:
368
+ logging.info(
369
+ f"Learning rate changed from {old_lr:.6e} to {new_lr:.6e}"
370
+ )
nextrec/basic/features.py CHANGED
@@ -1,15 +1,17 @@
1
1
  """
2
- Feature definitions
2
+ Feature definitions for NextRec models.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 20/12/2025
5
+ Checkpoint: edit on 27/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
9
  import torch
10
10
 
11
+ from typing import Literal
12
+
11
13
  from nextrec.utils.embedding import get_auto_embedding_dim
12
- from nextrec.utils.feature import normalize_to_list
14
+ from nextrec.utils.feature import to_list
13
15
 
14
16
 
15
17
  class BaseFeature:
@@ -25,12 +27,20 @@ class EmbeddingFeature(BaseFeature):
25
27
  name: str,
26
28
  vocab_size: int,
27
29
  embedding_name: str = "",
28
- embedding_dim: int | None = 4,
29
- padding_idx: int | None = None,
30
- init_type: str = "normal",
30
+ embedding_dim: int | None = None,
31
+ padding_idx: int = 0,
32
+ init_type: Literal[
33
+ "normal",
34
+ "uniform",
35
+ "xavier_uniform",
36
+ "xavier_normal",
37
+ "kaiming_uniform",
38
+ "kaiming_normal",
39
+ "orthogonal",
40
+ ] = "normal",
31
41
  init_params: dict | None = None,
32
42
  l1_reg: float = 0.0,
33
- l2_reg: float = 1e-5,
43
+ l2_reg: float = 0.0,
34
44
  trainable: bool = True,
35
45
  pretrained_weight: torch.Tensor | None = None,
36
46
  freeze_pretrained: bool = False,
@@ -55,23 +65,57 @@ class EmbeddingFeature(BaseFeature):
55
65
 
56
66
 
57
67
  class SequenceFeature(EmbeddingFeature):
68
+
58
69
  def __init__(
59
70
  self,
60
71
  name: str,
61
72
  vocab_size: int,
62
- max_len: int = 20,
73
+ max_len: int = 50,
63
74
  embedding_name: str = "",
64
- embedding_dim: int | None = 4,
65
- combiner: str = "mean",
66
- padding_idx: int | None = None,
67
- init_type: str = "normal",
75
+ embedding_dim: int | None = None,
76
+ combiner: Literal[
77
+ "mean",
78
+ "sum",
79
+ "concat",
80
+ "dot_attention",
81
+ "self_attention",
82
+ ] = "mean",
83
+ padding_idx: int = 0,
84
+ init_type: Literal[
85
+ "normal",
86
+ "uniform",
87
+ "xavier_uniform",
88
+ "xavier_normal",
89
+ "kaiming_uniform",
90
+ "kaiming_normal",
91
+ "orthogonal",
92
+ ] = "normal",
68
93
  init_params: dict | None = None,
69
94
  l1_reg: float = 0.0,
70
- l2_reg: float = 1e-5,
95
+ l2_reg: float = 0.0,
71
96
  trainable: bool = True,
72
97
  pretrained_weight: torch.Tensor | None = None,
73
98
  freeze_pretrained: bool = False,
74
99
  ):
100
+ """
101
+ Sequence feature for variable-length categorical id sequences.
102
+
103
+ Args:
104
+ name: Feature name used as input key.
105
+ vocab_size: Number of unique ids in the sequence vocabulary.
106
+ max_len: Maximum sequence length for padding/truncation.
107
+ embedding_name: Shared embedding table name. Defaults to ``name``.
108
+ embedding_dim: Embedding dimension. Set to ``None`` for auto sizing.
109
+ combiner: Pooling method for sequence embeddings, e.g. ``"mean"`` or ``"sum"``.
110
+ padding_idx: Index used for padding tokens.
111
+ init_type: Embedding initializer type.
112
+ init_params: Initializer parameters.
113
+ l1_reg: L1 regularization weight on embedding.
114
+ l2_reg: L2 regularization weight on embedding.
115
+ trainable: Whether the embedding is trainable. [TODO] This is for representation learning.
116
+ pretrained_weight: Optional pretrained embedding weights. [TODO] This is for representation learning.
117
+ freeze_pretrained: If True, keep pretrained weights frozen. [TODO] This is for representation learning.
118
+ """
75
119
  super().__init__(
76
120
  name=name,
77
121
  vocab_size=vocab_size,
@@ -91,28 +135,105 @@ class SequenceFeature(EmbeddingFeature):
91
135
 
92
136
 
93
137
  class SparseFeature(EmbeddingFeature):
94
- pass
138
+
139
+ def __init__(
140
+ self,
141
+ name: str,
142
+ vocab_size: int,
143
+ embedding_name: str = "",
144
+ embedding_dim: int | None = None,
145
+ padding_idx: int = 0,
146
+ init_type: Literal[
147
+ "normal",
148
+ "uniform",
149
+ "xavier_uniform",
150
+ "xavier_normal",
151
+ "kaiming_uniform",
152
+ "kaiming_normal",
153
+ "orthogonal",
154
+ ] = "normal",
155
+ init_params: dict | None = None,
156
+ l1_reg: float = 0.0,
157
+ l2_reg: float = 0.0,
158
+ trainable: bool = True,
159
+ pretrained_weight: torch.Tensor | None = None,
160
+ freeze_pretrained: bool = False,
161
+ ):
162
+ """
163
+ Sparse feature for categorical ids.
164
+
165
+ Args:
166
+ name: Feature name used as input key.
167
+ vocab_size: Number of unique categorical ids.
168
+ embedding_name: Shared embedding table name. Defaults to ``name``.
169
+ embedding_dim: Embedding dimension. Set to ``None`` for auto sizing.
170
+ padding_idx: Index used for padding tokens.
171
+ init_type: Embedding initializer type.
172
+ init_params: Initializer parameters.
173
+ l1_reg: L1 regularization weight on embedding.
174
+ l2_reg: L2 regularization weight on embedding.
175
+ trainable: Whether the embedding is trainable.
176
+ pretrained_weight: Optional pretrained embedding weights.
177
+ freeze_pretrained: If True, keep pretrained weights frozen.
178
+ """
179
+ super().__init__(
180
+ name=name,
181
+ vocab_size=vocab_size,
182
+ embedding_name=embedding_name,
183
+ embedding_dim=embedding_dim,
184
+ padding_idx=padding_idx,
185
+ init_type=init_type,
186
+ init_params=init_params,
187
+ l1_reg=l1_reg,
188
+ l2_reg=l2_reg,
189
+ trainable=trainable,
190
+ pretrained_weight=pretrained_weight,
191
+ freeze_pretrained=freeze_pretrained,
192
+ )
95
193
 
96
194
 
97
195
  class DenseFeature(BaseFeature):
196
+
98
197
  def __init__(
99
198
  self,
100
199
  name: str,
101
- embedding_dim: int | None = 1,
102
200
  input_dim: int = 1,
103
- use_embedding: bool = False,
201
+ proj_dim: int | None = 0,
202
+ use_projection: bool = False,
203
+ trainable: bool = True,
204
+ pretrained_weight: torch.Tensor | None = None,
205
+ freeze_pretrained: bool = False,
104
206
  ):
207
+ """
208
+ Dense feature for continuous values.
209
+
210
+ Args:
211
+ name: Feature name used as input key.
212
+ input_dim: Input dimension for continuous values.
213
+ proj_dim: Projection dimension. If None or 0, no projection is applied.
214
+ use_projection: Whether to project inputs to higher dimension.
215
+ trainable: Whether the projection is trainable.
216
+ pretrained_weight: Optional pretrained projection weights.
217
+ freeze_pretrained: If True, keep pretrained weights frozen.
218
+ """
105
219
  self.name = name
106
- self.input_dim = max(int(input_dim or 1), 1)
107
- self.embedding_dim = self.input_dim if embedding_dim is None else embedding_dim
108
- if use_embedding and self.embedding_dim == 0:
220
+ self.input_dim = max(int(input_dim), 1)
221
+ self.proj_dim = self.input_dim if proj_dim is None else proj_dim
222
+ if use_projection and self.proj_dim == 0:
109
223
  raise ValueError(
110
- "[Features Error] DenseFeature: use_embedding=True is incompatible with embedding_dim=0"
224
+ "[Features Error] DenseFeature: use_projection=True is incompatible with proj_dim=0"
111
225
  )
112
- if embedding_dim is not None and embedding_dim > 1:
113
- self.use_embedding = True
226
+ if proj_dim is not None and proj_dim > 1:
227
+ self.use_projection = True
114
228
  else:
115
- self.use_embedding = use_embedding # user decides for dim <= 1
229
+ self.use_projection = use_projection
230
+ self.embedding_dim = (
231
+ self.input_dim if not self.use_projection else self.proj_dim
232
+ ) # for compatibility
233
+
234
+ self.trainable = trainable
235
+ self.pretrained_weight = pretrained_weight
236
+ self.freeze_pretrained = freeze_pretrained
116
237
 
117
238
 
118
239
  class FeatureSet:
@@ -123,7 +244,7 @@ class FeatureSet:
123
244
  sequence_features: list[SequenceFeature] | None = None,
124
245
  target: str | list[str] | None = None,
125
246
  id_columns: str | list[str] | None = None,
126
- ) -> None:
247
+ ):
127
248
  self.dense_features = list(dense_features) if dense_features else []
128
249
  self.sparse_features = list(sparse_features) if sparse_features else []
129
250
  self.sequence_features = list(sequence_features) if sequence_features else []
@@ -132,13 +253,13 @@ class FeatureSet:
132
253
  self.dense_features + self.sparse_features + self.sequence_features
133
254
  )
134
255
  self.feature_names = [feat.name for feat in self.all_features]
135
- self.target_columns = normalize_to_list(target)
136
- self.id_columns = normalize_to_list(id_columns)
256
+ self.target_columns = to_list(target)
257
+ self.id_columns = to_list(id_columns)
137
258
 
138
259
  def set_target_id(
139
260
  self,
140
261
  target: str | list[str] | None = None,
141
262
  id_columns: str | list[str] | None = None,
142
263
  ) -> None:
143
- self.target_columns = normalize_to_list(target)
144
- self.id_columns = normalize_to_list(id_columns)
264
+ self.target_columns = to_list(target)
265
+ self.id_columns = to_list(id_columns)
nextrec/basic/heads.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Task head implementations for NextRec models.
3
3
 
4
4
  Date: create on 23/12/2025
5
+ Checkpoint: edit on 27/12/2025
5
6
  Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
8
 
@@ -14,6 +15,7 @@ import torch.nn as nn
14
15
  import torch.nn.functional as F
15
16
 
16
17
  from nextrec.basic.layers import PredictionLayer
18
+ from nextrec.utils.types import TaskTypeName
17
19
 
18
20
 
19
21
  class TaskHead(nn.Module):
@@ -26,7 +28,7 @@ class TaskHead(nn.Module):
26
28
 
27
29
  def __init__(
28
30
  self,
29
- task_type: str | list[str] = "binary",
31
+ task_type: TaskTypeName | list[TaskTypeName] = "binary",
30
32
  task_dims: int | list[int] | None = None,
31
33
  use_bias: bool = True,
32
34
  return_logits: bool = False,