nextrec 0.4.24__py3-none-any.whl → 0.4.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,8 @@ import numpy as np
13
13
  import pandas as pd
14
14
  import torch
15
15
 
16
+ from nextrec.utils.torch_utils import to_numpy
17
+
16
18
 
17
19
  def get_column_data(data: dict | pd.DataFrame, name: str):
18
20
 
@@ -23,15 +25,7 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
23
25
  return None
24
26
  return data[name].values
25
27
  else:
26
- if hasattr(data, name):
27
- return getattr(data, name)
28
- raise KeyError(f"Unsupported data type for extracting column {name}")
29
-
30
-
31
- def to_numpy(values: Any) -> np.ndarray:
32
- if isinstance(values, torch.Tensor):
33
- return values.detach().cpu().numpy()
34
- return np.asarray(values)
28
+ raise KeyError(f"Only dict or DataFrame supported, got {type(data)}")
35
29
 
36
30
 
37
31
  def get_data_length(data: Any) -> int | None:
@@ -194,6 +194,7 @@ class RecDataLoader(FeatureSet):
194
194
  streaming: bool = False,
195
195
  chunk_size: int = 10000,
196
196
  num_workers: int = 0,
197
+ prefetch_factor: int | None = None,
197
198
  sampler=None,
198
199
  ) -> DataLoader:
