nextrec 0.4.23__py3-none-any.whl → 0.4.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -45,7 +45,15 @@ from nextrec.utils.data import (
45
45
 
46
46
 
47
47
  class DataProcessor(FeatureSet):
48
- def __init__(self, hash_cache_size: int = 200_000):
48
+ def __init__(
49
+ self,
50
+ hash_cache_size: int = 200_000,
51
+ ):
52
+ if not logging.getLogger().hasHandlers():
53
+ logging.basicConfig(
54
+ level=logging.INFO,
55
+ format="%(message)s",
56
+ )
49
57
  self.numeric_features: Dict[str, Dict[str, Any]] = {}
50
58
  self.sparse_features: Dict[str, Dict[str, Any]] = {}
51
59
  self.sequence_features: Dict[str, Dict[str, Any]] = {}
@@ -53,9 +61,6 @@ class DataProcessor(FeatureSet):
53
61
  self.version = __version__
54
62
 
55
63
  self.is_fitted = False
56
- self._transform_summary_printed = (
57
- False # Track if summary has been printed during transform
58
- )
59
64
 
60
65
  self.scalers: Dict[str, Any] = {}
61
66
  self.label_encoders: Dict[str, LabelEncoder] = {}
@@ -92,17 +97,19 @@ class DataProcessor(FeatureSet):
92
97
  def add_sparse_feature(
93
98
  self,
94
99
  name: str,
95
- encode_method: Literal["hash", "label"] = "label",
100
+ encode_method: Literal["hash", "label"] = "hash",
96
101
  hash_size: Optional[int] = None,
102
+ min_freq: Optional[int] = None,
97
103
  fill_na: str = "<UNK>",
98
104
  ):
99
105
  """Add a sparse feature configuration.
100
106
 
101
107
  Args:
102
- name (str): Feature name.
103
- encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "label".
104
- hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
105
- fill_na (str, optional): Fill value for missing entries. Defaults to "<UNK>".
108
+ name: Feature name.
109
+ encode_method: Encoding method, including "hash encoding" and "label encoding". Defaults to "hash" because it is more scalable and much faster.
110
+ hash_size: Hash size for hash encoding. Required if encode_method is "hash".
111
+ min_freq: Minimum frequency for hash encoding to keep tokens; lower-frequency tokens map to unknown. Defaults to None.
112
+ fill_na: Fill value for missing entries. Defaults to "<UNK>".
106
113
  """
107
114
  if encode_method == "hash" and hash_size is None:
108
115
  raise ValueError(
@@ -111,6 +118,7 @@ class DataProcessor(FeatureSet):
111
118
  self.sparse_features[name] = {
112
119
  "encode_method": encode_method,
113
120
  "hash_size": hash_size,
121
+ "min_freq": min_freq,
114
122
  "fill_na": fill_na,
115
123
  }
116
124
 
@@ -119,6 +127,7 @@ class DataProcessor(FeatureSet):
119
127
  name: str,
120
128
  encode_method: Literal["hash", "label"] = "hash",
121
129
  hash_size: Optional[int] = None,
130
+ min_freq: Optional[int] = None,
122
131
  max_len: Optional[int] = 50,
123
132
  pad_value: int = 0,
124
133
  truncate: Literal[
@@ -129,13 +138,14 @@ class DataProcessor(FeatureSet):
129
138
  """Add a sequence feature configuration.
130
139
 
131
140
  Args:
132
- name (str): Feature name.
133
- encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "hash".
134
- hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
135
- max_len (Optional[int], optional): Maximum sequence length. Defaults to 50.
136
- pad_value (int, optional): Padding value for sequences shorter than max_len. Defaults to 0.
137
- truncate (Literal["pre", "post"], optional): Truncation strategy for sequences longer than max_len, including "pre" (keep last max_len items) and "post" (keep first max_len items). Defaults to "pre".
138
- separator (str, optional): Separator for string sequences. Defaults to ",".
141
+ name: Feature name.
142
+ encode_method: Encoding method, including "hash encoding" and "label encoding". Defaults to "hash".
143
+ hash_size: Hash size for hash encoding. Required if encode_method is "hash".
144
+ min_freq: Minimum frequency for hash encoding to keep tokens; lower-frequency tokens map to unknown. Defaults to None.
145
+ max_len: Maximum sequence length. Defaults to 50.
146
+ pad_value: Padding value for sequences shorter than max_len. Defaults to 0.
147
+ truncate: Truncation strategy for sequences longer than max_len, including "pre" (keep last max_len items) and "post" (keep first max_len items). Defaults to "pre".
148
+ separator: Separator for string sequences. Defaults to ",".
139
149
  """
140
150
  if encode_method == "hash" and hash_size is None:
141
151
  raise ValueError(
@@ -144,6 +154,7 @@ class DataProcessor(FeatureSet):
144
154
  self.sequence_features[name] = {
145
155
  "encode_method": encode_method,
146
156
  "hash_size": hash_size,
157
+ "min_freq": min_freq,
147
158
  "max_len": max_len,
148
159
  "pad_value": pad_value,
149
160
  "truncate": truncate,
@@ -175,17 +186,6 @@ class DataProcessor(FeatureSet):
175
186
  def hash_string(self, s: str, hash_size: int) -> int:
176
187
  return self.hash_fn(str(s), int(hash_size))
177
188
 
178
- def clear_hash_cache(self) -> None:
179
- cache_clear = getattr(self.hash_fn, "cache_clear", None)
180
- if callable(cache_clear):
181
- cache_clear()
182
-
183
- def hash_cache_info(self):
184
- cache_info = getattr(self.hash_fn, "cache_info", None)
185
- if callable(cache_info):
186
- return cache_info()
187
- return None
188
-
189
189
  def process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
190
190
  name = str(data.name)
191
191
  scaler_type = config["scaler"]
@@ -241,12 +241,30 @@ class DataProcessor(FeatureSet):
241
241
  return result
242
242
 
243
243
  def process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
244
- _ = str(data.name)
244
+ logger = logging.getLogger()
245
+
245
246
  encode_method = config["encode_method"]
246
247
  fill_na = config["fill_na"] # <UNK>
247
248
  filled_data = data.fillna(fill_na).astype(str)
248
249
  if encode_method == "label":
249
- vocab = sorted(set(filled_data.tolist()))
250
+ min_freq = config.get("min_freq")
251
+ if min_freq is not None:
252
+ counts = filled_data.value_counts()
253
+ config["_token_counts"] = counts.to_dict()
254
+ vocab = sorted(counts[counts >= min_freq].index.tolist())
255
+ low_freq_types = int((counts < min_freq).sum())
256
+ total_types = int(counts.size)
257
+ kept_types = total_types - low_freq_types
258
+ if not config.get("_min_freq_logged"):
259
+ logger.info(
260
+ f"Sparse feature {data.name} min_freq={min_freq}: "
261
+ f"{total_types} token types total, "
262
+ f"{low_freq_types} low-frequency, "
263
+ f"{kept_types} kept."
264
+ )
265
+ config["_min_freq_logged"] = True
266
+ else:
267
+ vocab = sorted(set(filled_data.tolist()))
250
268
  if "<UNK>" not in vocab:
251
269
  vocab.append("<UNK>")
252
270
  token_to_idx = {token: idx for idx, token in enumerate(vocab)}
@@ -254,6 +272,24 @@ class DataProcessor(FeatureSet):
254
272
  config["_unk_index"] = token_to_idx["<UNK>"]
255
273
  config["vocab_size"] = len(vocab)
256
274
  elif encode_method == "hash":
275
+ min_freq = config.get("min_freq")
276
+ if min_freq is not None:
277
+ counts = filled_data.value_counts()
278
+ config["_token_counts"] = counts.to_dict()
279
+ config["_unk_hash"] = self.hash_string(
280
+ "<UNK>", int(config["hash_size"])
281
+ )
282
+ low_freq_types = int((counts < min_freq).sum())
283
+ total_types = int(counts.size)
284
+ kept_types = total_types - low_freq_types
285
+ if not config.get("_min_freq_logged"):
286
+ logger.info(
287
+ f"Sparse feature {data.name} min_freq={min_freq}: "
288
+ f"{total_types} token types total, "
289
+ f"{low_freq_types} low-frequency, "
290
+ f"{kept_types} kept."
291
+ )
292
+ config["_min_freq_logged"] = True
257
293
  config["vocab_size"] = config["hash_size"]
258
294
 
259
295
  def process_sparse_feature_transform(
@@ -283,22 +319,60 @@ class DataProcessor(FeatureSet):
283
319
  if encode_method == "hash":
284
320
  hash_size = config["hash_size"]
285
321
  hash_fn = self.hash_string
322
+ min_freq = config.get("min_freq")
323
+ token_counts = config.get("_token_counts")
324
+ if min_freq is not None and isinstance(token_counts, dict):
325
+ unk_hash = config.get("_unk_hash")
326
+ if unk_hash is None:
327
+ unk_hash = hash_fn("<UNK>", hash_size)
286
328
  return np.fromiter(
287
- (hash_fn(v, hash_size) for v in sparse_series.to_numpy()),
329
+ (
330
+ (
331
+ unk_hash
332
+ if min_freq is not None
333
+ and isinstance(token_counts, dict)
334
+ and token_counts.get(v, 0) < min_freq
335
+ else hash_fn(v, hash_size)
336
+ )
337
+ for v in sparse_series.to_numpy()
338
+ ),
288
339
  dtype=np.int64,
289
340
  count=sparse_series.size,
290
341
  )
291
342
  return np.array([], dtype=np.int64)
292
343
 
293
344
  def process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
345
+ logger = logging.getLogger()
294
346
  _ = str(data.name)
295
347
  encode_method = config["encode_method"]
296
348
  separator = config["separator"]
297
349
  if encode_method == "label":
298
- all_tokens = set()
350
+ min_freq = config.get("min_freq")
351
+ token_counts: Dict[str, int] = {}
299
352
  for seq in data:
300
- all_tokens.update(self.extract_sequence_tokens(seq, separator))
301
- vocab = sorted(all_tokens)
353
+ tokens = self.extract_sequence_tokens(seq, separator)
354
+ for token in tokens:
355
+ if str(token).strip():
356
+ key = str(token)
357
+ token_counts[key] = token_counts.get(key, 0) + 1
358
+ if min_freq is not None:
359
+ config["_token_counts"] = token_counts
360
+ vocab = sorted([k for k, v in token_counts.items() if v >= min_freq])
361
+ low_freq_types = sum(
362
+ 1 for count in token_counts.values() if count < min_freq
363
+ )
364
+ total_types = len(token_counts)
365
+ kept_types = total_types - low_freq_types
366
+ if not config.get("_min_freq_logged"):
367
+ logger.info(
368
+ f"Sequence feature {data.name} min_freq={min_freq}: "
369
+ f"{total_types} token types total, "
370
+ f"{low_freq_types} low-frequency, "
371
+ f"{kept_types} kept."
372
+ )
373
+ config["_min_freq_logged"] = True
374
+ else:
375
+ vocab = sorted(token_counts.keys())
302
376
  if not vocab:
303
377
  vocab = ["<PAD>"]
304
378
  if "<UNK>" not in vocab:
@@ -308,6 +382,33 @@ class DataProcessor(FeatureSet):
308
382
  config["_unk_index"] = token_to_idx["<UNK>"]
309
383
  config["vocab_size"] = len(vocab)
310
384
  elif encode_method == "hash":
385
+ min_freq = config.get("min_freq")
386
+ if min_freq is not None:
387
+ token_counts: Dict[str, int] = {}
388
+ for seq in data:
389
+ tokens = self.extract_sequence_tokens(seq, separator)
390
+ for token in tokens:
391
+ if str(token).strip():
392
+ token_counts[str(token)] = (
393
+ token_counts.get(str(token), 0) + 1
394
+ )
395
+ config["_token_counts"] = token_counts
396
+ config["_unk_hash"] = self.hash_string(
397
+ "<UNK>", int(config["hash_size"])
398
+ )
399
+ low_freq_types = sum(
400
+ 1 for count in token_counts.values() if count < min_freq
401
+ )
402
+ total_types = len(token_counts)
403
+ kept_types = total_types - low_freq_types
404
+ if not config.get("_min_freq_logged"):
405
+ logger.info(
406
+ f"Sequence feature {data.name} min_freq={min_freq}: "
407
+ f"{total_types} token types total, "
408
+ f"{low_freq_types} low-frequency, "
409
+ f"{kept_types} kept."
410
+ )
411
+ config["_min_freq_logged"] = True
311
412
  config["vocab_size"] = config["hash_size"]
312
413
 
313
414
  def process_sequence_feature_transform(
@@ -338,6 +439,12 @@ class DataProcessor(FeatureSet):
338
439
  unk_index = 0
339
440
  hash_fn = self.hash_string
340
441
  hash_size = config.get("hash_size")
442
+ min_freq = config.get("min_freq")
443
+ token_counts = config.get("_token_counts")
444
+ if min_freq is not None and isinstance(token_counts, dict):
445
+ unk_hash = config.get("_unk_hash")
446
+ if unk_hash is None and hash_size is not None:
447
+ unk_hash = hash_fn("<UNK>", hash_size)
341
448
  for i, seq in enumerate(arr):
342
449
  # normalize sequence to a list of strings
343
450
  tokens = []
@@ -364,7 +471,13 @@ class DataProcessor(FeatureSet):
364
471
  "[Data Processor Error] hash_size must be set for hash encoding"
365
472
  )
366
473
  encoded = [
367
- hash_fn(str(token), hash_size)
474
+ (
475
+ unk_hash
476
+ if min_freq is not None
477
+ and isinstance(token_counts, dict)
478
+ and token_counts.get(str(token), 0) < min_freq
479
+ else hash_fn(str(token), hash_size)
480
+ )
368
481
  for token in tokens
369
482
  if str(token).strip()
370
483
  ]
@@ -472,6 +585,10 @@ class DataProcessor(FeatureSet):
472
585
  bold=True,
473
586
  )
474
587
  )
588
+ for config in self.sparse_features.values():
589
+ config.pop("_min_freq_logged", None)
590
+ for config in self.sequence_features.values():
591
+ config.pop("_min_freq_logged", None)
475
592
  file_paths, file_type = resolve_file_paths(path)
476
593
  if not check_streaming_support(file_type):
477
594
  raise ValueError(
@@ -496,6 +613,26 @@ class DataProcessor(FeatureSet):
496
613
  seq_vocab: Dict[str, set[str]] = {
497
614
  name: set() for name in self.sequence_features.keys()
498
615
  }
616
+ sparse_label_counts: Dict[str, Dict[str, int]] = {
617
+ name: {}
618
+ for name, config in self.sparse_features.items()
619
+ if config.get("encode_method") == "label" and config.get("min_freq")
620
+ }
621
+ seq_label_counts: Dict[str, Dict[str, int]] = {
622
+ name: {}
623
+ for name, config in self.sequence_features.items()
624
+ if config.get("encode_method") == "label" and config.get("min_freq")
625
+ }
626
+ sparse_hash_counts: Dict[str, Dict[str, int]] = {
627
+ name: {}
628
+ for name, config in self.sparse_features.items()
629
+ if config.get("encode_method") == "hash" and config.get("min_freq")
630
+ }
631
+ seq_hash_counts: Dict[str, Dict[str, int]] = {
632
+ name: {}
633
+ for name, config in self.sequence_features.items()
634
+ if config.get("encode_method") == "hash" and config.get("min_freq")
635
+ }
499
636
  target_values: Dict[str, set[Any]] = {
500
637
  name: set() for name in self.target_features.keys()
501
638
  }
@@ -531,6 +668,14 @@ class DataProcessor(FeatureSet):
531
668
  fill_na = config["fill_na"]
532
669
  series = series.fillna(fill_na).astype(str)
533
670
  sparse_vocab[name].update(series.tolist())
671
+ if name in sparse_label_counts:
672
+ counts = sparse_label_counts[name]
673
+ for token in series.tolist():
674
+ counts[token] = counts.get(token, 0) + 1
675
+ if name in sparse_hash_counts:
676
+ counts = sparse_hash_counts[name]
677
+ for token in series.tolist():
678
+ counts[token] = counts.get(token, 0) + 1
534
679
  else:
535
680
  separator = config["separator"]
536
681
  tokens = []
@@ -539,6 +684,18 @@ class DataProcessor(FeatureSet):
539
684
  self.extract_sequence_tokens(val, separator)
540
685
  )
541
686
  seq_vocab[name].update(tokens)
687
+ if name in seq_label_counts:
688
+ counts = seq_label_counts[name]
689
+ for token in tokens:
690
+ if str(token).strip():
691
+ key = str(token)
692
+ counts[key] = counts.get(key, 0) + 1
693
+ if name in seq_hash_counts:
694
+ counts = seq_hash_counts[name]
695
+ for token in tokens:
696
+ if str(token).strip():
697
+ key = str(token)
698
+ counts[key] = counts.get(key, 0) + 1
542
699
 
543
700
  # target features
544
701
  missing_features.update(self.target_features.keys() - columns)
@@ -605,7 +762,30 @@ class DataProcessor(FeatureSet):
605
762
  # finalize sparse label encoders
606
763
  for name, config in self.sparse_features.items():
607
764
  if config["encode_method"] == "label":
608
- vocab = sparse_vocab[name]
765
+ min_freq = config.get("min_freq")
766
+ if min_freq is not None:
767
+ token_counts = sparse_label_counts.get(name, {})
768
+ config["_token_counts"] = token_counts
769
+ vocab = {
770
+ token
771
+ for token, count in token_counts.items()
772
+ if count >= min_freq
773
+ }
774
+ low_freq_types = sum(
775
+ 1 for count in token_counts.values() if count < min_freq
776
+ )
777
+ total_types = len(token_counts)
778
+ kept_types = total_types - low_freq_types
779
+ if not config.get("_min_freq_logged"):
780
+ logger.info(
781
+ f"Sparse feature {name} min_freq={min_freq}: "
782
+ f"{total_types} token types total, "
783
+ f"{low_freq_types} low-frequency, "
784
+ f"{kept_types} kept."
785
+ )
786
+ config["_min_freq_logged"] = True
787
+ else:
788
+ vocab = sparse_vocab[name]
609
789
  if not vocab:
610
790
  logger.warning(f"Sparse feature {name} has empty vocabulary")
611
791
  continue
@@ -617,12 +797,55 @@ class DataProcessor(FeatureSet):
617
797
  config["_unk_index"] = token_to_idx["<UNK>"]
618
798
  config["vocab_size"] = len(vocab_list)
619
799
  elif config["encode_method"] == "hash":
800
+ min_freq = config.get("min_freq")
801
+ if min_freq is not None:
802
+ token_counts = sparse_hash_counts.get(name, {})
803
+ config["_token_counts"] = token_counts
804
+ config["_unk_hash"] = self.hash_string(
805
+ "<UNK>", int(config["hash_size"])
806
+ )
807
+ low_freq_types = sum(
808
+ 1 for count in token_counts.values() if count < min_freq
809
+ )
810
+ total_types = len(token_counts)
811
+ kept_types = total_types - low_freq_types
812
+ if not config.get("_min_freq_logged"):
813
+ logger.info(
814
+ f"Sparse feature {name} min_freq={min_freq}: "
815
+ f"{total_types} token types total, "
816
+ f"{low_freq_types} low-frequency, "
817
+ f"{kept_types} kept."
818
+ )
819
+ config["_min_freq_logged"] = True
620
820
  config["vocab_size"] = config["hash_size"]
621
821
 
622
822
  # finalize sequence vocabularies
623
823
  for name, config in self.sequence_features.items():
624
824
  if config["encode_method"] == "label":
625
- vocab_set = seq_vocab[name]
825
+ min_freq = config.get("min_freq")
826
+ if min_freq is not None:
827
+ token_counts = seq_label_counts.get(name, {})
828
+ config["_token_counts"] = token_counts
829
+ vocab_set = {
830
+ token
831
+ for token, count in token_counts.items()
832
+ if count >= min_freq
833
+ }
834
+ low_freq_types = sum(
835
+ 1 for count in token_counts.values() if count < min_freq
836
+ )
837
+ total_types = len(token_counts)
838
+ kept_types = total_types - low_freq_types
839
+ if not config.get("_min_freq_logged"):
840
+ logger.info(
841
+ f"Sequence feature {name} min_freq={min_freq}: "
842
+ f"{total_types} token types total, "
843
+ f"{low_freq_types} low-frequency, "
844
+ f"{kept_types} kept."
845
+ )
846
+ config["_min_freq_logged"] = True
847
+ else:
848
+ vocab_set = seq_vocab[name]
626
849
  vocab_list = sorted(vocab_set) if vocab_set else ["<PAD>"]
627
850
  if "<UNK>" not in vocab_list:
628
851
  vocab_list.append("<UNK>")
@@ -631,6 +854,26 @@ class DataProcessor(FeatureSet):
631
854
  config["_unk_index"] = token_to_idx["<UNK>"]
632
855
  config["vocab_size"] = len(vocab_list)
633
856
  elif config["encode_method"] == "hash":
857
+ min_freq = config.get("min_freq")
858
+ if min_freq is not None:
859
+ token_counts = seq_hash_counts.get(name, {})
860
+ config["_token_counts"] = token_counts
861
+ config["_unk_hash"] = self.hash_string(
862
+ "<UNK>", int(config["hash_size"])
863
+ )
864
+ low_freq_types = sum(
865
+ 1 for count in token_counts.values() if count < min_freq
866
+ )
867
+ total_types = len(token_counts)
868
+ kept_types = total_types - low_freq_types
869
+ if not config.get("_min_freq_logged"):
870
+ logger.info(
871
+ f"Sequence feature {name} min_freq={min_freq}: "
872
+ f"{total_types} token types total, "
873
+ f"{low_freq_types} low-frequency, "
874
+ f"{kept_types} kept."
875
+ )
876
+ config["_min_freq_logged"] = True
634
877
  config["vocab_size"] = config["hash_size"]
635
878
 
636
879
  # finalize targets
@@ -961,6 +1204,10 @@ class DataProcessor(FeatureSet):
961
1204
  """
962
1205
 
963
1206
  logger = logging.getLogger()
1207
+ for config in self.sparse_features.values():
1208
+ config.pop("_min_freq_logged", None)
1209
+ for config in self.sequence_features.values():
1210
+ config.pop("_min_freq_logged", None)
964
1211
  if isinstance(data, (str, os.PathLike)):
965
1212
  path_str = str(data)
966
1213
  uses_robust = any(
File without changes
File without changes
File without changes
@@ -116,10 +116,10 @@ class ESMM(BaseModel):
116
116
  input_dim = self.embedding.input_dim
117
117
 
118
118
  # CTR tower
119
- self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
119
+ self.ctr_tower = MLP(input_dim=input_dim, output_dim=1, **ctr_params)
120
120
 
121
121
  # CVR tower
122
- self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
122
+ self.cvr_tower = MLP(input_dim=input_dim, output_dim=1, **cvr_params)
123
123
  self.grad_norm_shared_modules = ["embedding"]
124
124
  self.prediction_layer = TaskHead(task_type=self.task, task_dims=[1, 1])
125
125
  # Register regularization weights
@@ -134,12 +134,12 @@ class MMOE(BaseModel):
134
134
  # Expert networks (shared by all tasks)
135
135
  self.experts = nn.ModuleList()
136
136
  for _ in range(num_experts):
137
- expert = MLP(input_dim=input_dim, output_layer=False, **expert_params)
137
+ expert = MLP(input_dim=input_dim, output_dim=None, **expert_params)
138
138
  self.experts.append(expert)
139
139
 
140
140
  # Get expert output dimension
141
- if "dims" in expert_params and len(expert_params["dims"]) > 0:
142
- expert_output_dim = expert_params["dims"][-1]
141
+ if "hidden_dims" in expert_params and len(expert_params["hidden_dims"]) > 0:
142
+ expert_output_dim = expert_params["hidden_dims"][-1]
143
143
  else:
144
144
  expert_output_dim = input_dim
145
145
 
@@ -153,7 +153,7 @@ class MMOE(BaseModel):
153
153
  # Task-specific towers
154
154
  self.towers = nn.ModuleList()
155
155
  for tower_params in tower_params_list:
156
- tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
156
+ tower = MLP(input_dim=expert_output_dim, output_dim=1, **tower_params)
157
157
  self.towers.append(tower)
158
158
  self.prediction_layer = TaskHead(
159
159
  task_type=self.task, task_dims=[1] * self.nums_task