nextrec 0.4.24__py3-none-any.whl → 0.4.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/model.py +175 -58
- nextrec/basic/summary.py +58 -0
- nextrec/cli.py +13 -0
- nextrec/data/data_processing.py +3 -9
- nextrec/data/dataloader.py +25 -2
- nextrec/data/preprocessor.py +283 -36
- nextrec/utils/config.py +2 -0
- nextrec/utils/model.py +14 -70
- nextrec/utils/torch_utils.py +11 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/METADATA +4 -4
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/RECORD +15 -15
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/WHEEL +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/licenses/LICENSE +0 -0
nextrec/data/data_processing.py
CHANGED
|
@@ -13,6 +13,8 @@ import numpy as np
|
|
|
13
13
|
import pandas as pd
|
|
14
14
|
import torch
|
|
15
15
|
|
|
16
|
+
from nextrec.utils.torch_utils import to_numpy
|
|
17
|
+
|
|
16
18
|
|
|
17
19
|
def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
18
20
|
|
|
@@ -23,15 +25,7 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
|
23
25
|
return None
|
|
24
26
|
return data[name].values
|
|
25
27
|
else:
|
|
26
|
-
|
|
27
|
-
return getattr(data, name)
|
|
28
|
-
raise KeyError(f"Unsupported data type for extracting column {name}")
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def to_numpy(values: Any) -> np.ndarray:
|
|
32
|
-
if isinstance(values, torch.Tensor):
|
|
33
|
-
return values.detach().cpu().numpy()
|
|
34
|
-
return np.asarray(values)
|
|
28
|
+
raise KeyError(f"Only dict or DataFrame supported, got {type(data)}")
|
|
35
29
|
|
|
36
30
|
|
|
37
31
|
def get_data_length(data: Any) -> int | None:
|
nextrec/data/dataloader.py
CHANGED
|
@@ -194,6 +194,7 @@ class RecDataLoader(FeatureSet):
|
|
|
194
194
|
streaming: bool = False,
|
|
195
195
|
chunk_size: int = 10000,
|
|
196
196
|
num_workers: int = 0,
|
|
197
|
+
prefetch_factor: int | None = None,
|
|
197
198
|
sampler=None,
|
|
198
199
|
) -> DataLoader:
|
|
199
200
|
"""
|
|
@@ -206,6 +207,7 @@ class RecDataLoader(FeatureSet):
|
|
|
206
207
|
streaming: If True, use streaming mode for large files; if False, load full data into memory.
|
|
207
208
|
chunk_size: Chunk size for streaming mode (number of rows per chunk).
|
|
208
209
|
num_workers: Number of worker processes for data loading.
|
|
210
|
+
prefetch_factor: Number of batches loaded in advance by each worker.
|
|
209
211
|
sampler: Optional sampler for DataLoader, only used for distributed training.
|
|
210
212
|
Returns:
|
|
211
213
|
DataLoader instance.
|
|
@@ -234,6 +236,7 @@ class RecDataLoader(FeatureSet):
|
|
|
234
236
|
streaming=streaming,
|
|
235
237
|
chunk_size=chunk_size,
|
|
236
238
|
num_workers=num_workers,
|
|
239
|
+
prefetch_factor=prefetch_factor,
|
|
237
240
|
)
|
|
238
241
|
|
|
239
242
|
if isinstance(data, (dict, pd.DataFrame)):
|
|
@@ -242,6 +245,7 @@ class RecDataLoader(FeatureSet):
|
|
|
242
245
|
batch_size=batch_size,
|
|
243
246
|
shuffle=shuffle,
|
|
244
247
|
num_workers=num_workers,
|
|
248
|
+
prefetch_factor=prefetch_factor,
|
|
245
249
|
sampler=sampler,
|
|
246
250
|
)
|
|
247
251
|
|
|
@@ -253,6 +257,7 @@ class RecDataLoader(FeatureSet):
|
|
|
253
257
|
batch_size: int,
|
|
254
258
|
shuffle: bool,
|
|
255
259
|
num_workers: int = 0,
|
|
260
|
+
prefetch_factor: int | None = None,
|
|
256
261
|
sampler=None,
|
|
257
262
|
) -> DataLoader:
|
|
258
263
|
raw_data = data
|
|
@@ -275,6 +280,9 @@ class RecDataLoader(FeatureSet):
|
|
|
275
280
|
"[RecDataLoader Error] No valid tensors could be built from the provided data."
|
|
276
281
|
)
|
|
277
282
|
dataset = TensorDictDataset(tensors)
|
|
283
|
+
loader_kwargs = {}
|
|
284
|
+
if num_workers > 0 and prefetch_factor is not None:
|
|
285
|
+
loader_kwargs["prefetch_factor"] = prefetch_factor
|
|
278
286
|
return DataLoader(
|
|
279
287
|
dataset,
|
|
280
288
|
batch_size=batch_size,
|
|
@@ -284,6 +292,7 @@ class RecDataLoader(FeatureSet):
|
|
|
284
292
|
num_workers=num_workers,
|
|
285
293
|
pin_memory=torch.cuda.is_available(),
|
|
286
294
|
persistent_workers=num_workers > 0,
|
|
295
|
+
**loader_kwargs,
|
|
287
296
|
)
|
|
288
297
|
|
|
289
298
|
def create_from_path(
|
|
@@ -294,6 +303,7 @@ class RecDataLoader(FeatureSet):
|
|
|
294
303
|
streaming: bool,
|
|
295
304
|
chunk_size: int = 10000,
|
|
296
305
|
num_workers: int = 0,
|
|
306
|
+
prefetch_factor: int | None = None,
|
|
297
307
|
) -> DataLoader:
|
|
298
308
|
if isinstance(path, (str, os.PathLike)):
|
|
299
309
|
file_paths, file_type = resolve_file_paths(str(Path(path)))
|
|
@@ -327,6 +337,7 @@ class RecDataLoader(FeatureSet):
|
|
|
327
337
|
chunk_size,
|
|
328
338
|
shuffle,
|
|
329
339
|
num_workers=num_workers,
|
|
340
|
+
prefetch_factor=prefetch_factor,
|
|
330
341
|
)
|
|
331
342
|
|
|
332
343
|
dfs = []
|
|
@@ -350,7 +361,11 @@ class RecDataLoader(FeatureSet):
|
|
|
350
361
|
f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use streaming=True or reduce chunk_size."
|
|
351
362
|
) from exc
|
|
352
363
|
return self.create_from_memory(
|
|
353
|
-
combined_df,
|
|
364
|
+
combined_df,
|
|
365
|
+
batch_size,
|
|
366
|
+
shuffle,
|
|
367
|
+
num_workers=num_workers,
|
|
368
|
+
prefetch_factor=prefetch_factor,
|
|
354
369
|
)
|
|
355
370
|
|
|
356
371
|
def load_files_streaming(
|
|
@@ -361,6 +376,7 @@ class RecDataLoader(FeatureSet):
|
|
|
361
376
|
chunk_size: int,
|
|
362
377
|
shuffle: bool,
|
|
363
378
|
num_workers: int = 0,
|
|
379
|
+
prefetch_factor: int | None = None,
|
|
364
380
|
) -> DataLoader:
|
|
365
381
|
if not check_streaming_support(file_type):
|
|
366
382
|
raise ValueError(
|
|
@@ -393,8 +409,15 @@ class RecDataLoader(FeatureSet):
|
|
|
393
409
|
file_type=file_type,
|
|
394
410
|
processor=self.processor,
|
|
395
411
|
)
|
|
412
|
+
loader_kwargs = {}
|
|
413
|
+
if num_workers > 0 and prefetch_factor is not None:
|
|
414
|
+
loader_kwargs["prefetch_factor"] = prefetch_factor
|
|
396
415
|
return DataLoader(
|
|
397
|
-
dataset,
|
|
416
|
+
dataset,
|
|
417
|
+
batch_size=1,
|
|
418
|
+
collate_fn=collate_fn,
|
|
419
|
+
num_workers=num_workers,
|
|
420
|
+
**loader_kwargs,
|
|
398
421
|
)
|
|
399
422
|
|
|
400
423
|
|
nextrec/data/preprocessor.py
CHANGED
|
@@ -45,7 +45,15 @@ from nextrec.utils.data import (
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
class DataProcessor(FeatureSet):
|
|
48
|
-
def __init__(
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
hash_cache_size: int = 200_000,
|
|
51
|
+
):
|
|
52
|
+
if not logging.getLogger().hasHandlers():
|
|
53
|
+
logging.basicConfig(
|
|
54
|
+
level=logging.INFO,
|
|
55
|
+
format="%(message)s",
|
|
56
|
+
)
|
|
49
57
|
self.numeric_features: Dict[str, Dict[str, Any]] = {}
|
|
50
58
|
self.sparse_features: Dict[str, Dict[str, Any]] = {}
|
|
51
59
|
self.sequence_features: Dict[str, Dict[str, Any]] = {}
|
|
@@ -53,9 +61,6 @@ class DataProcessor(FeatureSet):
|
|
|
53
61
|
self.version = __version__
|
|
54
62
|
|
|
55
63
|
self.is_fitted = False
|
|
56
|
-
self._transform_summary_printed = (
|
|
57
|
-
False # Track if summary has been printed during transform
|
|
58
|
-
)
|
|
59
64
|
|
|
60
65
|
self.scalers: Dict[str, Any] = {}
|
|
61
66
|
self.label_encoders: Dict[str, LabelEncoder] = {}
|
|
@@ -92,17 +97,19 @@ class DataProcessor(FeatureSet):
|
|
|
92
97
|
def add_sparse_feature(
|
|
93
98
|
self,
|
|
94
99
|
name: str,
|
|
95
|
-
encode_method: Literal["hash", "label"] = "
|
|
100
|
+
encode_method: Literal["hash", "label"] = "hash",
|
|
96
101
|
hash_size: Optional[int] = None,
|
|
102
|
+
min_freq: Optional[int] = None,
|
|
97
103
|
fill_na: str = "<UNK>",
|
|
98
104
|
):
|
|
99
105
|
"""Add a sparse feature configuration.
|
|
100
106
|
|
|
101
107
|
Args:
|
|
102
|
-
name
|
|
103
|
-
encode_method
|
|
104
|
-
hash_size
|
|
105
|
-
|
|
108
|
+
name: Feature name.
|
|
109
|
+
encode_method: Encoding method, including "hash encoding" and "label encoding". Defaults to "hash" because it is more scalable and much faster.
|
|
110
|
+
hash_size: Hash size for hash encoding. Required if encode_method is "hash".
|
|
111
|
+
min_freq: Minimum frequency for hash encoding to keep tokens; lower-frequency tokens map to unknown. Defaults to None.
|
|
112
|
+
fill_na: Fill value for missing entries. Defaults to "<UNK>".
|
|
106
113
|
"""
|
|
107
114
|
if encode_method == "hash" and hash_size is None:
|
|
108
115
|
raise ValueError(
|
|
@@ -111,6 +118,7 @@ class DataProcessor(FeatureSet):
|
|
|
111
118
|
self.sparse_features[name] = {
|
|
112
119
|
"encode_method": encode_method,
|
|
113
120
|
"hash_size": hash_size,
|
|
121
|
+
"min_freq": min_freq,
|
|
114
122
|
"fill_na": fill_na,
|
|
115
123
|
}
|
|
116
124
|
|
|
@@ -119,6 +127,7 @@ class DataProcessor(FeatureSet):
|
|
|
119
127
|
name: str,
|
|
120
128
|
encode_method: Literal["hash", "label"] = "hash",
|
|
121
129
|
hash_size: Optional[int] = None,
|
|
130
|
+
min_freq: Optional[int] = None,
|
|
122
131
|
max_len: Optional[int] = 50,
|
|
123
132
|
pad_value: int = 0,
|
|
124
133
|
truncate: Literal[
|
|
@@ -129,13 +138,14 @@ class DataProcessor(FeatureSet):
|
|
|
129
138
|
"""Add a sequence feature configuration.
|
|
130
139
|
|
|
131
140
|
Args:
|
|
132
|
-
name
|
|
133
|
-
encode_method
|
|
134
|
-
hash_size
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
141
|
+
name: Feature name.
|
|
142
|
+
encode_method: Encoding method, including "hash encoding" and "label encoding". Defaults to "hash".
|
|
143
|
+
hash_size: Hash size for hash encoding. Required if encode_method is "hash".
|
|
144
|
+
min_freq: Minimum frequency for hash encoding to keep tokens; lower-frequency tokens map to unknown. Defaults to None.
|
|
145
|
+
max_len: Maximum sequence length. Defaults to 50.
|
|
146
|
+
pad_value: Padding value for sequences shorter than max_len. Defaults to 0.
|
|
147
|
+
truncate: Truncation strategy for sequences longer than max_len, including "pre" (keep last max_len items) and "post" (keep first max_len items). Defaults to "pre".
|
|
148
|
+
separator: Separator for string sequences. Defaults to ",".
|
|
139
149
|
"""
|
|
140
150
|
if encode_method == "hash" and hash_size is None:
|
|
141
151
|
raise ValueError(
|
|
@@ -144,6 +154,7 @@ class DataProcessor(FeatureSet):
|
|
|
144
154
|
self.sequence_features[name] = {
|
|
145
155
|
"encode_method": encode_method,
|
|
146
156
|
"hash_size": hash_size,
|
|
157
|
+
"min_freq": min_freq,
|
|
147
158
|
"max_len": max_len,
|
|
148
159
|
"pad_value": pad_value,
|
|
149
160
|
"truncate": truncate,
|
|
@@ -175,17 +186,6 @@ class DataProcessor(FeatureSet):
|
|
|
175
186
|
def hash_string(self, s: str, hash_size: int) -> int:
|
|
176
187
|
return self.hash_fn(str(s), int(hash_size))
|
|
177
188
|
|
|
178
|
-
def clear_hash_cache(self) -> None:
|
|
179
|
-
cache_clear = getattr(self.hash_fn, "cache_clear", None)
|
|
180
|
-
if callable(cache_clear):
|
|
181
|
-
cache_clear()
|
|
182
|
-
|
|
183
|
-
def hash_cache_info(self):
|
|
184
|
-
cache_info = getattr(self.hash_fn, "cache_info", None)
|
|
185
|
-
if callable(cache_info):
|
|
186
|
-
return cache_info()
|
|
187
|
-
return None
|
|
188
|
-
|
|
189
189
|
def process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
190
190
|
name = str(data.name)
|
|
191
191
|
scaler_type = config["scaler"]
|
|
@@ -241,12 +241,30 @@ class DataProcessor(FeatureSet):
|
|
|
241
241
|
return result
|
|
242
242
|
|
|
243
243
|
def process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
244
|
-
|
|
244
|
+
logger = logging.getLogger()
|
|
245
|
+
|
|
245
246
|
encode_method = config["encode_method"]
|
|
246
247
|
fill_na = config["fill_na"] # <UNK>
|
|
247
248
|
filled_data = data.fillna(fill_na).astype(str)
|
|
248
249
|
if encode_method == "label":
|
|
249
|
-
|
|
250
|
+
min_freq = config.get("min_freq")
|
|
251
|
+
if min_freq is not None:
|
|
252
|
+
counts = filled_data.value_counts()
|
|
253
|
+
config["_token_counts"] = counts.to_dict()
|
|
254
|
+
vocab = sorted(counts[counts >= min_freq].index.tolist())
|
|
255
|
+
low_freq_types = int((counts < min_freq).sum())
|
|
256
|
+
total_types = int(counts.size)
|
|
257
|
+
kept_types = total_types - low_freq_types
|
|
258
|
+
if not config.get("_min_freq_logged"):
|
|
259
|
+
logger.info(
|
|
260
|
+
f"Sparse feature {data.name} min_freq={min_freq}: "
|
|
261
|
+
f"{total_types} token types total, "
|
|
262
|
+
f"{low_freq_types} low-frequency, "
|
|
263
|
+
f"{kept_types} kept."
|
|
264
|
+
)
|
|
265
|
+
config["_min_freq_logged"] = True
|
|
266
|
+
else:
|
|
267
|
+
vocab = sorted(set(filled_data.tolist()))
|
|
250
268
|
if "<UNK>" not in vocab:
|
|
251
269
|
vocab.append("<UNK>")
|
|
252
270
|
token_to_idx = {token: idx for idx, token in enumerate(vocab)}
|
|
@@ -254,6 +272,24 @@ class DataProcessor(FeatureSet):
|
|
|
254
272
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
255
273
|
config["vocab_size"] = len(vocab)
|
|
256
274
|
elif encode_method == "hash":
|
|
275
|
+
min_freq = config.get("min_freq")
|
|
276
|
+
if min_freq is not None:
|
|
277
|
+
counts = filled_data.value_counts()
|
|
278
|
+
config["_token_counts"] = counts.to_dict()
|
|
279
|
+
config["_unk_hash"] = self.hash_string(
|
|
280
|
+
"<UNK>", int(config["hash_size"])
|
|
281
|
+
)
|
|
282
|
+
low_freq_types = int((counts < min_freq).sum())
|
|
283
|
+
total_types = int(counts.size)
|
|
284
|
+
kept_types = total_types - low_freq_types
|
|
285
|
+
if not config.get("_min_freq_logged"):
|
|
286
|
+
logger.info(
|
|
287
|
+
f"Sparse feature {data.name} min_freq={min_freq}: "
|
|
288
|
+
f"{total_types} token types total, "
|
|
289
|
+
f"{low_freq_types} low-frequency, "
|
|
290
|
+
f"{kept_types} kept."
|
|
291
|
+
)
|
|
292
|
+
config["_min_freq_logged"] = True
|
|
257
293
|
config["vocab_size"] = config["hash_size"]
|
|
258
294
|
|
|
259
295
|
def process_sparse_feature_transform(
|
|
@@ -283,22 +319,60 @@ class DataProcessor(FeatureSet):
|
|
|
283
319
|
if encode_method == "hash":
|
|
284
320
|
hash_size = config["hash_size"]
|
|
285
321
|
hash_fn = self.hash_string
|
|
322
|
+
min_freq = config.get("min_freq")
|
|
323
|
+
token_counts = config.get("_token_counts")
|
|
324
|
+
if min_freq is not None and isinstance(token_counts, dict):
|
|
325
|
+
unk_hash = config.get("_unk_hash")
|
|
326
|
+
if unk_hash is None:
|
|
327
|
+
unk_hash = hash_fn("<UNK>", hash_size)
|
|
286
328
|
return np.fromiter(
|
|
287
|
-
(
|
|
329
|
+
(
|
|
330
|
+
(
|
|
331
|
+
unk_hash
|
|
332
|
+
if min_freq is not None
|
|
333
|
+
and isinstance(token_counts, dict)
|
|
334
|
+
and token_counts.get(v, 0) < min_freq
|
|
335
|
+
else hash_fn(v, hash_size)
|
|
336
|
+
)
|
|
337
|
+
for v in sparse_series.to_numpy()
|
|
338
|
+
),
|
|
288
339
|
dtype=np.int64,
|
|
289
340
|
count=sparse_series.size,
|
|
290
341
|
)
|
|
291
342
|
return np.array([], dtype=np.int64)
|
|
292
343
|
|
|
293
344
|
def process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
345
|
+
logger = logging.getLogger()
|
|
294
346
|
_ = str(data.name)
|
|
295
347
|
encode_method = config["encode_method"]
|
|
296
348
|
separator = config["separator"]
|
|
297
349
|
if encode_method == "label":
|
|
298
|
-
|
|
350
|
+
min_freq = config.get("min_freq")
|
|
351
|
+
token_counts: Dict[str, int] = {}
|
|
299
352
|
for seq in data:
|
|
300
|
-
|
|
301
|
-
|
|
353
|
+
tokens = self.extract_sequence_tokens(seq, separator)
|
|
354
|
+
for token in tokens:
|
|
355
|
+
if str(token).strip():
|
|
356
|
+
key = str(token)
|
|
357
|
+
token_counts[key] = token_counts.get(key, 0) + 1
|
|
358
|
+
if min_freq is not None:
|
|
359
|
+
config["_token_counts"] = token_counts
|
|
360
|
+
vocab = sorted([k for k, v in token_counts.items() if v >= min_freq])
|
|
361
|
+
low_freq_types = sum(
|
|
362
|
+
1 for count in token_counts.values() if count < min_freq
|
|
363
|
+
)
|
|
364
|
+
total_types = len(token_counts)
|
|
365
|
+
kept_types = total_types - low_freq_types
|
|
366
|
+
if not config.get("_min_freq_logged"):
|
|
367
|
+
logger.info(
|
|
368
|
+
f"Sequence feature {data.name} min_freq={min_freq}: "
|
|
369
|
+
f"{total_types} token types total, "
|
|
370
|
+
f"{low_freq_types} low-frequency, "
|
|
371
|
+
f"{kept_types} kept."
|
|
372
|
+
)
|
|
373
|
+
config["_min_freq_logged"] = True
|
|
374
|
+
else:
|
|
375
|
+
vocab = sorted(token_counts.keys())
|
|
302
376
|
if not vocab:
|
|
303
377
|
vocab = ["<PAD>"]
|
|
304
378
|
if "<UNK>" not in vocab:
|
|
@@ -308,6 +382,33 @@ class DataProcessor(FeatureSet):
|
|
|
308
382
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
309
383
|
config["vocab_size"] = len(vocab)
|
|
310
384
|
elif encode_method == "hash":
|
|
385
|
+
min_freq = config.get("min_freq")
|
|
386
|
+
if min_freq is not None:
|
|
387
|
+
token_counts: Dict[str, int] = {}
|
|
388
|
+
for seq in data:
|
|
389
|
+
tokens = self.extract_sequence_tokens(seq, separator)
|
|
390
|
+
for token in tokens:
|
|
391
|
+
if str(token).strip():
|
|
392
|
+
token_counts[str(token)] = (
|
|
393
|
+
token_counts.get(str(token), 0) + 1
|
|
394
|
+
)
|
|
395
|
+
config["_token_counts"] = token_counts
|
|
396
|
+
config["_unk_hash"] = self.hash_string(
|
|
397
|
+
"<UNK>", int(config["hash_size"])
|
|
398
|
+
)
|
|
399
|
+
low_freq_types = sum(
|
|
400
|
+
1 for count in token_counts.values() if count < min_freq
|
|
401
|
+
)
|
|
402
|
+
total_types = len(token_counts)
|
|
403
|
+
kept_types = total_types - low_freq_types
|
|
404
|
+
if not config.get("_min_freq_logged"):
|
|
405
|
+
logger.info(
|
|
406
|
+
f"Sequence feature {data.name} min_freq={min_freq}: "
|
|
407
|
+
f"{total_types} token types total, "
|
|
408
|
+
f"{low_freq_types} low-frequency, "
|
|
409
|
+
f"{kept_types} kept."
|
|
410
|
+
)
|
|
411
|
+
config["_min_freq_logged"] = True
|
|
311
412
|
config["vocab_size"] = config["hash_size"]
|
|
312
413
|
|
|
313
414
|
def process_sequence_feature_transform(
|
|
@@ -338,6 +439,12 @@ class DataProcessor(FeatureSet):
|
|
|
338
439
|
unk_index = 0
|
|
339
440
|
hash_fn = self.hash_string
|
|
340
441
|
hash_size = config.get("hash_size")
|
|
442
|
+
min_freq = config.get("min_freq")
|
|
443
|
+
token_counts = config.get("_token_counts")
|
|
444
|
+
if min_freq is not None and isinstance(token_counts, dict):
|
|
445
|
+
unk_hash = config.get("_unk_hash")
|
|
446
|
+
if unk_hash is None and hash_size is not None:
|
|
447
|
+
unk_hash = hash_fn("<UNK>", hash_size)
|
|
341
448
|
for i, seq in enumerate(arr):
|
|
342
449
|
# normalize sequence to a list of strings
|
|
343
450
|
tokens = []
|
|
@@ -364,7 +471,13 @@ class DataProcessor(FeatureSet):
|
|
|
364
471
|
"[Data Processor Error] hash_size must be set for hash encoding"
|
|
365
472
|
)
|
|
366
473
|
encoded = [
|
|
367
|
-
|
|
474
|
+
(
|
|
475
|
+
unk_hash
|
|
476
|
+
if min_freq is not None
|
|
477
|
+
and isinstance(token_counts, dict)
|
|
478
|
+
and token_counts.get(str(token), 0) < min_freq
|
|
479
|
+
else hash_fn(str(token), hash_size)
|
|
480
|
+
)
|
|
368
481
|
for token in tokens
|
|
369
482
|
if str(token).strip()
|
|
370
483
|
]
|
|
@@ -472,6 +585,10 @@ class DataProcessor(FeatureSet):
|
|
|
472
585
|
bold=True,
|
|
473
586
|
)
|
|
474
587
|
)
|
|
588
|
+
for config in self.sparse_features.values():
|
|
589
|
+
config.pop("_min_freq_logged", None)
|
|
590
|
+
for config in self.sequence_features.values():
|
|
591
|
+
config.pop("_min_freq_logged", None)
|
|
475
592
|
file_paths, file_type = resolve_file_paths(path)
|
|
476
593
|
if not check_streaming_support(file_type):
|
|
477
594
|
raise ValueError(
|
|
@@ -496,6 +613,26 @@ class DataProcessor(FeatureSet):
|
|
|
496
613
|
seq_vocab: Dict[str, set[str]] = {
|
|
497
614
|
name: set() for name in self.sequence_features.keys()
|
|
498
615
|
}
|
|
616
|
+
sparse_label_counts: Dict[str, Dict[str, int]] = {
|
|
617
|
+
name: {}
|
|
618
|
+
for name, config in self.sparse_features.items()
|
|
619
|
+
if config.get("encode_method") == "label" and config.get("min_freq")
|
|
620
|
+
}
|
|
621
|
+
seq_label_counts: Dict[str, Dict[str, int]] = {
|
|
622
|
+
name: {}
|
|
623
|
+
for name, config in self.sequence_features.items()
|
|
624
|
+
if config.get("encode_method") == "label" and config.get("min_freq")
|
|
625
|
+
}
|
|
626
|
+
sparse_hash_counts: Dict[str, Dict[str, int]] = {
|
|
627
|
+
name: {}
|
|
628
|
+
for name, config in self.sparse_features.items()
|
|
629
|
+
if config.get("encode_method") == "hash" and config.get("min_freq")
|
|
630
|
+
}
|
|
631
|
+
seq_hash_counts: Dict[str, Dict[str, int]] = {
|
|
632
|
+
name: {}
|
|
633
|
+
for name, config in self.sequence_features.items()
|
|
634
|
+
if config.get("encode_method") == "hash" and config.get("min_freq")
|
|
635
|
+
}
|
|
499
636
|
target_values: Dict[str, set[Any]] = {
|
|
500
637
|
name: set() for name in self.target_features.keys()
|
|
501
638
|
}
|
|
@@ -531,6 +668,14 @@ class DataProcessor(FeatureSet):
|
|
|
531
668
|
fill_na = config["fill_na"]
|
|
532
669
|
series = series.fillna(fill_na).astype(str)
|
|
533
670
|
sparse_vocab[name].update(series.tolist())
|
|
671
|
+
if name in sparse_label_counts:
|
|
672
|
+
counts = sparse_label_counts[name]
|
|
673
|
+
for token in series.tolist():
|
|
674
|
+
counts[token] = counts.get(token, 0) + 1
|
|
675
|
+
if name in sparse_hash_counts:
|
|
676
|
+
counts = sparse_hash_counts[name]
|
|
677
|
+
for token in series.tolist():
|
|
678
|
+
counts[token] = counts.get(token, 0) + 1
|
|
534
679
|
else:
|
|
535
680
|
separator = config["separator"]
|
|
536
681
|
tokens = []
|
|
@@ -539,6 +684,18 @@ class DataProcessor(FeatureSet):
|
|
|
539
684
|
self.extract_sequence_tokens(val, separator)
|
|
540
685
|
)
|
|
541
686
|
seq_vocab[name].update(tokens)
|
|
687
|
+
if name in seq_label_counts:
|
|
688
|
+
counts = seq_label_counts[name]
|
|
689
|
+
for token in tokens:
|
|
690
|
+
if str(token).strip():
|
|
691
|
+
key = str(token)
|
|
692
|
+
counts[key] = counts.get(key, 0) + 1
|
|
693
|
+
if name in seq_hash_counts:
|
|
694
|
+
counts = seq_hash_counts[name]
|
|
695
|
+
for token in tokens:
|
|
696
|
+
if str(token).strip():
|
|
697
|
+
key = str(token)
|
|
698
|
+
counts[key] = counts.get(key, 0) + 1
|
|
542
699
|
|
|
543
700
|
# target features
|
|
544
701
|
missing_features.update(self.target_features.keys() - columns)
|
|
@@ -605,7 +762,30 @@ class DataProcessor(FeatureSet):
|
|
|
605
762
|
# finalize sparse label encoders
|
|
606
763
|
for name, config in self.sparse_features.items():
|
|
607
764
|
if config["encode_method"] == "label":
|
|
608
|
-
|
|
765
|
+
min_freq = config.get("min_freq")
|
|
766
|
+
if min_freq is not None:
|
|
767
|
+
token_counts = sparse_label_counts.get(name, {})
|
|
768
|
+
config["_token_counts"] = token_counts
|
|
769
|
+
vocab = {
|
|
770
|
+
token
|
|
771
|
+
for token, count in token_counts.items()
|
|
772
|
+
if count >= min_freq
|
|
773
|
+
}
|
|
774
|
+
low_freq_types = sum(
|
|
775
|
+
1 for count in token_counts.values() if count < min_freq
|
|
776
|
+
)
|
|
777
|
+
total_types = len(token_counts)
|
|
778
|
+
kept_types = total_types - low_freq_types
|
|
779
|
+
if not config.get("_min_freq_logged"):
|
|
780
|
+
logger.info(
|
|
781
|
+
f"Sparse feature {name} min_freq={min_freq}: "
|
|
782
|
+
f"{total_types} token types total, "
|
|
783
|
+
f"{low_freq_types} low-frequency, "
|
|
784
|
+
f"{kept_types} kept."
|
|
785
|
+
)
|
|
786
|
+
config["_min_freq_logged"] = True
|
|
787
|
+
else:
|
|
788
|
+
vocab = sparse_vocab[name]
|
|
609
789
|
if not vocab:
|
|
610
790
|
logger.warning(f"Sparse feature {name} has empty vocabulary")
|
|
611
791
|
continue
|
|
@@ -617,12 +797,55 @@ class DataProcessor(FeatureSet):
|
|
|
617
797
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
618
798
|
config["vocab_size"] = len(vocab_list)
|
|
619
799
|
elif config["encode_method"] == "hash":
|
|
800
|
+
min_freq = config.get("min_freq")
|
|
801
|
+
if min_freq is not None:
|
|
802
|
+
token_counts = sparse_hash_counts.get(name, {})
|
|
803
|
+
config["_token_counts"] = token_counts
|
|
804
|
+
config["_unk_hash"] = self.hash_string(
|
|
805
|
+
"<UNK>", int(config["hash_size"])
|
|
806
|
+
)
|
|
807
|
+
low_freq_types = sum(
|
|
808
|
+
1 for count in token_counts.values() if count < min_freq
|
|
809
|
+
)
|
|
810
|
+
total_types = len(token_counts)
|
|
811
|
+
kept_types = total_types - low_freq_types
|
|
812
|
+
if not config.get("_min_freq_logged"):
|
|
813
|
+
logger.info(
|
|
814
|
+
f"Sparse feature {name} min_freq={min_freq}: "
|
|
815
|
+
f"{total_types} token types total, "
|
|
816
|
+
f"{low_freq_types} low-frequency, "
|
|
817
|
+
f"{kept_types} kept."
|
|
818
|
+
)
|
|
819
|
+
config["_min_freq_logged"] = True
|
|
620
820
|
config["vocab_size"] = config["hash_size"]
|
|
621
821
|
|
|
622
822
|
# finalize sequence vocabularies
|
|
623
823
|
for name, config in self.sequence_features.items():
|
|
624
824
|
if config["encode_method"] == "label":
|
|
625
|
-
|
|
825
|
+
min_freq = config.get("min_freq")
|
|
826
|
+
if min_freq is not None:
|
|
827
|
+
token_counts = seq_label_counts.get(name, {})
|
|
828
|
+
config["_token_counts"] = token_counts
|
|
829
|
+
vocab_set = {
|
|
830
|
+
token
|
|
831
|
+
for token, count in token_counts.items()
|
|
832
|
+
if count >= min_freq
|
|
833
|
+
}
|
|
834
|
+
low_freq_types = sum(
|
|
835
|
+
1 for count in token_counts.values() if count < min_freq
|
|
836
|
+
)
|
|
837
|
+
total_types = len(token_counts)
|
|
838
|
+
kept_types = total_types - low_freq_types
|
|
839
|
+
if not config.get("_min_freq_logged"):
|
|
840
|
+
logger.info(
|
|
841
|
+
f"Sequence feature {name} min_freq={min_freq}: "
|
|
842
|
+
f"{total_types} token types total, "
|
|
843
|
+
f"{low_freq_types} low-frequency, "
|
|
844
|
+
f"{kept_types} kept."
|
|
845
|
+
)
|
|
846
|
+
config["_min_freq_logged"] = True
|
|
847
|
+
else:
|
|
848
|
+
vocab_set = seq_vocab[name]
|
|
626
849
|
vocab_list = sorted(vocab_set) if vocab_set else ["<PAD>"]
|
|
627
850
|
if "<UNK>" not in vocab_list:
|
|
628
851
|
vocab_list.append("<UNK>")
|
|
@@ -631,6 +854,26 @@ class DataProcessor(FeatureSet):
|
|
|
631
854
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
632
855
|
config["vocab_size"] = len(vocab_list)
|
|
633
856
|
elif config["encode_method"] == "hash":
|
|
857
|
+
min_freq = config.get("min_freq")
|
|
858
|
+
if min_freq is not None:
|
|
859
|
+
token_counts = seq_hash_counts.get(name, {})
|
|
860
|
+
config["_token_counts"] = token_counts
|
|
861
|
+
config["_unk_hash"] = self.hash_string(
|
|
862
|
+
"<UNK>", int(config["hash_size"])
|
|
863
|
+
)
|
|
864
|
+
low_freq_types = sum(
|
|
865
|
+
1 for count in token_counts.values() if count < min_freq
|
|
866
|
+
)
|
|
867
|
+
total_types = len(token_counts)
|
|
868
|
+
kept_types = total_types - low_freq_types
|
|
869
|
+
if not config.get("_min_freq_logged"):
|
|
870
|
+
logger.info(
|
|
871
|
+
f"Sequence feature {name} min_freq={min_freq}: "
|
|
872
|
+
f"{total_types} token types total, "
|
|
873
|
+
f"{low_freq_types} low-frequency, "
|
|
874
|
+
f"{kept_types} kept."
|
|
875
|
+
)
|
|
876
|
+
config["_min_freq_logged"] = True
|
|
634
877
|
config["vocab_size"] = config["hash_size"]
|
|
635
878
|
|
|
636
879
|
# finalize targets
|
|
@@ -961,6 +1204,10 @@ class DataProcessor(FeatureSet):
|
|
|
961
1204
|
"""
|
|
962
1205
|
|
|
963
1206
|
logger = logging.getLogger()
|
|
1207
|
+
for config in self.sparse_features.values():
|
|
1208
|
+
config.pop("_min_freq_logged", None)
|
|
1209
|
+
for config in self.sequence_features.values():
|
|
1210
|
+
config.pop("_min_freq_logged", None)
|
|
964
1211
|
if isinstance(data, (str, os.PathLike)):
|
|
965
1212
|
path_str = str(data)
|
|
966
1213
|
uses_robust = any(
|
nextrec/utils/config.py
CHANGED
|
@@ -116,6 +116,7 @@ def register_processor_features(
|
|
|
116
116
|
name,
|
|
117
117
|
encode_method=proc_cfg.get("encode_method", "hash"),
|
|
118
118
|
hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
|
|
119
|
+
min_freq=proc_cfg.get("min_freq"),
|
|
119
120
|
fill_na=proc_cfg.get("fill_na", "<UNK>"),
|
|
120
121
|
)
|
|
121
122
|
|
|
@@ -125,6 +126,7 @@ def register_processor_features(
|
|
|
125
126
|
name,
|
|
126
127
|
encode_method=proc_cfg.get("encode_method", "hash"),
|
|
127
128
|
hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
|
|
129
|
+
min_freq=proc_cfg.get("min_freq"),
|
|
128
130
|
max_len=proc_cfg.get("max_len", 50),
|
|
129
131
|
pad_value=proc_cfg.get("pad_value", 0),
|
|
130
132
|
truncate=proc_cfg.get("truncate", "post"),
|