nextrec 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,7 @@
2
2
  DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
3
3
 
4
4
  Date: create on 13/11/2025
5
- Checkpoint: edit on 28/01/2026
5
+ Checkpoint: edit on 29/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -29,12 +29,8 @@ from nextrec.__version__ import __version__
29
29
  from nextrec.basic.features import FeatureSet
30
30
  from nextrec.basic.loggers import colorize
31
31
  from nextrec.basic.session import get_save_path
32
- from nextrec.data.data_processing import hash_md5_mod
33
32
  from nextrec.utils.console import progress
34
33
  from nextrec.utils.data import (
35
- FILE_FORMAT_CONFIG,
36
- check_streaming_support,
37
- default_output_dir,
38
34
  resolve_file_paths,
39
35
  )
40
36
 
@@ -44,32 +40,53 @@ class DataProcessor(FeatureSet):
44
40
  self,
45
41
  hash_cache_size: int = 200_000,
46
42
  ):
43
+ """
44
+ DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
45
+
46
+ Args:
47
+ hash_cache_size (int, optional): Cache size for string hashing. Defaults to 200,000.
48
+ """
47
49
  if not logging.getLogger().hasHandlers():
48
50
  logging.basicConfig(
49
51
  level=logging.INFO,
50
52
  format="%(message)s",
51
53
  )
52
- self.numeric_features: Dict[str, Dict[str, Any]] = {}
53
- self.sparse_features: Dict[str, Dict[str, Any]] = {}
54
- self.sequence_features: Dict[str, Dict[str, Any]] = {}
55
- self.target_features: Dict[str, Dict[str, Any]] = {}
54
+ self.numeric_features = {}
55
+ self.sparse_features = {}
56
+ self.sequence_features = {}
57
+ self.target_features = {}
56
58
  self.version = __version__
57
59
 
58
60
  self.is_fitted = False
59
61
 
60
- self.scalers: Dict[str, Any] = {}
61
- self.label_encoders: Dict[str, Any] = {}
62
- self.target_encoders: Dict[str, Dict[str, int]] = {}
62
+ self.scalers = {}
63
+ self.label_encoders = {}
64
+ self.target_encoders = {}
63
65
  self.set_target_id(target=[], id_columns=[])
64
66
 
65
67
  # cache hash function
66
68
  self.hash_cache_size = int(hash_cache_size)
67
69
  if self.hash_cache_size > 0:
68
70
  self.hash_fn = functools.lru_cache(maxsize=self.hash_cache_size)(
69
- hash_md5_mod
71
+ self.hash_string
72
+ )
73
+ else:
74
+ self.hash_fn = self.hash_string
75
+
76
+ def __getstate__(self):
77
+ state = self.__dict__.copy()
78
+ # lru_cache wrappers on instance fields are not picklable under spawn
79
+ state.pop("hash_fn", None)
80
+ return state
81
+
82
+ def __setstate__(self, state):
83
+ self.__dict__.update(state)
84
+ if self.hash_cache_size > 0:
85
+ self.hash_fn = functools.lru_cache(maxsize=self.hash_cache_size)(
86
+ self.hash_string
70
87
  )
71
88
  else:
72
- self.hash_fn = hash_md5_mod
89
+ self.hash_fn = self.hash_string
73
90
 
74
91
  def add_numeric_feature(
75
92
  self,
@@ -178,8 +195,10 @@ class DataProcessor(FeatureSet):
178
195
  }
179
196
  self.set_target_id(list(self.target_features.keys()), [])
180
197
 
181
- def hash_string(self, s: str, hash_size: int) -> int:
182
- return self.hash_fn(str(s), int(hash_size))
198
+ @staticmethod
199
+ def hash_string(value: str, hash_size: int) -> int:
200
+ hashed = pl.Series([value], dtype=pl.Utf8).hash().cast(pl.UInt64)
201
+ return int(hashed[0]) % int(hash_size)
183
202
 
184
203
  def polars_scan(self, file_paths: list[str], file_type: str):
185
204
  file_type = file_type.lower()
@@ -191,9 +210,7 @@ class DataProcessor(FeatureSet):
191
210
  f"[Data Processor Error] Polars backend only supports csv/parquet, got: {file_type}"
