nextrec 0.4.24__py3-none-any.whl → 0.4.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/asserts.py +72 -0
- nextrec/basic/loggers.py +18 -1
- nextrec/basic/model.py +191 -71
- nextrec/basic/summary.py +58 -0
- nextrec/cli.py +13 -0
- nextrec/data/data_processing.py +3 -9
- nextrec/data/dataloader.py +25 -2
- nextrec/data/preprocessor.py +283 -36
- nextrec/models/multi_task/[pre]aitm.py +173 -0
- nextrec/models/multi_task/[pre]snr_trans.py +232 -0
- nextrec/models/multi_task/[pre]star.py +192 -0
- nextrec/models/multi_task/apg.py +330 -0
- nextrec/models/multi_task/cross_stitch.py +229 -0
- nextrec/models/multi_task/escm.py +290 -0
- nextrec/models/multi_task/esmm.py +8 -21
- nextrec/models/multi_task/hmoe.py +203 -0
- nextrec/models/multi_task/mmoe.py +20 -28
- nextrec/models/multi_task/pepnet.py +68 -66
- nextrec/models/multi_task/ple.py +30 -44
- nextrec/models/multi_task/poso.py +13 -22
- nextrec/models/multi_task/share_bottom.py +14 -25
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -4
- nextrec/models/ranking/dcn.py +2 -3
- nextrec/models/ranking/dcn_v2.py +2 -3
- nextrec/models/ranking/deepfm.py +2 -3
- nextrec/models/ranking/dien.py +7 -9
- nextrec/models/ranking/din.py +8 -10
- nextrec/models/ranking/eulernet.py +1 -2
- nextrec/models/ranking/ffm.py +1 -2
- nextrec/models/ranking/fibinet.py +2 -3
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/lr.py +1 -1
- nextrec/models/ranking/masknet.py +1 -2
- nextrec/models/ranking/pnn.py +1 -2
- nextrec/models/ranking/widedeep.py +2 -3
- nextrec/models/ranking/xdeepfm.py +2 -4
- nextrec/models/representation/rqvae.py +4 -4
- nextrec/models/retrieval/dssm.py +18 -26
- nextrec/models/retrieval/dssm_v2.py +15 -22
- nextrec/models/retrieval/mind.py +9 -15
- nextrec/models/retrieval/sdm.py +36 -33
- nextrec/models/retrieval/youtube_dnn.py +16 -24
- nextrec/models/sequential/hstu.py +2 -2
- nextrec/utils/__init__.py +5 -1
- nextrec/utils/config.py +2 -0
- nextrec/utils/model.py +16 -77
- nextrec/utils/torch_utils.py +11 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
- nextrec-0.4.27.dist-info/RECORD +90 -0
- nextrec/models/multi_task/aitm.py +0 -0
- nextrec/models/multi_task/snr_trans.py +0 -0
- nextrec-0.4.24.dist-info/RECORD +0 -86
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
nextrec/data/preprocessor.py
CHANGED
|
@@ -45,7 +45,15 @@ from nextrec.utils.data import (
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
class DataProcessor(FeatureSet):
|
|
48
|
-
def __init__(
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
hash_cache_size: int = 200_000,
|
|
51
|
+
):
|
|
52
|
+
if not logging.getLogger().hasHandlers():
|
|
53
|
+
logging.basicConfig(
|
|
54
|
+
level=logging.INFO,
|
|
55
|
+
format="%(message)s",
|
|
56
|
+
)
|
|
49
57
|
self.numeric_features: Dict[str, Dict[str, Any]] = {}
|
|
50
58
|
self.sparse_features: Dict[str, Dict[str, Any]] = {}
|
|
51
59
|
self.sequence_features: Dict[str, Dict[str, Any]] = {}
|
|
@@ -53,9 +61,6 @@ class DataProcessor(FeatureSet):
|
|
|
53
61
|
self.version = __version__
|
|
54
62
|
|
|
55
63
|
self.is_fitted = False
|
|
56
|
-
self._transform_summary_printed = (
|
|
57
|
-
False # Track if summary has been printed during transform
|
|
58
|
-
)
|
|
59
64
|
|
|
60
65
|
self.scalers: Dict[str, Any] = {}
|
|
61
66
|
self.label_encoders: Dict[str, LabelEncoder] = {}
|
|
@@ -92,17 +97,19 @@ class DataProcessor(FeatureSet):
|
|
|
92
97
|
def add_sparse_feature(
|
|
93
98
|
self,
|
|
94
99
|
name: str,
|
|
95
|
-
encode_method: Literal["hash", "label"] = "
|
|
100
|
+
encode_method: Literal["hash", "label"] = "hash",
|
|
96
101
|
hash_size: Optional[int] = None,
|
|
102
|
+
min_freq: Optional[int] = None,
|
|
97
103
|
fill_na: str = "<UNK>",
|
|
98
104
|
):
|
|
99
105
|
"""Add a sparse feature configuration.
|
|
100
106
|
|
|
101
107
|
Args:
|
|
102
|
-
name
|
|
103
|
-
encode_method
|
|
104
|
-
hash_size
|
|
105
|
-
|
|
108
|
+
name: Feature name.
|
|
109
|
+
encode_method: Encoding method, including "hash encoding" and "label encoding". Defaults to "hash" because it is more scalable and much faster.
|
|
110
|
+
hash_size: Hash size for hash encoding. Required if encode_method is "hash".
|
|
111
|
+
min_freq: Minimum frequency for hash encoding to keep tokens; lower-frequency tokens map to unknown. Defaults to None.
|
|
112
|
+
fill_na: Fill value for missing entries. Defaults to "<UNK>".
|
|
106
113
|
"""
|
|
107
114
|
if encode_method == "hash" and hash_size is None:
|
|
108
115
|
raise ValueError(
|
|
@@ -111,6 +118,7 @@ class DataProcessor(FeatureSet):
|
|
|
111
118
|
self.sparse_features[name] = {
|
|
112
119
|
"encode_method": encode_method,
|
|
113
120
|
"hash_size": hash_size,
|
|
121
|
+
"min_freq": min_freq,
|
|
114
122
|
"fill_na": fill_na,
|
|
115
123
|
}
|
|
116
124
|
|
|
@@ -119,6 +127,7 @@ class DataProcessor(FeatureSet):
|
|
|
119
127
|
name: str,
|
|
120
128
|
encode_method: Literal["hash", "label"] = "hash",
|
|
121
129
|
hash_size: Optional[int] = None,
|
|
130
|
+
min_freq: Optional[int] = None,
|
|
122
131
|
max_len: Optional[int] = 50,
|
|
123
132
|
pad_value: int = 0,
|
|
124
133
|
truncate: Literal[
|
|
@@ -129,13 +138,14 @@ class DataProcessor(FeatureSet):
|
|
|
129
138
|
"""Add a sequence feature configuration.
|
|
130
139
|
|
|
131
140
|
Args:
|
|
132
|
-
name
|
|
133
|
-
encode_method
|
|
134
|
-
hash_size
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
141
|
+
name: Feature name.
|
|
142
|
+
encode_method: Encoding method, including "hash encoding" and "label encoding". Defaults to "hash".
|
|
143
|
+
hash_size: Hash size for hash encoding. Required if encode_method is "hash".
|
|
144
|
+
min_freq: Minimum frequency for hash encoding to keep tokens; lower-frequency tokens map to unknown. Defaults to None.
|
|
145
|
+
max_len: Maximum sequence length. Defaults to 50.
|
|
146
|
+
pad_value: Padding value for sequences shorter than max_len. Defaults to 0.
|
|
147
|
+
truncate: Truncation strategy for sequences longer than max_len, including "pre" (keep last max_len items) and "post" (keep first max_len items). Defaults to "pre".
|
|
148
|
+
separator: Separator for string sequences. Defaults to ",".
|
|
139
149
|
"""
|
|
140
150
|
if encode_method == "hash" and hash_size is None:
|
|
141
151
|
raise ValueError(
|
|
@@ -144,6 +154,7 @@ class DataProcessor(FeatureSet):
|
|
|
144
154
|
self.sequence_features[name] = {
|
|
145
155
|
"encode_method": encode_method,
|
|
146
156
|
"hash_size": hash_size,
|
|
157
|
+
"min_freq": min_freq,
|
|
147
158
|
"max_len": max_len,
|
|
148
159
|
"pad_value": pad_value,
|
|
149
160
|
"truncate": truncate,
|
|
@@ -175,17 +186,6 @@ class DataProcessor(FeatureSet):
|
|
|
175
186
|
def hash_string(self, s: str, hash_size: int) -> int:
|
|
176
187
|
return self.hash_fn(str(s), int(hash_size))
|
|
177
188
|
|
|
178
|
-
def clear_hash_cache(self) -> None:
|
|
179
|
-
cache_clear = getattr(self.hash_fn, "cache_clear", None)
|
|
180
|
-
if callable(cache_clear):
|
|
181
|
-
cache_clear()
|
|
182
|
-
|
|
183
|
-
def hash_cache_info(self):
|
|
184
|
-
cache_info = getattr(self.hash_fn, "cache_info", None)
|
|
185
|
-
if callable(cache_info):
|
|
186
|
-
return cache_info()
|
|
187
|
-
return None
|
|
188
|
-
|
|
189
189
|
def process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
190
190
|
name = str(data.name)
|
|
191
191
|
scaler_type = config["scaler"]
|
|
@@ -241,12 +241,30 @@ class DataProcessor(FeatureSet):
|
|
|
241
241
|
return result
|
|
242
242
|
|
|
243
243
|
def process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
244
|
-
|
|
244
|
+
logger = logging.getLogger()
|
|
245
|
+
|
|
245
246
|
encode_method = config["encode_method"]
|
|
246
247
|
fill_na = config["fill_na"] # <UNK>
|
|
247
248
|
filled_data = data.fillna(fill_na).astype(str)
|
|
248
249
|
if encode_method == "label":
|
|
249
|
-
|
|
250
|
+
min_freq = config.get("min_freq")
|
|
251
|
+
if min_freq is not None:
|
|
252
|
+
counts = filled_data.value_counts()
|
|
253
|
+
config["_token_counts"] = counts.to_dict()
|
|
254
|
+
vocab = sorted(counts[counts >= min_freq].index.tolist())
|
|
255
|
+
low_freq_types = int((counts < min_freq).sum())
|
|
256
|
+
total_types = int(counts.size)
|
|
257
|
+
kept_types = total_types - low_freq_types
|
|
258
|
+
if not config.get("_min_freq_logged"):
|
|
259
|
+
logger.info(
|
|
260
|
+
f"Sparse feature {data.name} min_freq={min_freq}: "
|
|
261
|
+
f"{total_types} token types total, "
|
|
262
|
+
f"{low_freq_types} low-frequency, "
|
|
263
|
+
f"{kept_types} kept."
|
|
264
|
+
)
|
|
265
|
+
config["_min_freq_logged"] = True
|
|
266
|
+
else:
|
|
267
|
+
vocab = sorted(set(filled_data.tolist()))
|
|
250
268
|
if "<UNK>" not in vocab:
|
|
251
269
|
vocab.append("<UNK>")
|
|
252
270
|
token_to_idx = {token: idx for idx, token in enumerate(vocab)}
|
|
@@ -254,6 +272,24 @@ class DataProcessor(FeatureSet):
|
|
|
254
272
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
255
273
|
config["vocab_size"] = len(vocab)
|
|
256
274
|
elif encode_method == "hash":
|
|
275
|
+
min_freq = config.get("min_freq")
|
|
276
|
+
if min_freq is not None:
|
|
277
|
+
counts = filled_data.value_counts()
|
|
278
|
+
config["_token_counts"] = counts.to_dict()
|
|
279
|
+
config["_unk_hash"] = self.hash_string(
|
|
280
|
+
"<UNK>", int(config["hash_size"])
|
|
281
|
+
)
|
|
282
|
+
low_freq_types = int((counts < min_freq).sum())
|
|
283
|
+
total_types = int(counts.size)
|
|
284
|
+
kept_types = total_types - low_freq_types
|
|
285
|
+
if not config.get("_min_freq_logged"):
|
|
286
|
+
logger.info(
|
|
287
|
+
f"Sparse feature {data.name} min_freq={min_freq}: "
|
|
288
|
+
f"{total_types} token types total, "
|
|
289
|
+
f"{low_freq_types} low-frequency, "
|
|
290
|
+
f"{kept_types} kept."
|
|
291
|
+
)
|
|
292
|
+
config["_min_freq_logged"] = True
|
|
257
293
|
config["vocab_size"] = config["hash_size"]
|
|
258
294
|
|
|
259
295
|
def process_sparse_feature_transform(
|
|
@@ -283,22 +319,60 @@ class DataProcessor(FeatureSet):
|
|
|
283
319
|
if encode_method == "hash":
|
|
284
320
|
hash_size = config["hash_size"]
|
|
285
321
|
hash_fn = self.hash_string
|
|
322
|
+
min_freq = config.get("min_freq")
|
|
323
|
+
token_counts = config.get("_token_counts")
|
|
324
|
+
if min_freq is not None and isinstance(token_counts, dict):
|
|
325
|
+
unk_hash = config.get("_unk_hash")
|
|
326
|
+
if unk_hash is None:
|
|
327
|
+
unk_hash = hash_fn("<UNK>", hash_size)
|
|
286
328
|
return np.fromiter(
|
|
287
|
-
(
|
|
329
|
+
(
|
|
330
|
+
(
|
|
331
|
+
unk_hash
|
|
332
|
+
if min_freq is not None
|
|
333
|
+
and isinstance(token_counts, dict)
|
|
334
|
+
and token_counts.get(v, 0) < min_freq
|
|
335
|
+
else hash_fn(v, hash_size)
|
|
336
|
+
)
|
|
337
|
+
for v in sparse_series.to_numpy()
|
|
338
|
+
),
|
|
288
339
|
dtype=np.int64,
|
|
289
340
|
count=sparse_series.size,
|
|
290
341
|
)
|
|
291
342
|
return np.array([], dtype=np.int64)
|
|
292
343
|
|
|
293
344
|
def process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
345
|
+
logger = logging.getLogger()
|
|
294
346
|
_ = str(data.name)
|
|
295
347
|
encode_method = config["encode_method"]
|
|
296
348
|
separator = config["separator"]
|
|
297
349
|
if encode_method == "label":
|
|
298
|
-
|
|
350
|
+
min_freq = config.get("min_freq")
|
|
351
|
+
token_counts: Dict[str, int] = {}
|
|
299
352
|
for seq in data:
|
|
300
|
-
|
|
301
|
-
|
|
353
|
+
tokens = self.extract_sequence_tokens(seq, separator)
|
|
354
|
+
for token in tokens:
|
|
355
|
+
if str(token).strip():
|
|
356
|
+
key = str(token)
|
|
357
|
+
token_counts[key] = token_counts.get(key, 0) + 1
|
|
358
|
+
if min_freq is not None:
|
|
359
|
+
config["_token_counts"] = token_counts
|
|
360
|
+
vocab = sorted([k for k, v in token_counts.items() if v >= min_freq])
|
|
361
|
+
low_freq_types = sum(
|
|
362
|
+
1 for count in token_counts.values() if count < min_freq
|
|
363
|
+
)
|
|
364
|
+
total_types = len(token_counts)
|
|
365
|
+
kept_types = total_types - low_freq_types
|
|
366
|
+
if not config.get("_min_freq_logged"):
|
|
367
|
+
logger.info(
|
|
368
|
+
f"Sequence feature {data.name} min_freq={min_freq}: "
|
|
369
|
+
f"{total_types} token types total, "
|
|
370
|
+
f"{low_freq_types} low-frequency, "
|
|
371
|
+
f"{kept_types} kept."
|
|
372
|
+
)
|
|
373
|
+
config["_min_freq_logged"] = True
|
|
374
|
+
else:
|
|
375
|
+
vocab = sorted(token_counts.keys())
|
|
302
376
|
if not vocab:
|
|
303
377
|
vocab = ["<PAD>"]
|
|
304
378
|
if "<UNK>" not in vocab:
|
|
@@ -308,6 +382,33 @@ class DataProcessor(FeatureSet):
|
|
|
308
382
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
309
383
|
config["vocab_size"] = len(vocab)
|
|
310
384
|
elif encode_method == "hash":
|
|
385
|
+
min_freq = config.get("min_freq")
|
|
386
|
+
if min_freq is not None:
|
|
387
|
+
token_counts: Dict[str, int] = {}
|
|
388
|
+
for seq in data:
|
|
389
|
+
tokens = self.extract_sequence_tokens(seq, separator)
|
|
390
|
+
for token in tokens:
|
|
391
|
+
if str(token).strip():
|
|
392
|
+
token_counts[str(token)] = (
|
|
393
|
+
token_counts.get(str(token), 0) + 1
|
|
394
|
+
)
|
|
395
|
+
config["_token_counts"] = token_counts
|
|
396
|
+
config["_unk_hash"] = self.hash_string(
|
|
397
|
+
"<UNK>", int(config["hash_size"])
|
|
398
|
+
)
|
|
399
|
+
low_freq_types = sum(
|
|
400
|
+
1 for count in token_counts.values() if count < min_freq
|
|
401
|
+
)
|
|
402
|
+
total_types = len(token_counts)
|
|
403
|
+
kept_types = total_types - low_freq_types
|
|
404
|
+
if not config.get("_min_freq_logged"):
|
|
405
|
+
logger.info(
|
|
406
|
+
f"Sequence feature {data.name} min_freq={min_freq}: "
|
|
407
|
+
f"{total_types} token types total, "
|
|
408
|
+
f"{low_freq_types} low-frequency, "
|
|
409
|
+
f"{kept_types} kept."
|
|
410
|
+
)
|
|
411
|
+
config["_min_freq_logged"] = True
|
|
311
412
|
config["vocab_size"] = config["hash_size"]
|
|
312
413
|
|
|
313
414
|
def process_sequence_feature_transform(
|
|
@@ -338,6 +439,12 @@ class DataProcessor(FeatureSet):
|
|
|
338
439
|
unk_index = 0
|
|
339
440
|
hash_fn = self.hash_string
|
|
340
441
|
hash_size = config.get("hash_size")
|
|
442
|
+
min_freq = config.get("min_freq")
|
|
443
|
+
token_counts = config.get("_token_counts")
|
|
444
|
+
if min_freq is not None and isinstance(token_counts, dict):
|
|
445
|
+
unk_hash = config.get("_unk_hash")
|
|
446
|
+
if unk_hash is None and hash_size is not None:
|
|
447
|
+
unk_hash = hash_fn("<UNK>", hash_size)
|
|
341
448
|
for i, seq in enumerate(arr):
|
|
342
449
|
# normalize sequence to a list of strings
|
|
343
450
|
tokens = []
|
|
@@ -364,7 +471,13 @@ class DataProcessor(FeatureSet):
|
|
|
364
471
|
"[Data Processor Error] hash_size must be set for hash encoding"
|
|
365
472
|
)
|
|
366
473
|
encoded = [
|
|
367
|
-
|
|
474
|
+
(
|
|
475
|
+
unk_hash
|
|
476
|
+
if min_freq is not None
|
|
477
|
+
and isinstance(token_counts, dict)
|
|
478
|
+
and token_counts.get(str(token), 0) < min_freq
|
|
479
|
+
else hash_fn(str(token), hash_size)
|
|
480
|
+
)
|
|
368
481
|
for token in tokens
|
|
369
482
|
if str(token).strip()
|
|
370
483
|
]
|
|
@@ -472,6 +585,10 @@ class DataProcessor(FeatureSet):
|
|
|
472
585
|
bold=True,
|
|
473
586
|
)
|
|
474
587
|
)
|
|
588
|
+
for config in self.sparse_features.values():
|
|
589
|
+
config.pop("_min_freq_logged", None)
|
|
590
|
+
for config in self.sequence_features.values():
|
|
591
|
+
config.pop("_min_freq_logged", None)
|
|
475
592
|
file_paths, file_type = resolve_file_paths(path)
|
|
476
593
|
if not check_streaming_support(file_type):
|
|
477
594
|
raise ValueError(
|
|
@@ -496,6 +613,26 @@ class DataProcessor(FeatureSet):
|
|
|
496
613
|
seq_vocab: Dict[str, set[str]] = {
|
|
497
614
|
name: set() for name in self.sequence_features.keys()
|
|
498
615
|
}
|
|
616
|
+
sparse_label_counts: Dict[str, Dict[str, int]] = {
|
|
617
|
+
name: {}
|
|
618
|
+
for name, config in self.sparse_features.items()
|
|
619
|
+
if config.get("encode_method") == "label" and config.get("min_freq")
|
|
620
|
+
}
|
|
621
|
+
seq_label_counts: Dict[str, Dict[str, int]] = {
|
|
622
|
+
name: {}
|
|
623
|
+
for name, config in self.sequence_features.items()
|
|
624
|
+
if config.get("encode_method") == "label" and config.get("min_freq")
|
|
625
|
+
}
|
|
626
|
+
sparse_hash_counts: Dict[str, Dict[str, int]] = {
|
|
627
|
+
name: {}
|
|
628
|
+
for name, config in self.sparse_features.items()
|
|
629
|
+
if config.get("encode_method") == "hash" and config.get("min_freq")
|
|
630
|
+
}
|
|
631
|
+
seq_hash_counts: Dict[str, Dict[str, int]] = {
|
|
632
|
+
name: {}
|
|
633
|
+
for name, config in self.sequence_features.items()
|
|
634
|
+
if config.get("encode_method") == "hash" and config.get("min_freq")
|
|
635
|
+
}
|
|
499
636
|
target_values: Dict[str, set[Any]] = {
|
|
500
637
|
name: set() for name in self.target_features.keys()
|
|
501
638
|
}
|
|
@@ -531,6 +668,14 @@ class DataProcessor(FeatureSet):
|
|
|
531
668
|
fill_na = config["fill_na"]
|
|
532
669
|
series = series.fillna(fill_na).astype(str)
|
|
533
670
|
sparse_vocab[name].update(series.tolist())
|
|
671
|
+
if name in sparse_label_counts:
|
|
672
|
+
counts = sparse_label_counts[name]
|
|
673
|
+
for token in series.tolist():
|
|
674
|
+
counts[token] = counts.get(token, 0) + 1
|
|
675
|
+
if name in sparse_hash_counts:
|
|
676
|
+
counts = sparse_hash_counts[name]
|
|
677
|
+
for token in series.tolist():
|
|
678
|
+
counts[token] = counts.get(token, 0) + 1
|
|
534
679
|
else:
|
|
535
680
|
separator = config["separator"]
|
|
536
681
|
tokens = []
|
|
@@ -539,6 +684,18 @@ class DataProcessor(FeatureSet):
|
|
|
539
684
|
self.extract_sequence_tokens(val, separator)
|
|
540
685
|
)
|
|
541
686
|
seq_vocab[name].update(tokens)
|
|
687
|
+
if name in seq_label_counts:
|
|
688
|
+
counts = seq_label_counts[name]
|
|
689
|
+
for token in tokens:
|
|
690
|
+
if str(token).strip():
|
|
691
|
+
key = str(token)
|
|
692
|
+
counts[key] = counts.get(key, 0) + 1
|
|
693
|
+
if name in seq_hash_counts:
|
|
694
|
+
counts = seq_hash_counts[name]
|
|
695
|
+
for token in tokens:
|
|
696
|
+
if str(token).strip():
|
|
697
|
+
key = str(token)
|
|
698
|
+
counts[key] = counts.get(key, 0) + 1
|
|
542
699
|
|
|
543
700
|
# target features
|
|
544
701
|
missing_features.update(self.target_features.keys() - columns)
|
|
@@ -605,7 +762,30 @@ class DataProcessor(FeatureSet):
|
|
|
605
762
|
# finalize sparse label encoders
|
|
606
763
|
for name, config in self.sparse_features.items():
|
|
607
764
|
if config["encode_method"] == "label":
|
|
608
|
-
|
|
765
|
+
min_freq = config.get("min_freq")
|
|
766
|
+
if min_freq is not None:
|
|
767
|
+
token_counts = sparse_label_counts.get(name, {})
|
|
768
|
+
config["_token_counts"] = token_counts
|
|
769
|
+
vocab = {
|
|
770
|
+
token
|
|
771
|
+
for token, count in token_counts.items()
|
|
772
|
+
if count >= min_freq
|
|
773
|
+
}
|
|
774
|
+
low_freq_types = sum(
|
|
775
|
+
1 for count in token_counts.values() if count < min_freq
|
|
776
|
+
)
|
|
777
|
+
total_types = len(token_counts)
|
|
778
|
+
kept_types = total_types - low_freq_types
|
|
779
|
+
if not config.get("_min_freq_logged"):
|
|
780
|
+
logger.info(
|
|
781
|
+
f"Sparse feature {name} min_freq={min_freq}: "
|
|
782
|
+
f"{total_types} token types total, "
|
|
783
|
+
f"{low_freq_types} low-frequency, "
|
|
784
|
+
f"{kept_types} kept."
|
|
785
|
+
)
|
|
786
|
+
config["_min_freq_logged"] = True
|
|
787
|
+
else:
|
|
788
|
+
vocab = sparse_vocab[name]
|
|
609
789
|
if not vocab:
|
|
610
790
|
logger.warning(f"Sparse feature {name} has empty vocabulary")
|
|
611
791
|
continue
|
|
@@ -617,12 +797,55 @@ class DataProcessor(FeatureSet):
|
|
|
617
797
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
618
798
|
config["vocab_size"] = len(vocab_list)
|
|
619
799
|
elif config["encode_method"] == "hash":
|
|
800
|
+
min_freq = config.get("min_freq")
|
|
801
|
+
if min_freq is not None:
|
|
802
|
+
token_counts = sparse_hash_counts.get(name, {})
|
|
803
|
+
config["_token_counts"] = token_counts
|
|
804
|
+
config["_unk_hash"] = self.hash_string(
|
|
805
|
+
"<UNK>", int(config["hash_size"])
|
|
806
|
+
)
|
|
807
|
+
low_freq_types = sum(
|
|
808
|
+
1 for count in token_counts.values() if count < min_freq
|
|
809
|
+
)
|
|
810
|
+
total_types = len(token_counts)
|
|
811
|
+
kept_types = total_types - low_freq_types
|
|
812
|
+
if not config.get("_min_freq_logged"):
|
|
813
|
+
logger.info(
|
|
814
|
+
f"Sparse feature {name} min_freq={min_freq}: "
|
|
815
|
+
f"{total_types} token types total, "
|
|
816
|
+
f"{low_freq_types} low-frequency, "
|
|
817
|
+
f"{kept_types} kept."
|
|
818
|
+
)
|
|
819
|
+
config["_min_freq_logged"] = True
|
|
620
820
|
config["vocab_size"] = config["hash_size"]
|
|
621
821
|
|
|
622
822
|
# finalize sequence vocabularies
|
|
623
823
|
for name, config in self.sequence_features.items():
|
|
624
824
|
if config["encode_method"] == "label":
|
|
625
|
-
|
|
825
|
+
min_freq = config.get("min_freq")
|
|
826
|
+
if min_freq is not None:
|
|
827
|
+
token_counts = seq_label_counts.get(name, {})
|
|
828
|
+
config["_token_counts"] = token_counts
|
|
829
|
+
vocab_set = {
|
|
830
|
+
token
|
|
831
|
+
for token, count in token_counts.items()
|
|
832
|
+
if count >= min_freq
|
|
833
|
+
}
|
|
834
|
+
low_freq_types = sum(
|
|
835
|
+
1 for count in token_counts.values() if count < min_freq
|
|
836
|
+
)
|
|
837
|
+
total_types = len(token_counts)
|
|
838
|
+
kept_types = total_types - low_freq_types
|
|
839
|
+
if not config.get("_min_freq_logged"):
|
|
840
|
+
logger.info(
|
|
841
|
+
f"Sequence feature {name} min_freq={min_freq}: "
|
|
842
|
+
f"{total_types} token types total, "
|
|
843
|
+
f"{low_freq_types} low-frequency, "
|
|
844
|
+
f"{kept_types} kept."
|
|
845
|
+
)
|
|
846
|
+
config["_min_freq_logged"] = True
|
|
847
|
+
else:
|
|
848
|
+
vocab_set = seq_vocab[name]
|
|
626
849
|
vocab_list = sorted(vocab_set) if vocab_set else ["<PAD>"]
|
|
627
850
|
if "<UNK>" not in vocab_list:
|
|
628
851
|
vocab_list.append("<UNK>")
|
|
@@ -631,6 +854,26 @@ class DataProcessor(FeatureSet):
|
|
|
631
854
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
632
855
|
config["vocab_size"] = len(vocab_list)
|
|
633
856
|
elif config["encode_method"] == "hash":
|
|
857
|
+
min_freq = config.get("min_freq")
|
|
858
|
+
if min_freq is not None:
|
|
859
|
+
token_counts = seq_hash_counts.get(name, {})
|
|
860
|
+
config["_token_counts"] = token_counts
|
|
861
|
+
config["_unk_hash"] = self.hash_string(
|
|
862
|
+
"<UNK>", int(config["hash_size"])
|
|
863
|
+
)
|
|
864
|
+
low_freq_types = sum(
|
|
865
|
+
1 for count in token_counts.values() if count < min_freq
|
|
866
|
+
)
|
|
867
|
+
total_types = len(token_counts)
|
|
868
|
+
kept_types = total_types - low_freq_types
|
|
869
|
+
if not config.get("_min_freq_logged"):
|
|
870
|
+
logger.info(
|
|
871
|
+
f"Sequence feature {name} min_freq={min_freq}: "
|
|
872
|
+
f"{total_types} token types total, "
|
|
873
|
+
f"{low_freq_types} low-frequency, "
|
|
874
|
+
f"{kept_types} kept."
|
|
875
|
+
)
|
|
876
|
+
config["_min_freq_logged"] = True
|
|
634
877
|
config["vocab_size"] = config["hash_size"]
|
|
635
878
|
|
|
636
879
|
# finalize targets
|
|
@@ -961,6 +1204,10 @@ class DataProcessor(FeatureSet):
|
|
|
961
1204
|
"""
|
|
962
1205
|
|
|
963
1206
|
logger = logging.getLogger()
|
|
1207
|
+
for config in self.sparse_features.values():
|
|
1208
|
+
config.pop("_min_freq_logged", None)
|
|
1209
|
+
for config in self.sequence_features.values():
|
|
1210
|
+
config.pop("_min_freq_logged", None)
|
|
964
1211
|
if isinstance(data, (str, os.PathLike)):
|
|
965
1212
|
path_str = str(data)
|
|
966
1213
|
uses_robust = any(
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Date: create on 01/01/2026 - prerelease version: need to overwrite compute_loss later
|
|
3
|
+
Checkpoint: edit on 01/01/2026
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
|
+
Reference:
|
|
6
|
+
- [1] Xi D, Chen Z, Yan P, Zhang Y, Zhu Y, Zhuang F, Chen Y. Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising. Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining (KDD ’21), 2021, pp. 3745–3755.
|
|
7
|
+
URL: https://arxiv.org/abs/2105.08489
|
|
8
|
+
- [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import math
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
|
|
18
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
19
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
20
|
+
from nextrec.basic.heads import TaskHead
|
|
21
|
+
from nextrec.basic.model import BaseModel
|
|
22
|
+
from nextrec.utils.model import get_mlp_output_dim
|
|
23
|
+
from nextrec.utils.types import TaskTypeName
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AITMTransfer(nn.Module):
|
|
27
|
+
"""Attentive information transfer from previous task to current task."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, input_dim: int):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.input_dim = input_dim
|
|
32
|
+
self.prev_proj = nn.Linear(input_dim, input_dim)
|
|
33
|
+
self.value = nn.Linear(input_dim, input_dim)
|
|
34
|
+
self.key = nn.Linear(input_dim, input_dim)
|
|
35
|
+
self.query = nn.Linear(input_dim, input_dim)
|
|
36
|
+
|
|
37
|
+
def forward(self, prev_feat: torch.Tensor, curr_feat: torch.Tensor) -> torch.Tensor:
|
|
38
|
+
prev = self.prev_proj(prev_feat).unsqueeze(1)
|
|
39
|
+
curr = curr_feat.unsqueeze(1)
|
|
40
|
+
stacked = torch.cat([prev, curr], dim=1)
|
|
41
|
+
value = self.value(stacked)
|
|
42
|
+
key = self.key(stacked)
|
|
43
|
+
query = self.query(stacked)
|
|
44
|
+
attn_scores = torch.sum(key * query, dim=2, keepdim=True) / math.sqrt(
|
|
45
|
+
self.input_dim
|
|
46
|
+
)
|
|
47
|
+
attn = torch.softmax(attn_scores, dim=1)
|
|
48
|
+
return torch.sum(attn * value, dim=1)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class AITM(BaseModel):
|
|
52
|
+
"""
|
|
53
|
+
Attentive Information Transfer Multi-Task model.
|
|
54
|
+
|
|
55
|
+
AITM learns task-specific representations and transfers information from
|
|
56
|
+
task i-1 to task i via attention, enabling sequential task dependency modeling.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def model_name(self):
|
|
61
|
+
return "AITM"
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def default_task(self):
|
|
65
|
+
nums_task = getattr(self, "nums_task", None)
|
|
66
|
+
if nums_task is not None and nums_task > 0:
|
|
67
|
+
return ["binary"] * nums_task
|
|
68
|
+
return ["binary"]
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
dense_features: list[DenseFeature] | None = None,
|
|
73
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
74
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
75
|
+
bottom_mlp_params: dict | list[dict] | None = None,
|
|
76
|
+
tower_mlp_params_list: list[dict] | None = None,
|
|
77
|
+
calibrator_alpha: float = 0.1,
|
|
78
|
+
target: list[str] | str | None = None,
|
|
79
|
+
task: list[TaskTypeName] | None = None,
|
|
80
|
+
**kwargs,
|
|
81
|
+
):
|
|
82
|
+
dense_features = dense_features or []
|
|
83
|
+
sparse_features = sparse_features or []
|
|
84
|
+
sequence_features = sequence_features or []
|
|
85
|
+
bottom_mlp_params = bottom_mlp_params or {}
|
|
86
|
+
tower_mlp_params_list = tower_mlp_params_list or []
|
|
87
|
+
self.calibrator_alpha = calibrator_alpha
|
|
88
|
+
|
|
89
|
+
if target is None:
|
|
90
|
+
raise ValueError("AITM requires target names for all tasks.")
|
|
91
|
+
if isinstance(target, str):
|
|
92
|
+
target = [target]
|
|
93
|
+
|
|
94
|
+
self.nums_task = len(target)
|
|
95
|
+
if self.nums_task < 2:
|
|
96
|
+
raise ValueError("AITM requires at least 2 tasks.")
|
|
97
|
+
|
|
98
|
+
super(AITM, self).__init__(
|
|
99
|
+
dense_features=dense_features,
|
|
100
|
+
sparse_features=sparse_features,
|
|
101
|
+
sequence_features=sequence_features,
|
|
102
|
+
target=target,
|
|
103
|
+
task=task,
|
|
104
|
+
**kwargs,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if len(tower_mlp_params_list) != self.nums_task:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"Number of tower mlp params "
|
|
110
|
+
f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
bottom_mlp_params_list: list[dict]
|
|
114
|
+
if isinstance(bottom_mlp_params, list):
|
|
115
|
+
if len(bottom_mlp_params) != self.nums_task:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
"Number of bottom mlp params "
|
|
118
|
+
f"({len(bottom_mlp_params)}) must match number of tasks ({self.nums_task})."
|
|
119
|
+
)
|
|
120
|
+
bottom_mlp_params_list = [params.copy() for params in bottom_mlp_params]
|
|
121
|
+
else:
|
|
122
|
+
bottom_mlp_params_list = [
|
|
123
|
+
bottom_mlp_params.copy() for _ in range(self.nums_task)
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
127
|
+
input_dim = self.embedding.input_dim
|
|
128
|
+
|
|
129
|
+
self.bottoms = nn.ModuleList(
|
|
130
|
+
[
|
|
131
|
+
MLP(input_dim=input_dim, output_dim=None, **params)
|
|
132
|
+
for params in bottom_mlp_params_list
|
|
133
|
+
]
|
|
134
|
+
)
|
|
135
|
+
bottom_dims = [
|
|
136
|
+
get_mlp_output_dim(params, input_dim) for params in bottom_mlp_params_list
|
|
137
|
+
]
|
|
138
|
+
if len(set(bottom_dims)) != 1:
|
|
139
|
+
raise ValueError(f"All bottom output dims must match, got {bottom_dims}.")
|
|
140
|
+
bottom_output_dim = bottom_dims[0]
|
|
141
|
+
|
|
142
|
+
self.transfers = nn.ModuleList(
|
|
143
|
+
[AITMTransfer(bottom_output_dim) for _ in range(self.nums_task - 1)]
|
|
144
|
+
)
|
|
145
|
+
self.grad_norm_shared_modules = ["embedding", "transfers"]
|
|
146
|
+
|
|
147
|
+
self.towers = nn.ModuleList(
|
|
148
|
+
[
|
|
149
|
+
MLP(input_dim=bottom_output_dim, output_dim=1, **params)
|
|
150
|
+
for params in tower_mlp_params_list
|
|
151
|
+
]
|
|
152
|
+
)
|
|
153
|
+
self.prediction_layer = TaskHead(
|
|
154
|
+
task_type=self.task, task_dims=[1] * self.nums_task
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
self.register_regularization_weights(
|
|
158
|
+
embedding_attr="embedding",
|
|
159
|
+
include_modules=["bottoms", "transfers", "towers"],
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
163
|
+
input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
164
|
+
task_feats = [bottom(input_flat) for bottom in self.bottoms]
|
|
165
|
+
|
|
166
|
+
for idx in range(1, self.nums_task):
|
|
167
|
+
task_feats[idx] = self.transfers[idx - 1](
|
|
168
|
+
task_feats[idx - 1], task_feats[idx]
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
task_outputs = [tower(task_feats[idx]) for idx, tower in enumerate(self.towers)]
|
|
172
|
+
logits = torch.cat(task_outputs, dim=1)
|
|
173
|
+
return self.prediction_layer(logits)
|