199
200
  """
@@ -206,6 +207,7 @@ class RecDataLoader(FeatureSet):
206
207
  streaming: If True, use streaming mode for large files; if False, load full data into memory.
207
208
  chunk_size: Chunk size for streaming mode (number of rows per chunk).
208
209
  num_workers: Number of worker processes for data loading.
210
+ prefetch_factor: Number of batches loaded in advance by each worker.
209
211
  sampler: Optional sampler for DataLoader, only used for distributed training.
210
212
  Returns:
211
213
  DataLoader instance.
@@ -234,6 +236,7 @@ class RecDataLoader(FeatureSet):
234
236
  streaming=streaming,
235
237
  chunk_size=chunk_size,
236
238
  num_workers=num_workers,
239
+ prefetch_factor=prefetch_factor,
237
240
  )
238
241
 
239
242
  if isinstance(data, (dict, pd.DataFrame)):
@@ -242,6 +245,7 @@ class RecDataLoader(FeatureSet):
242
245
  batch_size=batch_size,
243
246
  shuffle=shuffle,
244
247
  num_workers=num_workers,
248
+ prefetch_factor=prefetch_factor,
245
249
  sampler=sampler,
246
250
  )
247
251
 
@@ -253,6 +257,7 @@ class RecDataLoader(FeatureSet):
253
257
  batch_size: int,
254
258
  shuffle: bool,
255
259
  num_workers: int = 0,
260
+ prefetch_factor: int | None = None,
256
261
  sampler=None,
257
262
  ) -> DataLoader:
258
263
  raw_data = data
@@ -275,6 +280,9 @@ class RecDataLoader(FeatureSet):
275
280
  "[RecDataLoader Error] No valid tensors could be built from the provided data."
276
281
  )
277
282
  dataset = TensorDictDataset(tensors)
283
+ loader_kwargs = {}
284
+ if num_workers > 0 and prefetch_factor is not None:
285
+ loader_kwargs["prefetch_factor"] = prefetch_factor
278
286
  return DataLoader(
279
287
  dataset,
280
288
  batch_size=batch_size,
@@ -284,6 +292,7 @@ class RecDataLoader(FeatureSet):
284
292
  num_workers=num_workers,
285
293
  pin_memory=torch.cuda.is_available(),
286
294
  persistent_workers=num_workers > 0,
295
+ **loader_kwargs,
287
296
  )
288
297
 
289
298
  def create_from_path(
@@ -294,6 +303,7 @@ class RecDataLoader(FeatureSet):
294
303
  streaming: bool,
295
304
  chunk_size: int = 10000,
296
305
  num_workers: int = 0,
306
+ prefetch_factor: int | None = None,
297
307
  ) -> DataLoader:
298
308
  if isinstance(path, (str, os.PathLike)):
299
309
  file_paths, file_type = resolve_file_paths(str(Path(path)))
@@ -327,6 +337,7 @@ class RecDataLoader(FeatureSet):
327
337
  chunk_size,
328
338
  shuffle,
329
339
  num_workers=num_workers,
340
+ prefetch_factor=prefetch_factor,
330
341
  )
331
342
 
332
343
  dfs = []
@@ -350,7 +361,11 @@ class RecDataLoader(FeatureSet):
350
361
  f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use streaming=True or reduce chunk_size."
351
362
  ) from exc
352
363
  return self.create_from_memory(
353
- combined_df, batch_size, shuffle, num_workers=num_workers
364
+ combined_df,
365
+ batch_size,
366
+ shuffle,
367
+ num_workers=num_workers,
368
+ prefetch_factor=prefetch_factor,
354
369
  )
355
370
 
356
371
  def load_files_streaming(
@@ -361,6 +376,7 @@ class RecDataLoader(FeatureSet):
361
376
  chunk_size: int,
362
377
  shuffle: bool,
363
378
  num_workers: int = 0,
379
+ prefetch_factor: int | None = None,
364
380
  ) -> DataLoader:
365
381
  if not check_streaming_support(file_type):
366
382
  raise ValueError(
@@ -393,8 +409,15 @@ class RecDataLoader(FeatureSet):
393
409
  file_type=file_type,
394
410
  processor=self.processor,
395
411
  )
412
+ loader_kwargs = {}
413
+ if num_workers > 0 and prefetch_factor is not None:
414
+ loader_kwargs["prefetch_factor"] = prefetch_factor
396
415
  return DataLoader(
397
- dataset, batch_size=1, collate_fn=collate_fn, num_workers=num_workers
416
+ dataset,
417
+ batch_size=1,
418
+ collate_fn=collate_fn,
419
+ num_workers=num_workers,
420
+ **loader_kwargs,
398
421
  )
399
422
 
400
423
 
@@ -45,7 +45,15 @@ from nextrec.utils.data import (
45
45
 
46
46
 
47
47
  class DataProcessor(FeatureSet):
48
- def __init__(self, hash_cache_size: int = 200_000):
48
+ def __init__(
49
+ self,
50
+ hash_cache_size: int = 200_000,
51
+ ):
52
+ if not logging.getLogger().hasHandlers():
53
+ logging.basicConfig(
54
+ level=logging.INFO,
55
+ format="%(message)s",
56
+ )
49
57
  self.numeric_features: Dict[str, Dict[str, Any]] = {}
50
58
  self.sparse_features: Dict[str, Dict[str, Any]] = {}
51
59
  self.sequence_features: Dict[str, Dict[str, Any]] = {}
@@ -53,9 +61,6 @@ class DataProcessor(FeatureSet):
53
61
  self.version = __version__
54
62
 
55
63
  self.is_fitted = False
56
- self._transform_summary_printed = (
57
- False # Track if summary has been printed during transform
58
- )
59
64
 
60
65
  self.scalers: Dict[str, Any] = {}
61
66
  self.label_encoders: Dict[str, LabelEncoder] = {}
@@ -92,17 +97,19 @@ class DataProcessor(FeatureSet):
92
97
  def add_sparse_feature(
93
98
  self,
94
99
  name: str,
95
- encode_method: Literal["hash", "label"] = "label",
100
+ encode_method: Literal["hash", "label"] = "hash",
96
101
  hash_size: Optional[int] = None,
102
+ min_freq: Optional[int] = None,
97
103
  fill_na: str = "<UNK>",
98
104
  ):
99
105
  """Add a sparse feature configuration.
100
106
 
101
107
  Args:
102
- name (str): Feature name.
103
- encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "label".
104
- hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
105
- fill_na (str, optional): Fill value for missing entries. Defaults to "<UNK>".
108
+ name: Feature name.
109
+ encode_method: Encoding method, including "hash encoding" and "label encoding". Defaults to "hash" because it is more scalable and much faster.
110
+ hash_size: Hash size for hash encoding. Required if encode_method is "hash".
111
+ min_freq: Minimum frequency for hash encoding to keep tokens; lower-frequency tokens map to unknown. Defaults to None.
112
+ fill_na: Fill value for missing entries. Defaults to "<UNK>".
106
113
  """
107
114
  if encode_method == "hash" and hash_size is None:
108
115
  raise ValueError(
@@ -111,6 +118,7 @@ class DataProcessor(FeatureSet):
111
118
  self.sparse_features[name] = {
112
119
  "encode_method": encode_method,
113
120
  "hash_size": hash_size,
121
+ "min_freq": min_freq,
114
122
  "fill_na": fill_na,
115
123
  }
116
124
 
@@ -119,6 +127,7 @@ class DataProcessor(FeatureSet):
119
127
  name: str,
120
128
  encode_method: Literal["hash", "label"] = "hash",
121
129
  hash_size: Optional[int] = None,
130
+ min_freq: Optional[int] = None,
122
131
  max_len: Optional[int] = 50,
123
132
  pad_value: int = 0,
124
133
  truncate: Literal[
@@ -129,13 +138,14 @@ class DataProcessor(FeatureSet):
129
138
  """Add a sequence feature configuration.