192
211
  )
193
212
 
194
- def sequence_expr(
195
- self, pl, name: str, config: Dict[str, Any], schema: Dict[str, Any]
196
- ):
213
+ def sequence_expr(self, name: str, config: Dict[str, Any], schema: Dict[str, Any]):
197
214
  """
198
215
  generate polars expression for sequence feature processing
199
216
 
@@ -222,7 +239,7 @@ class DataProcessor(FeatureSet):
222
239
  ).list.drop_nulls()
223
240
  return seq_col
224
241
 
225
- def apply_transforms(self, lf, schema: Dict[str, Any], warn_missing: bool):
242
+ def apply_transforms(self, lazy_frame, schema: Dict[str, Any]):
226
243
  """
227
244
  Apply all transformations to a Polars LazyFrame.
228
245
 
@@ -237,20 +254,16 @@ class DataProcessor(FeatureSet):
237
254
  return_dtype=dtype,
238
255
  )
239
256
 
240
- def ensure_present(feature_name: str, label: str) -> bool:
241
- if feature_name not in schema:
242
- if warn_missing:
243
- logger.warning(f"{label} feature {feature_name} not found in data")
244
- return False
245
- return True
246
-
247
257
  # Numeric features
248
258
  for name, config in self.numeric_features.items():
249
- if not ensure_present(name, "Numeric"):
259
+ if name not in schema:
260
+ logger.warning(f"Numeric feature {name} not found in data")
250
261
  continue
251
262
  scaler_type = config["scaler"]
252
263
  fill_na_value = config.get("fill_na_value", 0)
253
264
  col = pl.col(name).cast(pl.Float64).fill_null(fill_na_value)
265
+
266
+ # Apply scaling
254
267
  if scaler_type == "log":
255
268
  col = col.clip(lower_bound=0).log1p()
256
269
  elif scaler_type == "none":
@@ -285,7 +298,8 @@ class DataProcessor(FeatureSet):
285
298
 
286
299
  # Sparse features
287
300
  for name, config in self.sparse_features.items():
288
- if not ensure_present(name, "Sparse"):
301
+ if name not in schema:
302
+ logger.warning(f"Sparse feature {name} not found in data")
289
303
  continue
290
304
  encode_method = config["encode_method"]
291
305
  fill_na = config["fill_na"]
@@ -307,7 +321,7 @@ class DataProcessor(FeatureSet):
307
321
  low_freq = [k for k, v in token_counts.items() if v < min_freq]
308
322
  unk_hash = config.get("_unk_hash")
309
323
  if unk_hash is None:
310
- unk_hash = self.hash_string("<UNK>", int(hash_size))
324
+ unk_hash = self.hash_fn("<UNK>", int(hash_size))
311
325
  hash_expr = (
312
326
  pl.when(col.is_in(low_freq))
313
327
  .then(int(unk_hash))
@@ -318,13 +332,14 @@ class DataProcessor(FeatureSet):
318
332
 
319
333
  # Sequence features
320
334
  for name, config in self.sequence_features.items():
321
- if not ensure_present(name, "Sequence"):
335
+ if name not in schema:
336
+ logger.warning(f"Sequence feature {name} not found in data")
322
337
  continue
323
338
  encode_method = config["encode_method"]
324
339
  max_len = int(config["max_len"])
325
340
  pad_value = int(config["pad_value"])
326
341
  truncate = config["truncate"]
327
- seq_col = self.sequence_expr(pl, name, config, schema)
342
+ seq_col = self.sequence_expr(name, config, schema)
328
343
 
329
344
  if encode_method == "label":
330
345
  token_to_idx = config.get("_token_to_idx")
@@ -350,7 +365,7 @@ class DataProcessor(FeatureSet):
350
365
  low_freq = [k for k, v in token_counts.items() if v < min_freq]
351
366
  unk_hash = config.get("_unk_hash")
352
367
  if unk_hash is None:
353
- unk_hash = self.hash_string("<UNK>", int(hash_size))
368
+ unk_hash = self.hash_fn("<UNK>", int(hash_size))
354
369
  hash_expr = (
355
370
  pl.when(elem.is_in(low_freq))
356
371
  .then(int(unk_hash))
@@ -368,7 +383,8 @@ class DataProcessor(FeatureSet):
368
383
 
369
384
  # Target features
370
385
  for name, config in self.target_features.items():
371
- if not ensure_present(name, "Target"):
386
+ if name not in schema:
387
+ logger.warning(f"Target feature {name} not found in data")
372
388
  continue
373
389
  target_type = config.get("target_type")
374
390
  col = pl.col(name)
@@ -390,8 +406,8 @@ class DataProcessor(FeatureSet):
390
406
  expressions.append(col.alias(name))
391
407
 
392
408
  if not expressions:
393
- return lf
394
- return lf.with_columns(expressions)
409
+ return lazy_frame
410
+ return lazy_frame.with_columns(expressions)
395
411
 
396
412
  def process_target_fit(
397
413
  self, data: Iterable[Any], config: Dict[str, Any], name: str
@@ -401,22 +417,18 @@ class DataProcessor(FeatureSet):
401
417
  if target_type == "binary":
402
418
  if label_map is None:
403
419
  unique_values = {v for v in data if v is not None}
404
- # Filter out None values before sorting to avoid comparison errors
405
420
  sorted_values = sorted(v for v in unique_values if v is not None)
406
- try:
407
- int_values = [int(v) for v in sorted_values]
408
- if int_values == list(range(len(int_values))):
409
- label_map = {str(val): int(val) for val in sorted_values}
410
- else:
411
- label_map = {
412
- str(val): idx for idx, val in enumerate(sorted_values)
413
- }
414
- except (ValueError, TypeError):
421
+
422
+ int_values = [int(v) for v in sorted_values]
423
+ if int_values == list(range(len(int_values))):
424
+ label_map = {str(val): int(val) for val in sorted_values}
425
+ else:
415
426
  label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
427
+
416
428
  config["label_map"] = label_map
417
429
  self.target_encoders[name] = label_map
418
430
 
419
- def polars_fit_from_lazy(self, lf, schema: Dict[str, Any]) -> "DataProcessor":
431
+ def fit_from_lazy(self, lazy_frame, schema: Dict[str, Any]) -> "DataProcessor":
420
432
  logger = logging.getLogger()
421
433
 
422
434
  missing_features = set()
@@ -462,7 +474,11 @@ class DataProcessor(FeatureSet):
462
474
  col.median().alias(f"{name}__median"),
463
475
  ]
464
476
  )
465
- stats = lf.select(agg_exprs).collect().to_dicts()[0] if agg_exprs else {}
477
+ stats = (
478
+ lazy_frame.select(agg_exprs).collect().to_dicts()[0]
479
+ if agg_exprs
480
+ else {}
481
+ )
466
482
  else:
467
483
  stats = {}
468
484
 
@@ -538,7 +554,7 @@ class DataProcessor(FeatureSet):
538
554
  fill_na = config["fill_na"]
539
555
  col = pl.col(name).cast(pl.Utf8).fill_null(fill_na)
540
556
  counts_df = (
541
- lf.select(col.alias(name))
557
+ lazy_frame.select(col.alias(name))
542
558
  .group_by(name)
543
559
  .agg(pl.len().alias("count"))
544
560
  .collect()
@@ -585,7 +601,7 @@ class DataProcessor(FeatureSet):
585
601
  min_freq = config.get("min_freq")
586
602
  if min_freq is not None:
587
603
  config["_token_counts"] = counts
588
- config["_unk_hash"] = self.hash_string(
604
+ config["_unk_hash"] = self.hash_fn(
589
605
  "<UNK>", int(config["hash_size"])
590
606
  )
591
607
  low_freq_types = sum(
@@ -608,9 +624,9 @@ class DataProcessor(FeatureSet):
608
624
  if name not in schema:
609
625
  continue
610
626
  encode_method = config["encode_method"]
611
- seq_col = self.sequence_expr(pl, name, config, schema)
627
+ seq_col = self.sequence_expr(name, config, schema)
612
628
  tokens_df = (
613
- lf.select(seq_col.alias("seq"))
629
+ lazy_frame.select(seq_col.alias("seq"))
614
630
  .explode("seq")
615
631
  .select(pl.col("seq").cast(pl.Utf8).alias("seq"))
616
632
  .drop_nulls("seq")
@@ -661,7 +677,7 @@ class DataProcessor(FeatureSet):
661
677
  min_freq = config.get("min_freq")
662
678
  if min_freq is not None:
663
679
  config["_token_counts"] = counts
664
- config["_unk_hash"] = self.hash_string(
680
+ config["_unk_hash"] = self.hash_fn(
665
681
  "<UNK>", int(config["hash_size"])
666
682
  )
667
683
  low_freq_types = sum(
@@ -685,7 +701,7 @@ class DataProcessor(FeatureSet):
685
701
  continue
686
702
  if config.get("target_type") == "binary":
687
703
  unique_vals = (
688
- lf.select(pl.col(name).drop_nulls().unique())
704
+ lazy_frame.select(pl.col(name).drop_nulls().unique())
689
705
  .collect()
690
706
  .to_series()
691
707
  .to_list()
@@ -715,9 +731,9 @@ class DataProcessor(FeatureSet):
715
731
  config.pop("_min_freq_logged", None)
716
732
  for config in self.sequence_features.values():
717
733
  config.pop("_min_freq_logged", None)
718
- lf = self.polars_scan(file_paths, file_type)
719
- schema = lf.collect_schema()
720
- return self.polars_fit_from_lazy(lf, schema)
734
+ lazy_frame = self.polars_scan(file_paths, file_type)
735
+ schema = lazy_frame.collect_schema()
736
+ return self.fit_from_lazy(lazy_frame, schema)
721
737
 
722
738
  def fit_from_path(self, path: str) -> "DataProcessor":
723
739
  logger = logging.getLogger()
@@ -742,7 +758,6 @@ class DataProcessor(FeatureSet):
742
758
  persist: bool,
743
759
  save_format: Optional[str],
744
760
  output_path: Optional[str],
745
- warn_missing: bool = True,
746
761
  ):
747
762
  logger = logging.getLogger()
748
763
 
@@ -754,16 +769,16 @@ class DataProcessor(FeatureSet):
754
769
  df = data
755
770
 
756
771
  schema = df.schema
757
- lf = df.lazy()
758
- lf = self.apply_transforms(lf, schema, warn_missing=warn_missing)
759
- out_df = lf.collect()
772
+ lazy_frame = df.lazy()
773
+ lazy_frame = self.apply_transforms(lazy_frame, schema)
774
+ out_df = lazy_frame.collect()
760
775
 
761
776
  effective_format = save_format
762
777
  if persist:
763
778
  effective_format = save_format or "parquet"
764
779
 
765
780
  if persist:
766
- if effective_format not in FILE_FORMAT_CONFIG:
781
+ if effective_format not in {"csv", "parquet"}:
767
782
  raise ValueError(f"Unsupported save format: {effective_format}")
768
783
  if output_path is None:
769
784
  raise ValueError(
@@ -773,14 +788,12 @@ class DataProcessor(FeatureSet):
773
788
  if output_dir.suffix:
774
789
  output_dir = output_dir.parent
775
790
  output_dir.mkdir(parents=True, exist_ok=True)
776
- suffix = FILE_FORMAT_CONFIG[effective_format]["extension"][0]
791
+ suffix = ".csv" if effective_format == "csv" else ".parquet"
777
792
  save_path = output_dir / f"transformed_data{suffix}"
778
793
  if effective_format == "csv":
779
794
  out_df.write_csv(save_path)
780
795
  elif effective_format == "parquet":
781
796
  out_df.write_parquet(save_path)
782
- elif effective_format == "feather":
783
- out_df.write_ipc(save_path)
784
797
  else:
785
798
  raise ValueError(
786
799
  f"Format '{effective_format}' is not supported by the polars-only pipeline."
@@ -814,27 +827,28 @@ class DataProcessor(FeatureSet):
814
827
  logger = logging.getLogger()
815
828
  file_paths, file_type = resolve_file_paths(input_path)
816
829
  target_format = save_format or file_type
817
- if target_format not in FILE_FORMAT_CONFIG:
818
- raise ValueError(f"Unsupported format: {target_format}")
819
- if target_format not in {"csv", "parquet", "feather"}:
830
+ if target_format not in {"csv", "parquet"}:
820
831
  raise ValueError(
821
832
  f"Format '{target_format}' is not supported by the polars-only pipeline."
822
833
  )
823
- if not check_streaming_support(file_type):
834
+ if file_type not in {"csv", "parquet"}:
824
835
  raise ValueError(
825
836
  f"Input format '{file_type}' does not support streaming reads. "
826
837
  "Polars backend supports csv/parquet only."
827
838
  )
828
839
 
829
- if not check_streaming_support(target_format):
830
- logger.warning(
831
- f"[Data Processor Warning] Format '{target_format}' does not support streaming writes. "
832
- "Data will be collected in memory before saving."
833
- )
834
-
835
- base_output_dir = (
836
- Path(output_path) if output_path else default_output_dir(input_path)
837
- )
840
+ if output_path:
841
+ base_output_dir = Path(output_path)
842
+ else:
843
+ input_path_obj = Path(input_path)
844
+ if input_path_obj.is_file():
845
+ base_output_dir = (
846
+ input_path_obj.parent / f"{input_path_obj.stem}_preprocessed"
847
+ )
848
+ else:
849
+ base_output_dir = input_path_obj.with_name(
850
+ f"{input_path_obj.name}_preprocessed"
851
+ )
838
852
  if base_output_dir.suffix:
839
853
  base_output_dir = base_output_dir.parent
840
854
  output_root = base_output_dir / "transformed_data"
@@ -843,18 +857,18 @@ class DataProcessor(FeatureSet):
843
857
 
844
858
  for file_path in progress(file_paths, description="Transforming files"):
845
859
  source_path = Path(file_path)
846
- suffix = FILE_FORMAT_CONFIG[target_format]["extension"][0]
860
+ suffix = ".csv" if target_format == "csv" else ".parquet"
847
861
  target_file = output_root / f"{source_path.stem}{suffix}"
848
862
 
849
- lf = self.polars_scan([file_path], file_type)
850
- schema = lf.collect_schema()
851
- lf = self.apply_transforms(lf, schema, warn_missing=True)
863
+ lazy_frame = self.polars_scan([file_path], file_type)
864
+ schema = lazy_frame.collect_schema()
865
+ lazy_frame = self.apply_transforms(lazy_frame, schema)
852
866
 
853
867
  if target_format == "parquet":
854
- lf.sink_parquet(target_file)
868
+ lazy_frame.sink_parquet(target_file)
855
869
  elif target_format == "csv":
856
870
  # CSV doesn't support nested data (lists), so convert list columns to string
857
- transformed_schema = lf.collect_schema()
871
+ transformed_schema = lazy_frame.collect_schema()
858
872
  list_cols = [
859
873
  name
860
874
  for name, dtype in transformed_schema.items()
@@ -875,11 +889,12 @@ class DataProcessor(FeatureSet):
875
889
  + pl.lit("]")
876
890
  ).alias(name)
877
891
  )
878
- lf = lf.with_columns(list_exprs)
879
- lf.sink_csv(target_file)
892
+ lazy_frame = lazy_frame.with_columns(list_exprs)
893
+ lazy_frame.sink_csv(target_file)
880
894
  else:
881
- df = lf.collect()
882
- df.write_ipc(target_file)
895
+ raise ValueError(
896
+ f"Format '{target_format}' is not supported by the polars-only pipeline."
897
+ )
883
898
  saved_paths.append(str(target_file.resolve()))
884
899
 
885
900
  logger.info(
@@ -917,9 +932,9 @@ class DataProcessor(FeatureSet):
917
932
  df = pl.from_pandas(data)
918
933
  else:
919
934
  df = data
920
- lf = df.lazy()
935
+ lazy_frame = df.lazy()
921
936
  schema = df.schema
922
- return self.polars_fit_from_lazy(lf, schema)
937
+ return self.fit_from_lazy(lazy_frame, schema)
923
938
 
924
939
  @overload
925
940
  def transform(
@@ -985,6 +1000,33 @@ class DataProcessor(FeatureSet):
985
1000
  output_path=output_path,
986
1001
  )
987
1002
 
1003
+ @overload
1004
+ def fit_transform(
1005
+ self,
1006
+ data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
1007
+ return_dict: Literal[True] = True,
1008
+ save_format: Optional[str] = None,
1009
+ output_path: Optional[str] = None,
1010
+ ) -> Dict[str, np.ndarray]: ...
1011
+
1012
+ @overload
1013
+ def fit_transform(
1014
+ self,
1015
+ data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
1016
+ return_dict: Literal[False] = False,
1017
+ save_format: Optional[str] = None,
1018
+ output_path: Optional[str] = None,
1019
+ ) -> pl.DataFrame: ...
1020
+
1021
+ @overload
1022
+ def fit_transform(
1023
+ self,
1024
+ data: str | os.PathLike,
1025
+ return_dict: Literal[False] = False,
1026
+ save_format: Optional[str] = None,
1027
+ output_path: Optional[str] = None,
1028
+ ) -> list[str]: ...
1029
+
988
1030
  def fit_transform(
989
1031
  self,
990
1032
  data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike],
@@ -1005,9 +1047,16 @@ class DataProcessor(FeatureSet):
1005
1047
  """
1006
1048
 
1007
1049
  self.fit(data)
1008
- return self.transform(
1009
- data,
1050
+ if isinstance(data, (str, os.PathLike)):
1051
+ if return_dict:
1052
+ raise ValueError(
1053
+ "[Data Processor Error] Path transform writes files only; set return_dict=False when passing a path."
1054
+ )
1055
+ return self.transform_path(str(data), output_path, save_format)
1056
+ return self.transform_in_memory(
1057
+ data=data,
1010
1058
  return_dict=return_dict,
1059
+ persist=output_path is not None,
1011
1060
  save_format=save_format,
1012
1061
  output_path=output_path,
1013
1062
  )
nextrec/loss/__init__.py CHANGED
@@ -1,36 +0,0 @@
1
- from nextrec.loss.listwise import (
2
- ApproxNDCGLoss,
3
- InfoNCELoss,
4
- ListMLELoss,
5
- ListNetLoss,
6
- SampledSoftmaxLoss,
7
- )
8
- from nextrec.loss.grad_norm import GradNormLossWeighting
9
- from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
10
- from nextrec.loss.pointwise import (
11
- ClassBalancedFocalLoss,
12
- CosineContrastiveLoss,
13
- FocalLoss,
14
- WeightedBCELoss,
15
- )
16
-
17
- __all__ = [
18
- # Pointwise
19
- "CosineContrastiveLoss",
20
- "WeightedBCELoss",
21
- "FocalLoss",
22
- "ClassBalancedFocalLoss",
23
- # Pairwise
24
- "BPRLoss",
25
- "HingeLoss",
26
- "TripletLoss",
27
- # Listwise
28
- "SampledSoftmaxLoss",
29
- "InfoNCELoss",
30
- "ListNetLoss",
31
- "ListMLELoss",
32
- "ApproxNDCGLoss",
33
- # Multi-task weighting
34
- "GradNormLossWeighting",
35
- # Utilities
36
- ]
@@ -1,9 +0,0 @@
1
- """
2
- Generative Recommendation Models
3
-
4
- This module contains generative models for recommendation tasks.
5
- """
6
-
7
- from nextrec.models.sequential.hstu import HSTU
8
-
9
- __all__ = ["HSTU"]
@@ -1,15 +0,0 @@
1
- """
2
- Tree-based models for NextRec.
3
- """
4
-
5
- from nextrec.models.tree_base.base import TreeBaseModel
6
- from nextrec.models.tree_base.catboost import Catboost
7
- from nextrec.models.tree_base.lightgbm import Lightgbm
8
- from nextrec.models.tree_base.xgboost import Xgboost
9
-
10
- __all__ = [
11
- "TreeBaseModel",
12
- "Xgboost",
13
- "Lightgbm",
14
- "Catboost",
15
- ]