nextrec 0.4.24__py3-none-any.whl → 0.4.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/asserts.py +72 -0
  3. nextrec/basic/loggers.py +18 -1
  4. nextrec/basic/model.py +191 -71
  5. nextrec/basic/summary.py +58 -0
  6. nextrec/cli.py +13 -0
  7. nextrec/data/data_processing.py +3 -9
  8. nextrec/data/dataloader.py +25 -2
  9. nextrec/data/preprocessor.py +283 -36
  10. nextrec/models/multi_task/[pre]aitm.py +173 -0
  11. nextrec/models/multi_task/[pre]snr_trans.py +232 -0
  12. nextrec/models/multi_task/[pre]star.py +192 -0
  13. nextrec/models/multi_task/apg.py +330 -0
  14. nextrec/models/multi_task/cross_stitch.py +229 -0
  15. nextrec/models/multi_task/escm.py +290 -0
  16. nextrec/models/multi_task/esmm.py +8 -21
  17. nextrec/models/multi_task/hmoe.py +203 -0
  18. nextrec/models/multi_task/mmoe.py +20 -28
  19. nextrec/models/multi_task/pepnet.py +68 -66
  20. nextrec/models/multi_task/ple.py +30 -44
  21. nextrec/models/multi_task/poso.py +13 -22
  22. nextrec/models/multi_task/share_bottom.py +14 -25
  23. nextrec/models/ranking/afm.py +2 -2
  24. nextrec/models/ranking/autoint.py +2 -4
  25. nextrec/models/ranking/dcn.py +2 -3
  26. nextrec/models/ranking/dcn_v2.py +2 -3
  27. nextrec/models/ranking/deepfm.py +2 -3
  28. nextrec/models/ranking/dien.py +7 -9
  29. nextrec/models/ranking/din.py +8 -10
  30. nextrec/models/ranking/eulernet.py +1 -2
  31. nextrec/models/ranking/ffm.py +1 -2
  32. nextrec/models/ranking/fibinet.py +2 -3
  33. nextrec/models/ranking/fm.py +1 -1
  34. nextrec/models/ranking/lr.py +1 -1
  35. nextrec/models/ranking/masknet.py +1 -2
  36. nextrec/models/ranking/pnn.py +1 -2
  37. nextrec/models/ranking/widedeep.py +2 -3
  38. nextrec/models/ranking/xdeepfm.py +2 -4
  39. nextrec/models/representation/rqvae.py +4 -4
  40. nextrec/models/retrieval/dssm.py +18 -26
  41. nextrec/models/retrieval/dssm_v2.py +15 -22
  42. nextrec/models/retrieval/mind.py +9 -15
  43. nextrec/models/retrieval/sdm.py +36 -33
  44. nextrec/models/retrieval/youtube_dnn.py +16 -24
  45. nextrec/models/sequential/hstu.py +2 -2
  46. nextrec/utils/__init__.py +5 -1
  47. nextrec/utils/config.py +2 -0
  48. nextrec/utils/model.py +16 -77
  49. nextrec/utils/torch_utils.py +11 -0
  50. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
  51. nextrec-0.4.27.dist-info/RECORD +90 -0
  52. nextrec/models/multi_task/aitm.py +0 -0
  53. nextrec/models/multi_task/snr_trans.py +0 -0
  54. nextrec-0.4.24.dist-info/RECORD +0 -86
  55. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
  56. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
  57. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