130
139
 
131
140
  Args:
132
- name (str): Feature name.
133
- encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "hash".
134
- hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
135
- max_len (Optional[int], optional): Maximum sequence length. Defaults to 50.
136
- pad_value (int, optional): Padding value for sequences shorter than max_len. Defaults to 0.
137
- truncate (Literal["pre", "post"], optional): Truncation strategy for sequences longer than max_len, including "pre" (keep last max_len items) and "post" (keep first max_len items). Defaults to "pre".
138
- separator (str, optional): Separator for string sequences. Defaults to ",".
141
+ name: Feature name.
142
+ encode_method: Encoding method, including "hash encoding" and "label encoding". Defaults to "hash".
143
+ hash_size: Hash size for hash encoding. Required if encode_method is "hash".
144
+ min_freq: Minimum frequency for hash encoding to keep tokens; lower-frequency tokens map to unknown. Defaults to None.
145
+ max_len: Maximum sequence length. Defaults to 50.
146
+ pad_value: Padding value for sequences shorter than max_len. Defaults to 0.
147
+ truncate: Truncation strategy for sequences longer than max_len, including "pre" (keep last max_len items) and "post" (keep first max_len items). Defaults to "pre".
148
+ separator: Separator for string sequences. Defaults to ",".
139
149
  """
140
150
  if encode_method == "hash" and hash_size is None:
141
151
  raise ValueError(
@@ -144,6 +154,7 @@ class DataProcessor(FeatureSet):
144
154
  self.sequence_features[name] = {
145
155
  "encode_method": encode_method,
146
156
  "hash_size": hash_size,
157
+ "min_freq": min_freq,
147
158
  "max_len": max_len,
148
159
  "pad_value": pad_value,
149
160
  "truncate": truncate,
@@ -175,17 +186,6 @@ class DataProcessor(FeatureSet):
175
186
  def hash_string(self, s: str, hash_size: int) -> int:
176
187
  return self.hash_fn(str(s), int(hash_size))
177
188
 
178
- def clear_hash_cache(self) -> None:
179
- cache_clear = getattr(self.hash_fn, "cache_clear", None)
180
- if callable(cache_clear):
181
- cache_clear()
182
-
183
- def hash_cache_info(self):
184
- cache_info = getattr(self.hash_fn, "cache_info", None)
185
- if callable(cache_info):
186
- return cache_info()
187
- return None
188
-
189
189
  def process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
190
190
  name = str(data.name)
191
191
  scaler_type = config["scaler"]
@@ -241,12 +241,30 @@ class DataProcessor(FeatureSet):
241
241
  return result
242
242
 
243
243
  def process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
244
- _ = str(data.name)
244
+ logger = logging.getLogger()
245
+
245
246
  encode_method = config["encode_method"]
246
247
  fill_na = config["fill_na"] # <UNK>
247
248
  filled_data = data.fillna(fill_na).astype(str)
248
249
  if encode_method == "label":
249
- vocab = sorted(set(filled_data.tolist()))
250
+ min_freq = config.get("min_freq")
251
+ if min_freq is not None:
252
+ counts = filled_data.value_counts()
253
+ config["_token_counts"] = counts.to_dict()
254
+ vocab = sorted(counts[counts >= min_freq].index.tolist())
255
+ low_freq_types = int((counts < min_freq).sum())
256
+ total_types = int(counts.size)
257
+ kept_types = total_types - low_freq_types
258
+ if not config.get("_min_freq_logged"):
259
+ logger.info(
260
+ f"Sparse feature {data.name} min_freq={min_freq}: "
261
+ f"{total_types} token types total, "
262
+ f"{low_freq_types} low-frequency, "
263
+ f"{kept_types} kept."
264
+ )
265
+ config["_min_freq_logged"] = True
266
+ else:
267
+ vocab = sorted(set(filled_data.tolist()))
250
268
  if "<UNK>" not in vocab:
251
269
  vocab.append("<UNK>")
252
270
  token_to_idx = {token: idx for idx, token in enumerate(vocab)}
@@ -254,6 +272,24 @@ class DataProcessor(FeatureSet):
254
272
  config["_unk_index"] = token_to_idx["<UNK>"]
255
273
  config["vocab_size"] = len(vocab)
256
274
  elif encode_method == "hash":
275
+ min_freq = config.get("min_freq")
276
+ if min_freq is not None:
277
+ counts = filled_data.value_counts()
278
+ config["_token_counts"] = counts.to_dict()
279
+ config["_unk_hash"] = self.hash_string(
280
+ "<UNK>", int(config["hash_size"])
281
+ )
282
+ low_freq_types = int((counts < min_freq).sum())
283
+ total_types = int(counts.size)
284
+ kept_types = total_types - low_freq_types
285
+ if not config.get("_min_freq_logged"):
286
+ logger.info(
287
+ f"Sparse feature {data.name} min_freq={min_freq}: "
288
+ f"{total_types} token types total, "
289
+ f"{low_freq_types} low-frequency, "
290
+ f"{kept_types} kept."
291
+ )
292
+ config["_min_freq_logged"] = True
257
293
  config["vocab_size"] = config["hash_size"]
258
294
 
259
295
  def process_sparse_feature_transform(
@@ -283,22 +319,60 @@ class DataProcessor(FeatureSet):
283
319
  if encode_method == "hash":
284
320
  hash_size = config["hash_size"]
285
321
  hash_fn = self.hash_string
322
+ min_freq = config.get("min_freq")
323
+ token_counts = config.get("_token_counts")
324
+ if min_freq is not None and isinstance(token_counts, dict):
325
+ unk_hash = config.get("_unk_hash")
326
+ if unk_hash is None:
327
+ unk_hash = hash_fn("<UNK>", hash_size)
286
328
  return np.fromiter(
287
- (hash_fn(v, hash_size) for v in sparse_series.to_numpy()),
329
+ (
330
+ (
331
+ unk_hash
332
+ if min_freq is not None
333
+ and isinstance(token_counts, dict)
334
+ and token_counts.get(v, 0) < min_freq
335
+ else hash_fn(v, hash_size)
336
+ )
337
+ for v in sparse_series.to_numpy()
338
+ ),
288
339
  dtype=np.int64,
289
340
  count=sparse_series.size,
290
341
  )
291
342
  return np.array([], dtype=np.int64)
292
343
 
293
344
  def process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
345
+ logger = logging.getLogger()
294
346
  _ = str(data.name)
295
347
  encode_method = config["encode_method"]
296
348
  separator = config["separator"]
297
349
  if encode_method == "label":
298
- all_tokens = set()
350
+ min_freq = config.get("min_freq")
351
+ token_counts: Dict[str, int] = {}
299
352
  for seq in data:
300
- all_tokens.update(self.extract_sequence_tokens(seq, separator))
301
- vocab = sorted(all_tokens)
353
+ tokens = self.extract_sequence_tokens(seq, separator)
354
+ for token in tokens:
355
+ if str(token).strip():
356
+ key = str(token)
357
+ token_counts[key] = token_counts.get(key, 0) + 1
358
+ if min_freq is not None:
359
+ config["_token_counts"] = token_counts
360
+ vocab = sorted([k for k, v in token_counts.items() if v >= min_freq])
361
+ low_freq_types = sum(
362
+ 1 for count in token_counts.values() if count < min_freq
363
+ )
364
+ total_types = len(token_counts)
365
+ kept_types = total_types - low_freq_types
366
+ if not config.get("_min_freq_logged"):
367
+ logger.info(
368
+ f"Sequence feature {data.name} min_freq={min_freq}: "
369
+ f"{total_types} token types total, "
370
+ f"{low_freq_types} low-frequency, "
371
+ f"{kept_types} kept."
372
+ )
373
+ config["_min_freq_logged"] = True
374
+ else:
375
+ vocab = sorted(token_counts.keys())
302
376
  if not vocab:
303
377
  vocab = ["<PAD>"]
304
378
  if "<UNK>" not in vocab:
@@ -308,6 +382,33 @@ class DataProcessor(FeatureSet):
308
382
  config["_unk_index"] = token_to_idx["<UNK>"]
309
383
  config["vocab_size"] = len(vocab)
310
384
  elif encode_method == "hash":
385
+ min_freq = config.get("min_freq")
386
+ if min_freq is not None:
387
+ token_counts: Dict[str, int] = {}
388
+ for seq in data:
389
+ tokens = self.extract_sequence_tokens(seq, separator)
390
+ for token in tokens:
391
+ if str(token).strip():
392
+ token_counts[str(token)] = (
393
+ token_counts.get(str(token), 0) + 1
394
+ )
395
+ config["_token_counts"] = token_counts
396
+ config["_unk_hash"] = self.hash_string(
397
+ "<UNK>", int(config["hash_size"])
398
+ )
399
+ low_freq_types = sum(
400
+ 1 for count in token_counts.values() if count < min_freq
401
+ )
402
+ total_types = len(token_counts)
403
+ kept_types = total_types - low_freq_types
404
+ if not config.get("_min_freq_logged"):
405
+ logger.info(
406
+ f"Sequence feature {data.name} min_freq={min_freq}: "
407
+ f"{total_types} token types total, "
408
+ f"{low_freq_types} low-frequency, "
409
+ f"{kept_types} kept."
410
+ )
411
+ config["_min_freq_logged"] = True
311
412
  config["vocab_size"] = config["hash_size"]
312
413
 
313
414
  def process_sequence_feature_transform(
@@ -338,6 +439,12 @@ class DataProcessor(FeatureSet):
338
439
  unk_index = 0
339
440
  hash_fn = self.hash_string
340
441
  hash_size = config.get("hash_size")
442
+ min_freq = config.get("min_freq")
443
+ token_counts = config.get("_token_counts")
444
+ if min_freq is not None and isinstance(token_counts, dict):
445
+ unk_hash = config.get("_unk_hash")
446
+ if unk_hash is None and hash_size is not None:
447
+ unk_hash = hash_fn("<UNK>", hash_size)
341
448
  for i, seq in enumerate(arr):
342
449
  # normalize sequence to a list of strings
343
450
  tokens = []
@@ -364,7 +471,13 @@ class DataProcessor(FeatureSet):
364
471
  "[Data Processor Error] hash_size must be set for hash encoding"
365
472
  )
366
473
  encoded = [
367
- hash_fn(str(token), hash_size)
474
+ (
475
+ unk_hash
476
+ if min_freq is not None
477
+ and isinstance(token_counts, dict)
478
+ and token_counts.get(str(token), 0) < min_freq
479
+ else hash_fn(str(token), hash_size)
480
+ )
368
481
  for token in tokens
369
482
  if str(token).strip()
370
483
  ]
@@ -472,6 +585,10 @@ class DataProcessor(FeatureSet):
472
585
  bold=True,
473
586
  )
474
587
  )
588
+ for config in self.sparse_features.values():
589
+ config.pop("_min_freq_logged", None)
590
+ for config in self.sequence_features.values():
591
+ config.pop("_min_freq_logged", None)
475
592
  file_paths, file_type = resolve_file_paths(path)
476
593
  if not check_streaming_support(file_type):
477
594
  raise ValueError(
@@ -496,6 +613,26 @@ class DataProcessor(FeatureSet):
496
613
  seq_vocab: Dict[str, set[str]] = {
497
614
  name: set() for name in self.sequence_features.keys()
498
615
  }
616
+ sparse_label_counts: Dict[str, Dict[str, int]] = {
617
+ name: {}
618
+ for name, config in self.sparse_features.items()
619
+ if config.get("encode_method") == "label" and config.get("min_freq")
620
+ }
621
+ seq_label_counts: Dict[str, Dict[str, int]] = {
622
+ name: {}
623
+ for name, config in self.sequence_features.items()
624
+ if config.get("encode_method") == "label" and config.get("min_freq")
625
+ }
626
+ sparse_hash_counts: Dict[str, Dict[str, int]] = {
627
+ name: {}
628
+ for name, config in self.sparse_features.items()
629
+ if config.get("encode_method") == "hash" and config.get("min_freq")
630
+ }
631
+ seq_hash_counts: Dict[str, Dict[str, int]] = {
632
+ name: {}
633
+ for name, config in self.sequence_features.items()
634
+ if config.get("encode_method") == "hash" and config.get("min_freq")
635
+ }
499
636
  target_values: Dict[str, set[Any]] = {
500
637
  name: set() for name in self.target_features.keys()
501
638
  }
@@ -531,6 +668,14 @@ class DataProcessor(FeatureSet):
531
668
  fill_na = config["fill_na"]
532
669
  series = series.fillna(fill_na).astype(str)
533
670
  sparse_vocab[name].update(series.tolist())
671
+ if name in sparse_label_counts:
672
+ counts = sparse_label_counts[name]
673
+ for token in series.tolist():
674
+ counts[token] = counts.get(token, 0) + 1
675
+ if name in sparse_hash_counts:
676
+ counts = sparse_hash_counts[name]
677
+ for token in series.tolist():
678
+ counts[token] = counts.get(token, 0) + 1
534
679
  else:
535
680
  separator = config["separator"]
536
681
  tokens = []
@@ -539,6 +684,18 @@ class DataProcessor(FeatureSet):
539
684
  self.extract_sequence_tokens(val, separator)
540
685
  )
541
686
  seq_vocab[name].update(tokens)
687
+ if name in seq_label_counts:
688
+ counts = seq_label_counts[name]
689
+ for token in tokens:
690
+ if str(token).strip():
691
+ key = str(token)
692
+ counts[key] = counts.get(key, 0) + 1
693
+ if name in seq_hash_counts:
694
+ counts = seq_hash_counts[name]
695
+ for token in tokens:
696
+ if str(token).strip():
697
+ key = str(token)
698
+ counts[key] = counts.get(key, 0) + 1
542
699
 
543
700
  # target features
544
701
  missing_features.update(self.target_features.keys() - columns)
@@ -605,7 +762,30 @@ class DataProcessor(FeatureSet):
605
762
  # finalize sparse label encoders
606
763
  for name, config in self.sparse_features.items():
607
764
  if config["encode_method"] == "label":
608
- vocab = sparse_vocab[name]
765
+ min_freq = config.get("min_freq")
766
+ if min_freq is not None:
767
+ token_counts = sparse_label_counts.get(name, {})
768
+ config["_token_counts"] = token_counts
769
+ vocab = {
770
+ token
771
+ for token, count in token_counts.items()
772
+ if count >= min_freq
773
+ }
774
+ low_freq_types = sum(
775
+ 1 for count in token_counts.values() if count < min_freq
776
+ )
777
+ total_types = len(token_counts)
778
+ kept_types = total_types - low_freq_types
779
+ if not config.get("_min_freq_logged"):
780
+ logger.info(
781
+ f"Sparse feature {name} min_freq={min_freq}: "
782
+ f"{total_types} token types total, "
783
+ f"{low_freq_types} low-frequency, "
784
+ f"{kept_types} kept."
785
+ )
786
+ config["_min_freq_logged"] = True
787
+ else:
788
+ vocab = sparse_vocab[name]
609
789
  if not vocab:
610
790
  logger.warning(f"Sparse feature {name} has empty vocabulary")
611
791
  continue
@@ -617,12 +797,55 @@ class DataProcessor(FeatureSet):
617
797
  config["_unk_index"] = token_to_idx["<UNK>"]
618
798
  config["vocab_size"] = len(vocab_list)
619
799
  elif config["encode_method"] == "hash":
800
+ min_freq = config.get("min_freq")
801
+ if min_freq is not None:
802
+ token_counts = sparse_hash_counts.get(name, {})
803
+ config["_token_counts"] = token_counts
804
+ config["_unk_hash"] = self.hash_string(
805
+ "<UNK>", int(config["hash_size"])
806
+ )
807
+ low_freq_types = sum(
808
+ 1 for count in token_counts.values() if count < min_freq
809
+ )
810
+ total_types = len(token_counts)
811
+ kept_types = total_types - low_freq_types
812
+ if not config.get("_min_freq_logged"):
813
+ logger.info(
814
+ f"Sparse feature {name} min_freq={min_freq}: "
815
+ f"{total_types} token types total, "
816
+ f"{low_freq_types} low-frequency, "
817
+ f"{kept_types} kept."
818
+ )
819
+ config["_min_freq_logged"] = True
620
820
  config["vocab_size"] = config["hash_size"]
621
821
 
622
822
  # finalize sequence vocabularies
623
823
  for name, config in self.sequence_features.items():
624
824
  if config["encode_method"] == "label":
625
- vocab_set = seq_vocab[name]
825
+ min_freq = config.get("min_freq")
826
+ if min_freq is not None:
827
+ token_counts = seq_label_counts.get(name, {})
828
+ config["_token_counts"] = token_counts
829
+ vocab_set = {
830
+ token
831
+ for token, count in token_counts.items()
832
+ if count >= min_freq
833
+ }
834
+ low_freq_types = sum(
835
+ 1 for count in token_counts.values() if count < min_freq
836
+ )
837
+ total_types = len(token_counts)
838
+ kept_types = total_types - low_freq_types
839
+ if not config.get("_min_freq_logged"):
840
+ logger.info(
841
+ f"Sequence feature {name} min_freq={min_freq}: "
842
+ f"{total_types} token types total, "
843
+ f"{low_freq_types} low-frequency, "
844
+ f"{kept_types} kept."
845
+ )
846
+ config["_min_freq_logged"] = True
847
+ else:
848
+ vocab_set = seq_vocab[name]
626
849
  vocab_list = sorted(vocab_set) if vocab_set else ["<PAD>"]
627
850
  if "<UNK>" not in vocab_list:
628
851
  vocab_list.append("<UNK>")
@@ -631,6 +854,26 @@ class DataProcessor(FeatureSet):
631
854
  config["_unk_index"] = token_to_idx["<UNK>"]
632
855
  config["vocab_size"] = len(vocab_list)
633
856
  elif config["encode_method"] == "hash":
857
+ min_freq = config.get("min_freq")
858
+ if min_freq is not None:
859
+ token_counts = seq_hash_counts.get(name, {})
860
+ config["_token_counts"] = token_counts
861
+ config["_unk_hash"] = self.hash_string(
862
+ "<UNK>", int(config["hash_size"])
863
+ )
864
+ low_freq_types = sum(
865
+ 1 for count in token_counts.values() if count < min_freq
866
+ )
867
+ total_types = len(token_counts)
868
+ kept_types = total_types - low_freq_types
869
+ if not config.get("_min_freq_logged"):
870
+ logger.info(
871
+ f"Sequence feature {name} min_freq={min_freq}: "
872
+ f"{total_types} token types total, "
873
+ f"{low_freq_types} low-frequency, "
874
+ f"{kept_types} kept."
875
+ )
876
+ config["_min_freq_logged"] = True
634
877
  config["vocab_size"] = config["hash_size"]
635
878
 
636
879
  # finalize targets
@@ -961,6 +1204,10 @@ class DataProcessor(FeatureSet):
961
1204
  """
962
1205
 
963
1206
  logger = logging.getLogger()
1207
+ for config in self.sparse_features.values():
1208
+ config.pop("_min_freq_logged", None)
1209
+ for config in self.sequence_features.values():
1210
+ config.pop("_min_freq_logged", None)
964
1211
  if isinstance(data, (str, os.PathLike)):
965
1212
  path_str = str(data)
966
1213
  uses_robust = any(
nextrec/utils/config.py CHANGED
@@ -116,6 +116,7 @@ def register_processor_features(
116
116
  name,
117
117
  encode_method=proc_cfg.get("encode_method", "hash"),
118
118
  hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
119
+ min_freq=proc_cfg.get("min_freq"),
119
120
  fill_na=proc_cfg.get("fill_na", "<UNK>"),
120
121
  )
121
122
 
@@ -125,6 +126,7 @@ def register_processor_features(
125
126
  name,
126
127
  encode_method=proc_cfg.get("encode_method", "hash"),
127
128
  hash_size=proc_cfg.get("hash_size") or proc_cfg.get("vocab_size"),
129
+ min_freq=proc_cfg.get("min_freq"),
128
130
  max_len=proc_cfg.get("max_len", 50),
129
131
  pad_value=proc_cfg.get("pad_value", 0),
130
132
  truncate=proc_cfg.get("truncate", "post"),