@@ -45,7 +45,15 @@ from nextrec.utils.data import (
45
45
 
46
46
 
47
47
  class DataProcessor(FeatureSet):
48
- def __init__(self, hash_cache_size: int = 200_000):
48
+ def __init__(
49
+ self,
50
+ hash_cache_size: int = 200_000,
51
+ ):
52
+ if not logging.getLogger().hasHandlers():
53
+ logging.basicConfig(
54
+ level=logging.INFO,
55
+ format="%(message)s",
56
+ )
49
57
  self.numeric_features: Dict[str, Dict[str, Any]] = {}
50
58
  self.sparse_features: Dict[str, Dict[str, Any]] = {}
51
59
  self.sequence_features: Dict[str, Dict[str, Any]] = {}
@@ -53,9 +61,6 @@ class DataProcessor(FeatureSet):
53
61
  self.version = __version__
54
62
 
55
63
  self.is_fitted = False
56
- self._transform_summary_printed = (
57
- False # Track if summary has been printed during transform
58
- )
59
64
 
60
65
  self.scalers: Dict[str, Any] = {}
61
66
  self.label_encoders: Dict[str, LabelEncoder] = {}
@@ -92,17 +97,19 @@ class DataProcessor(FeatureSet):
92
97
  def add_sparse_feature(
93
98
  self,
94
99
  name: str,
95
- encode_method: Literal["hash", "label"] = "label",
100
+ encode_method: Literal["hash", "label"] = "hash",
96
101
  hash_size: Optional[int] = None,
102
+ min_freq: Optional[int] = None,
97
103
  fill_na: str = "<UNK>",
98
104
  ):
99
105
  """Add a sparse feature configuration.
100
106
 
101
107
  Args:
102
- name (str): Feature name.
103
- encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "label".
104
- hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
105
- fill_na (str, optional): Fill value for missing entries. Defaults to "<UNK>".
108
+ name: Feature name.
109
+ encode_method: Encoding method, including "hash encoding" and "label encoding". Defaults to "hash" because it is more scalable and much faster.
110
+ hash_size: Hash size for hash encoding. Required if encode_method is "hash".
111
+ min_freq: Minimum frequency for hash encoding to keep tokens; lower-frequency tokens map to unknown. Defaults to None.
112
+ fill_na: Fill value for missing entries. Defaults to "<UNK>".
106
113
  """
107
114
  if encode_method == "hash" and hash_size is None:
108
115
  raise ValueError(
@@ -111,6 +118,7 @@ class DataProcessor(FeatureSet):
111
118
  self.sparse_features[name] = {
112
119
  "encode_method": encode_method,
113
120
  "hash_size": hash_size,
121
+ "min_freq": min_freq,
114
122
  "fill_na": fill_na,
115
123
  }
116
124
 
@@ -119,6 +127,7 @@ class DataProcessor(FeatureSet):
119
127
  name: str,
120
128
  encode_method: Literal["hash", "label"] = "hash",
121
129
  hash_size: Optional[int] = None,
130
+ min_freq: Optional[int] = None,
122
131
  max_len: Optional[int] = 50,
123
132
  pad_value: int = 0,
124
133
  truncate: Literal[
@@ -129,13 +138,14 @@ class DataProcessor(FeatureSet):
129
138
  """Add a sequence feature configuration.
130
139
 
131
140
  Args:
132
- name (str): Feature name.
133
- encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "hash".
134
- hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
135
- max_len (Optional[int], optional): Maximum sequence length. Defaults to 50.
136
- pad_value (int, optional): Padding value for sequences shorter than max_len. Defaults to 0.
137
- truncate (Literal["pre", "post"], optional): Truncation strategy for sequences longer than max_len, including "pre" (keep last max_len items) and "post" (keep first max_len items). Defaults to "pre".
138
- separator (str, optional): Separator for string sequences. Defaults to ",".
141
+ name: Feature name.
142
+ encode_method: Encoding method, including "hash encoding" and "label encoding". Defaults to "hash".
143
+ hash_size: Hash size for hash encoding. Required if encode_method is "hash".
144
+ min_freq: Minimum frequency for hash encoding to keep tokens; lower-frequency tokens map to unknown. Defaults to None.
145
+ max_len: Maximum sequence length. Defaults to 50.
146
+ pad_value: Padding value for sequences shorter than max_len. Defaults to 0.
147
+ truncate: Truncation strategy for sequences longer than max_len, including "pre" (keep last max_len items) and "post" (keep first max_len items). Defaults to "pre".
148
+ separator: Separator for string sequences. Defaults to ",".
139
149
  """
140
150
  if encode_method == "hash" and hash_size is None:
141
151
  raise ValueError(
@@ -144,6 +154,7 @@ class DataProcessor(FeatureSet):
144
154
  self.sequence_features[name] = {
145
155
  "encode_method": encode_method,
146
156
  "hash_size": hash_size,
157
+ "min_freq": min_freq,
147
158
  "max_len": max_len,
148
159
  "pad_value": pad_value,
149
160
  "truncate": truncate,
@@ -175,17 +186,6 @@ class DataProcessor(FeatureSet):
175
186
  def hash_string(self, s: str, hash_size: int) -> int:
176
187
  return self.hash_fn(str(s), int(hash_size))
177
188
 
178
- def clear_hash_cache(self) -> None:
179
- cache_clear = getattr(self.hash_fn, "cache_clear", None)
180
- if callable(cache_clear):
181
- cache_clear()
182
-
183
- def hash_cache_info(self):
184
- cache_info = getattr(self.hash_fn, "cache_info", None)
185
- if callable(cache_info):
186
- return cache_info()
187
- return None
188
-
189
189
  def process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
190
190
  name = str(data.name)
191
191
  scaler_type = config["scaler"]
@@ -241,12 +241,30 @@ class DataProcessor(FeatureSet):
241
241
  return result
242
242
 
243
243
  def process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
244
- _ = str(data.name)
244
+ logger = logging.getLogger()
245
+
245
246
  encode_method = config["encode_method"]
246
247
  fill_na = config["fill_na"] # <UNK>
247
248
  filled_data = data.fillna(fill_na).astype(str)
248
249
  if encode_method == "label":
249
- vocab = sorted(set(filled_data.tolist()))
250
+ min_freq = config.get("min_freq")
251
+ if min_freq is not None:
252
+ counts = filled_data.value_counts()
253
+ config["_token_counts"] = counts.to_dict()
254
+ vocab = sorted(counts[counts >= min_freq].index.tolist())
255
+ low_freq_types = int((counts < min_freq).sum())
256
+ total_types = int(counts.size)
257
+ kept_types = total_types - low_freq_types
258
+ if not config.get("_min_freq_logged"):
259
+ logger.info(
260
+ f"Sparse feature {data.name} min_freq={min_freq}: "
261
+ f"{total_types} token types total, "
262
+ f"{low_freq_types} low-frequency, "
263
+ f"{kept_types} kept."
264
+ )
265
+ config["_min_freq_logged"] = True
266
+ else:
267
+ vocab = sorted(set(filled_data.tolist()))
250
268
  if "<UNK>" not in vocab:
251
269
  vocab.append("<UNK>")
252
270
  token_to_idx = {token: idx for idx, token in enumerate(vocab)}
@@ -254,6 +272,24 @@ class DataProcessor(FeatureSet):
254
272
  config["_unk_index"] = token_to_idx["<UNK>"]
255
273
  config["vocab_size"] = len(vocab)
256
274
  elif encode_method == "hash":
275
+ min_freq = config.get("min_freq")
276
+ if min_freq is not None:
277
+ counts = filled_data.value_counts()
278
+ config["_token_counts"] = counts.to_dict()
279
+ config["_unk_hash"] = self.hash_string(
280
+ "<UNK>", int(config["hash_size"])
281
+ )
282
+ low_freq_types = int((counts < min_freq).sum())
283
+ total_types = int(counts.size)
284
+ kept_types = total_types - low_freq_types
285
+ if not config.get("_min_freq_logged"):
286
+ logger.info(
287
+ f"Sparse feature {data.name} min_freq={min_freq}: "
288
+ f"{total_types} token types total, "
289
+ f"{low_freq_types} low-frequency, "
290
+ f"{kept_types} kept."
291
+ )
292
+ config["_min_freq_logged"] = True
257
293
  config["vocab_size"] = config["hash_size"]
258
294
 
259
295
  def process_sparse_feature_transform(
@@ -283,22 +319,60 @@ class DataProcessor(FeatureSet):
283
319
  if encode_method == "hash":
284
320
  hash_size = config["hash_size"]
285
321
  hash_fn = self.hash_string
322
+ min_freq = config.get("min_freq")
323
+ token_counts = config.get("_token_counts")
324
+ if min_freq is not None and isinstance(token_counts, dict):
325
+ unk_hash = config.get("_unk_hash")
326
+ if unk_hash is None:
327
+ unk_hash = hash_fn("<UNK>", hash_size)
286
328
  return np.fromiter(
287
- (hash_fn(v, hash_size) for v in sparse_series.to_numpy()),
329
+ (
330
+ (
331
+ unk_hash
332
+ if min_freq is not None
333
+ and isinstance(token_counts, dict)
334
+ and token_counts.get(v, 0) < min_freq
335
+ else hash_fn(v, hash_size)
336
+ )
337
+ for v in sparse_series.to_numpy()
338
+ ),
288
339
  dtype=np.int64,
289
340
  count=sparse_series.size,
290
341
  )
291
342
  return np.array([], dtype=np.int64)
292
343
 
293
344
  def process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
345
+ logger = logging.getLogger()
294
346
  _ = str(data.name)
295
347
  encode_method = config["encode_method"]
296
348
  separator = config["separator"]
297
349
  if encode_method == "label":
298
- all_tokens = set()
350
+ min_freq = config.get("min_freq")
351
+ token_counts: Dict[str, int] = {}
299
352
  for seq in data:
300
- all_tokens.update(self.extract_sequence_tokens(seq, separator))
301
- vocab = sorted(all_tokens)
353
+ tokens = self.extract_sequence_tokens(seq, separator)
354
+ for token in tokens:
355
+ if str(token).strip():
356
+ key = str(token)
357
+ token_counts[key] = token_counts.get(key, 0) + 1
358
+ if min_freq is not None:
359
+ config["_token_counts"] = token_counts
360
+ vocab = sorted([k for k, v in token_counts.items() if v >= min_freq])
361
+ low_freq_types = sum(
362
+ 1 for count in token_counts.values() if count < min_freq
363
+ )
364
+ total_types = len(token_counts)
365
+ kept_types = total_types - low_freq_types
366
+ if not config.get("_min_freq_logged"):
367
+ logger.info(
368
+ f"Sequence feature {data.name} min_freq={min_freq}: "
369
+ f"{total_types} token types total, "
370
+ f"{low_freq_types} low-frequency, "
371
+ f"{kept_types} kept."
372
+ )
373
+ config["_min_freq_logged"] = True
374
+ else:
375
+ vocab = sorted(token_counts.keys())
302
376
  if not vocab:
303
377
  vocab = ["<PAD>"]
304
378
  if "<UNK>" not in vocab:
@@ -308,6 +382,33 @@ class DataProcessor(FeatureSet):
308
382
  config["_unk_index"] = token_to_idx["<UNK>"]
309
383
  config["vocab_size"] = len(vocab)
310
384
  elif encode_method == "hash":
385
+ min_freq = config.get("min_freq")
386
+ if min_freq is not None:
387
+ token_counts: Dict[str, int] = {}
388
+ for seq in data:
389
+ tokens = self.extract_sequence_tokens(seq, separator)
390
+ for token in tokens:
391
+ if str(token).strip():
392
+ token_counts[str(token)] = (
393
+ token_counts.get(str(token), 0) + 1
394
+ )
395
+ config["_token_counts"] = token_counts
396
+ config["_unk_hash"] = self.hash_string(
397
+ "<UNK>", int(config["hash_size"])
398
+ )
399
+ low_freq_types = sum(
400
+ 1 for count in token_counts.values() if count < min_freq
401
+ )
402
+ total_types = len(token_counts)
403
+ kept_types = total_types - low_freq_types
404
+ if not config.get("_min_freq_logged"):
405
+ logger.info(
406
+ f"Sequence feature {data.name} min_freq={min_freq}: "
407
+ f"{total_types} token types total, "
408
+ f"{low_freq_types} low-frequency, "
409
+ f"{kept_types} kept."
410
+ )
411
+ config["_min_freq_logged"] = True
311
412
  config["vocab_size"] = config["hash_size"]
312
413
 
313
414
  def process_sequence_feature_transform(
@@ -338,6 +439,12 @@ class DataProcessor(FeatureSet):
338
439
  unk_index = 0
339
440
  hash_fn = self.hash_string
340
441
  hash_size = config.get("hash_size")
442
+ min_freq = config.get("min_freq")
443
+ token_counts = config.get("_token_counts")
444
+ if min_freq is not None and isinstance(token_counts, dict):
445
+ unk_hash = config.get("_unk_hash")
446
+ if unk_hash is None and hash_size is not None:
447
+ unk_hash = hash_fn("<UNK>", hash_size)
341
448
  for i, seq in enumerate(arr):
342
449
  # normalize sequence to a list of strings
343
450
  tokens = []
@@ -364,7 +471,13 @@ class DataProcessor(FeatureSet):
364
471
  "[Data Processor Error] hash_size must be set for hash encoding"
365
472
  )
366
473
  encoded = [
367
- hash_fn(str(token), hash_size)
474
+ (
475
+ unk_hash
476
+ if min_freq is not None
477
+ and isinstance(token_counts, dict)
478
+ and token_counts.get(str(token), 0) < min_freq
479
+ else hash_fn(str(token), hash_size)
480
+ )
368
481
  for token in tokens
369
482
  if str(token).strip()
370
483
  ]
@@ -472,6 +585,10 @@ class DataProcessor(FeatureSet):
472
585
  bold=True,
473
586
  )
474
587
  )
588
+ for config in self.sparse_features.values():
589
+ config.pop("_min_freq_logged", None)
590
+ for config in self.sequence_features.values():
591
+ config.pop("_min_freq_logged", None)
475
592
  file_paths, file_type = resolve_file_paths(path)
476
593
  if not check_streaming_support(file_type):
477
594
  raise ValueError(
@@ -496,6 +613,26 @@ class DataProcessor(FeatureSet):
496
613
  seq_vocab: Dict[str, set[str]] = {
497
614
  name: set() for name in self.sequence_features.keys()
498
615
  }
616
+ sparse_label_counts: Dict[str, Dict[str, int]] = {
617
+ name: {}
618
+ for name, config in self.sparse_features.items()
619
+ if config.get("encode_method") == "label" and config.get("min_freq")
620
+ }
621
+ seq_label_counts: Dict[str, Dict[str, int]] = {
622
+ name: {}
623
+ for name, config in self.sequence_features.items()
624
+ if config.get("encode_method") == "label" and config.get("min_freq")
625
+ }
626
+ sparse_hash_counts: Dict[str, Dict[str, int]] = {
627
+ name: {}
628
+ for name, config in self.sparse_features.items()
629
+ if config.get("encode_method") == "hash" and config.get("min_freq")
630
+ }
631
+ seq_hash_counts: Dict[str, Dict[str, int]] = {
632
+ name: {}
633
+ for name, config in self.sequence_features.items()
634
+ if config.get("encode_method") == "hash" and config.get("min_freq")
635
+ }
499
636
  target_values: Dict[str, set[Any]] = {
500
637
  name: set() for name in self.target_features.keys()
501
638
  }
@@ -531,6 +668,14 @@ class DataProcessor(FeatureSet):
531
668
  fill_na = config["fill_na"]
532
669
  series = series.fillna(fill_na).astype(str)
533
670
  sparse_vocab[name].update(series.tolist())
671
+ if name in sparse_label_counts:
672
+ counts = sparse_label_counts[name]
673
+ for token in series.tolist():
674
+ counts[token] = counts.get(token, 0) + 1
675
+ if name in sparse_hash_counts:
676
+ counts = sparse_hash_counts[name]
677
+ for token in series.tolist():
678
+ counts[token] = counts.get(token, 0) + 1
534
679
  else:
535
680
  separator = config["separator"]
536
681
  tokens = []
@@ -539,6 +684,18 @@ class DataProcessor(FeatureSet):
539
684
  self.extract_sequence_tokens(val, separator)
540
685
  )
541
686
  seq_vocab[name].update(tokens)
687
+ if name in seq_label_counts:
688
+ counts = seq_label_counts[name]
689
+ for token in tokens:
690
+ if str(token).strip():
691
+ key = str(token)
692
+ counts[key] = counts.get(key, 0) + 1
693
+ if name in seq_hash_counts:
694
+ counts = seq_hash_counts[name]
695
+ for token in tokens:
696
+ if str(token).strip():
697
+ key = str(token)
698
+ counts[key] = counts.get(key, 0) + 1
542
699
 
543
700
  # target features
544
701
  missing_features.update(self.target_features.keys() - columns)
@@ -605,7 +762,30 @@ class DataProcessor(FeatureSet):
605
762
  # finalize sparse label encoders
606
763
  for name, config in self.sparse_features.items():
607
764
  if config["encode_method"] == "label":
608
- vocab = sparse_vocab[name]
765
+ min_freq = config.get("min_freq")
766
+ if min_freq is not None:
767
+ token_counts = sparse_label_counts.get(name, {})
768
+ config["_token_counts"] = token_counts
769
+ vocab = {
770
+ token
771
+ for token, count in token_counts.items()
772
+ if count >= min_freq
773
+ }
774
+ low_freq_types = sum(
775
+ 1 for count in token_counts.values() if count < min_freq
776
+ )
777
+ total_types = len(token_counts)
778
+ kept_types = total_types - low_freq_types
779
+ if not config.get("_min_freq_logged"):
780
+ logger.info(
781
+ f"Sparse feature {name} min_freq={min_freq}: "
782
+ f"{total_types} token types total, "
783
+ f"{low_freq_types} low-frequency, "
784
+ f"{kept_types} kept."
785
+ )
786
+ config["_min_freq_logged"] = True
787
+ else:
788
+ vocab = sparse_vocab[name]
609
789
  if not vocab:
610
790
  logger.warning(f"Sparse feature {name} has empty vocabulary")
611
791
  continue
@@ -617,12 +797,55 @@ class DataProcessor(FeatureSet):
617
797
  config["_unk_index"] = token_to_idx["<UNK>"]
618
798
  config["vocab_size"] = len(vocab_list)
619
799
  elif config["encode_method"] == "hash":
800
+ min_freq = config.get("min_freq")
801
+ if min_freq is not None:
802
+ token_counts = sparse_hash_counts.get(name, {})
803
+ config["_token_counts"] = token_counts
804
+ config["_unk_hash"] = self.hash_string(
805
+ "<UNK>", int(config["hash_size"])
806
+ )
807
+ low_freq_types = sum(
808
+ 1 for count in token_counts.values() if count < min_freq
809
+ )
810
+ total_types = len(token_counts)
811
+ kept_types = total_types - low_freq_types
812
+ if not config.get("_min_freq_logged"):
813
+ logger.info(
814
+ f"Sparse feature {name} min_freq={min_freq}: "
815
+ f"{total_types} token types total, "
816
+ f"{low_freq_types} low-frequency, "
817
+ f"{kept_types} kept."
818
+ )
819
+ config["_min_freq_logged"] = True
620
820
  config["vocab_size"] = config["hash_size"]
621
821
 
622
822
  # finalize sequence vocabularies
623
823
  for name, config in self.sequence_features.items():
624
824
  if config["encode_method"] == "label":
625
- vocab_set = seq_vocab[name]
825
+ min_freq = config.get("min_freq")
826
+ if min_freq is not None:
827
+ token_counts = seq_label_counts.get(name, {})
828
+ config["_token_counts"] = token_counts
829
+ vocab_set = {
830
+ token
831
+ for token, count in token_counts.items()
832
+ if count >= min_freq
833
+ }
834
+ low_freq_types = sum(
835
+ 1 for count in token_counts.values() if count < min_freq
836
+ )
837
+ total_types = len(token_counts)
838
+ kept_types = total_types - low_freq_types
839
+ if not config.get("_min_freq_logged"):
840
+ logger.info(
841
+ f"Sequence feature {name} min_freq={min_freq}: "
842
+ f"{total_types} token types total, "
843
+ f"{low_freq_types} low-frequency, "
844
+ f"{kept_types} kept."
845
+ )
846
+ config["_min_freq_logged"] = True
847
+ else:
848
+ vocab_set = seq_vocab[name]
626
849
  vocab_list = sorted(vocab_set) if vocab_set else ["<PAD>"]
627
850
  if "<UNK>" not in vocab_list:
628
851
  vocab_list.append("<UNK>")
@@ -631,6 +854,26 @@ class DataProcessor(FeatureSet):
631
854
  config["_unk_index"] = token_to_idx["<UNK>"]
632
855
  config["vocab_size"] = len(vocab_list)
633
856
  elif config["encode_method"] == "hash":
857
+ min_freq = config.get("min_freq")
858
+ if min_freq is not None:
859
+ token_counts = seq_hash_counts.get(name, {})
860
+ config["_token_counts"] = token_counts
861
+ config["_unk_hash"] = self.hash_string(
862
+ "<UNK>", int(config["hash_size"])
863
+ )
864
+ low_freq_types = sum(
865
+ 1 for count in token_counts.values() if count < min_freq
866
+ )
867
+ total_types = len(token_counts)
868
+ kept_types = total_types - low_freq_types
869
+ if not config.get("_min_freq_logged"):
870
+ logger.info(
871
+ f"Sequence feature {name} min_freq={min_freq}: "
872
+ f"{total_types} token types total, "
873
+ f"{low_freq_types} low-frequency, "
874
+ f"{kept_types} kept."
875
+ )
876
+ config["_min_freq_logged"] = True
634
877
  config["vocab_size"] = config["hash_size"]
635
878
 
636
879
  # finalize targets
@@ -961,6 +1204,10 @@ class DataProcessor(FeatureSet):
961
1204
  """
962
1205
 
963
1206
  logger = logging.getLogger()
1207
+ for config in self.sparse_features.values():
1208
+ config.pop("_min_freq_logged", None)
1209
+ for config in self.sequence_features.values():
1210
+ config.pop("_min_freq_logged", None)
964
1211
  if isinstance(data, (str, os.PathLike)):
965
1212
  path_str = str(data)
966
1213
  uses_robust = any(
@@ -0,0 +1,173 @@
1
+ """
2
+ Date: create on 01/01/2026 - prerelease version: need to overwrite compute_loss later
3
+ Checkpoint: edit on 01/01/2026
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
+ Reference:
6
+ - [1] Xi D, Chen Z, Yan P, Zhang Y, Zhu Y, Zhuang F, Chen Y. Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising. Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining (KDD ’21), 2021, pp. 3745–3755.
7
+ URL: https://arxiv.org/abs/2105.08489
8
+ - [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
9
+
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
19
+ from nextrec.basic.layers import MLP, EmbeddingLayer
20
+ from nextrec.basic.heads import TaskHead
21
+ from nextrec.basic.model import BaseModel
22
+ from nextrec.utils.model import get_mlp_output_dim
23
+ from nextrec.utils.types import TaskTypeName
24
+
25
+
26
+ class AITMTransfer(nn.Module):
27
+ """Attentive information transfer from previous task to current task."""
28
+
29
+ def __init__(self, input_dim: int):
30
+ super().__init__()
31
+ self.input_dim = input_dim
32
+ self.prev_proj = nn.Linear(input_dim, input_dim)
33
+ self.value = nn.Linear(input_dim, input_dim)
34
+ self.key = nn.Linear(input_dim, input_dim)
35
+ self.query = nn.Linear(input_dim, input_dim)
36
+
37
+ def forward(self, prev_feat: torch.Tensor, curr_feat: torch.Tensor) -> torch.Tensor:
38
+ prev = self.prev_proj(prev_feat).unsqueeze(1)
39
+ curr = curr_feat.unsqueeze(1)
40
+ stacked = torch.cat([prev, curr], dim=1)
41
+ value = self.value(stacked)
42
+ key = self.key(stacked)
43
+ query = self.query(stacked)
44
+ attn_scores = torch.sum(key * query, dim=2, keepdim=True) / math.sqrt(
45
+ self.input_dim
46
+ )
47
+ attn = torch.softmax(attn_scores, dim=1)
48
+ return torch.sum(attn * value, dim=1)
49
+
50
+
51
+ class AITM(BaseModel):
52
+ """
53
+ Attentive Information Transfer Multi-Task model.
54
+
55
+ AITM learns task-specific representations and transfers information from
56
+ task i-1 to task i via attention, enabling sequential task dependency modeling.
57
+ """
58
+
59
+ @property
60
+ def model_name(self):
61
+ return "AITM"
62
+
63
+ @property
64
+ def default_task(self):
65
+ nums_task = getattr(self, "nums_task", None)
66
+ if nums_task is not None and nums_task > 0:
67
+ return ["binary"] * nums_task
68
+ return ["binary"]
69
+
70
+ def __init__(
71
+ self,
72
+ dense_features: list[DenseFeature] | None = None,
73
+ sparse_features: list[SparseFeature] | None = None,
74
+ sequence_features: list[SequenceFeature] | None = None,
75
+ bottom_mlp_params: dict | list[dict] | None = None,
76
+ tower_mlp_params_list: list[dict] | None = None,
77
+ calibrator_alpha: float = 0.1,
78
+ target: list[str] | str | None = None,
79
+ task: list[TaskTypeName] | None = None,
80
+ **kwargs,
81
+ ):
82
+ dense_features = dense_features or []
83
+ sparse_features = sparse_features or []
84
+ sequence_features = sequence_features or []
85
+ bottom_mlp_params = bottom_mlp_params or {}
86
+ tower_mlp_params_list = tower_mlp_params_list or []
87
+ self.calibrator_alpha = calibrator_alpha
88
+
89
+ if target is None:
90
+ raise ValueError("AITM requires target names for all tasks.")
91
+ if isinstance(target, str):
92
+ target = [target]
93
+
94
+ self.nums_task = len(target)
95
+ if self.nums_task < 2:
96
+ raise ValueError("AITM requires at least 2 tasks.")
97
+
98
+ super(AITM, self).__init__(
99
+ dense_features=dense_features,
100
+ sparse_features=sparse_features,
101
+ sequence_features=sequence_features,
102
+ target=target,
103
+ task=task,
104
+ **kwargs,
105
+ )
106
+
107
+ if len(tower_mlp_params_list) != self.nums_task:
108
+ raise ValueError(
109
+ "Number of tower mlp params "
110
+ f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})."
111
+ )
112
+
113
+ bottom_mlp_params_list: list[dict]
114
+ if isinstance(bottom_mlp_params, list):
115
+ if len(bottom_mlp_params) != self.nums_task:
116
+ raise ValueError(
117
+ "Number of bottom mlp params "
118
+ f"({len(bottom_mlp_params)}) must match number of tasks ({self.nums_task})."
119
+ )
120
+ bottom_mlp_params_list = [params.copy() for params in bottom_mlp_params]
121
+ else:
122
+ bottom_mlp_params_list = [
123
+ bottom_mlp_params.copy() for _ in range(self.nums_task)
124
+ ]
125
+
126
+ self.embedding = EmbeddingLayer(features=self.all_features)
127
+ input_dim = self.embedding.input_dim
128
+
129
+ self.bottoms = nn.ModuleList(
130
+ [
131
+ MLP(input_dim=input_dim, output_dim=None, **params)
132
+ for params in bottom_mlp_params_list
133
+ ]
134
+ )
135
+ bottom_dims = [
136
+ get_mlp_output_dim(params, input_dim) for params in bottom_mlp_params_list
137
+ ]
138
+ if len(set(bottom_dims)) != 1:
139
+ raise ValueError(f"All bottom output dims must match, got {bottom_dims}.")
140
+ bottom_output_dim = bottom_dims[0]
141
+
142
+ self.transfers = nn.ModuleList(
143
+ [AITMTransfer(bottom_output_dim) for _ in range(self.nums_task - 1)]
144
+ )
145
+ self.grad_norm_shared_modules = ["embedding", "transfers"]
146
+
147
+ self.towers = nn.ModuleList(
148
+ [
149
+ MLP(input_dim=bottom_output_dim, output_dim=1, **params)
150
+ for params in tower_mlp_params_list
151
+ ]
152
+ )
153
+ self.prediction_layer = TaskHead(
154
+ task_type=self.task, task_dims=[1] * self.nums_task
155
+ )
156
+
157
+ self.register_regularization_weights(
158
+ embedding_attr="embedding",
159
+ include_modules=["bottoms", "transfers", "towers"],
160
+ )
161
+
162
+ def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
163
+ input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
164
+ task_feats = [bottom(input_flat) for bottom in self.bottoms]
165
+
166
+ for idx in range(1, self.nums_task):
167
+ task_feats[idx] = self.transfers[idx - 1](
168
+ task_feats[idx - 1], task_feats[idx]
169
+ )
170
+
171
+ task_outputs = [tower(task_feats[idx]) for idx, tower in enumerate(self.towers)]
172
+ logits = torch.cat(task_outputs, dim=1)
173
+ return self.prediction_layer(logits)