nextrec 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/model.py +288 -181
- nextrec/basic/summary.py +21 -4
- nextrec/cli.py +36 -17
- nextrec/data/__init__.py +0 -52
- nextrec/data/batch_utils.py +1 -1
- nextrec/data/data_processing.py +1 -35
- nextrec/data/data_utils.py +0 -4
- nextrec/data/dataloader.py +125 -103
- nextrec/data/preprocessor.py +141 -92
- nextrec/loss/__init__.py +0 -36
- nextrec/models/generative/__init__.py +0 -9
- nextrec/models/tree_base/__init__.py +0 -15
- nextrec/models/tree_base/base.py +14 -23
- nextrec/utils/__init__.py +0 -119
- nextrec/utils/data.py +39 -119
- nextrec/utils/model.py +5 -14
- nextrec/utils/torch_utils.py +6 -1
- {nextrec-0.5.0.dist-info → nextrec-0.5.2.dist-info}/METADATA +4 -5
- {nextrec-0.5.0.dist-info → nextrec-0.5.2.dist-info}/RECORD +23 -23
- {nextrec-0.5.0.dist-info → nextrec-0.5.2.dist-info}/WHEEL +0 -0
- {nextrec-0.5.0.dist-info → nextrec-0.5.2.dist-info}/entry_points.txt +0 -0
- {nextrec-0.5.0.dist-info → nextrec-0.5.2.dist-info}/licenses/LICENSE +0 -0
nextrec/data/preprocessor.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
|
|
3
3
|
|
|
4
4
|
Date: create on 13/11/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 29/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -29,12 +29,8 @@ from nextrec.__version__ import __version__
|
|
|
29
29
|
from nextrec.basic.features import FeatureSet
|
|
30
30
|
from nextrec.basic.loggers import colorize
|
|
31
31
|
from nextrec.basic.session import get_save_path
|
|
32
|
-
from nextrec.data.data_processing import hash_md5_mod
|
|
33
32
|
from nextrec.utils.console import progress
|
|
34
33
|
from nextrec.utils.data import (
|
|
35
|
-
FILE_FORMAT_CONFIG,
|
|
36
|
-
check_streaming_support,
|
|
37
|
-
default_output_dir,
|
|
38
34
|
resolve_file_paths,
|
|
39
35
|
)
|
|
40
36
|
|
|
@@ -44,32 +40,53 @@ class DataProcessor(FeatureSet):
|
|
|
44
40
|
self,
|
|
45
41
|
hash_cache_size: int = 200_000,
|
|
46
42
|
):
|
|
43
|
+
"""
|
|
44
|
+
DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
hash_cache_size (int, optional): Cache size for string hashing. Defaults to 200,000.
|
|
48
|
+
"""
|
|
47
49
|
if not logging.getLogger().hasHandlers():
|
|
48
50
|
logging.basicConfig(
|
|
49
51
|
level=logging.INFO,
|
|
50
52
|
format="%(message)s",
|
|
51
53
|
)
|
|
52
|
-
self.numeric_features
|
|
53
|
-
self.sparse_features
|
|
54
|
-
self.sequence_features
|
|
55
|
-
self.target_features
|
|
54
|
+
self.numeric_features = {}
|
|
55
|
+
self.sparse_features = {}
|
|
56
|
+
self.sequence_features = {}
|
|
57
|
+
self.target_features = {}
|
|
56
58
|
self.version = __version__
|
|
57
59
|
|
|
58
60
|
self.is_fitted = False
|
|
59
61
|
|
|
60
|
-
self.scalers
|
|
61
|
-
self.label_encoders
|
|
62
|
-
self.target_encoders
|
|
62
|
+
self.scalers = {}
|
|
63
|
+
self.label_encoders = {}
|
|
64
|
+
self.target_encoders = {}
|
|
63
65
|
self.set_target_id(target=[], id_columns=[])
|
|
64
66
|
|
|
65
67
|
# cache hash function
|
|
66
68
|
self.hash_cache_size = int(hash_cache_size)
|
|
67
69
|
if self.hash_cache_size > 0:
|
|
68
70
|
self.hash_fn = functools.lru_cache(maxsize=self.hash_cache_size)(
|
|
69
|
-
|
|
71
|
+
self.hash_string
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
self.hash_fn = self.hash_string
|
|
75
|
+
|
|
76
|
+
def __getstate__(self):
|
|
77
|
+
state = self.__dict__.copy()
|
|
78
|
+
# lru_cache wrappers on instance fields are not picklable under spawn
|
|
79
|
+
state.pop("hash_fn", None)
|
|
80
|
+
return state
|
|
81
|
+
|
|
82
|
+
def __setstate__(self, state):
|
|
83
|
+
self.__dict__.update(state)
|
|
84
|
+
if self.hash_cache_size > 0:
|
|
85
|
+
self.hash_fn = functools.lru_cache(maxsize=self.hash_cache_size)(
|
|
86
|
+
self.hash_string
|
|
70
87
|
)
|
|
71
88
|
else:
|
|
72
|
-
self.hash_fn =
|
|
89
|
+
self.hash_fn = self.hash_string
|
|
73
90
|
|
|
74
91
|
def add_numeric_feature(
|
|
75
92
|
self,
|
|
@@ -178,8 +195,10 @@ class DataProcessor(FeatureSet):
|
|
|
178
195
|
}
|
|
179
196
|
self.set_target_id(list(self.target_features.keys()), [])
|
|
180
197
|
|
|
181
|
-
|
|
182
|
-
|
|
198
|
+
@staticmethod
|
|
199
|
+
def hash_string(value: str, hash_size: int) -> int:
|
|
200
|
+
hashed = pl.Series([value], dtype=pl.Utf8).hash().cast(pl.UInt64)
|
|
201
|
+
return int(hashed[0]) % int(hash_size)
|
|
183
202
|
|
|
184
203
|
def polars_scan(self, file_paths: list[str], file_type: str):
|
|
185
204
|
file_type = file_type.lower()
|
|
@@ -191,9 +210,7 @@ class DataProcessor(FeatureSet):
|
|
|
191
210
|
f"[Data Processor Error] Polars backend only supports csv/parquet, got: {file_type}"
|
|
192
211
|
)
|
|
193
212
|
|
|
194
|
-
def sequence_expr(
|
|
195
|
-
self, pl, name: str, config: Dict[str, Any], schema: Dict[str, Any]
|
|
196
|
-
):
|
|
213
|
+
def sequence_expr(self, name: str, config: Dict[str, Any], schema: Dict[str, Any]):
|
|
197
214
|
"""
|
|
198
215
|
generate polars expression for sequence feature processing
|
|
199
216
|
|
|
@@ -222,7 +239,7 @@ class DataProcessor(FeatureSet):
|
|
|
222
239
|
).list.drop_nulls()
|
|
223
240
|
return seq_col
|
|
224
241
|
|
|
225
|
-
def apply_transforms(self,
|
|
242
|
+
def apply_transforms(self, lazy_frame, schema: Dict[str, Any]):
|
|
226
243
|
"""
|
|
227
244
|
Apply all transformations to a Polars LazyFrame.
|
|
228
245
|
|
|
@@ -237,20 +254,16 @@ class DataProcessor(FeatureSet):
|
|
|
237
254
|
return_dtype=dtype,
|
|
238
255
|
)
|
|
239
256
|
|
|
240
|
-
def ensure_present(feature_name: str, label: str) -> bool:
|
|
241
|
-
if feature_name not in schema:
|
|
242
|
-
if warn_missing:
|
|
243
|
-
logger.warning(f"{label} feature {feature_name} not found in data")
|
|
244
|
-
return False
|
|
245
|
-
return True
|
|
246
|
-
|
|
247
257
|
# Numeric features
|
|
248
258
|
for name, config in self.numeric_features.items():
|
|
249
|
-
if not
|
|
259
|
+
if name not in schema:
|
|
260
|
+
logger.warning(f"Numeric feature {name} not found in data")
|
|
250
261
|
continue
|
|
251
262
|
scaler_type = config["scaler"]
|
|
252
263
|
fill_na_value = config.get("fill_na_value", 0)
|
|
253
264
|
col = pl.col(name).cast(pl.Float64).fill_null(fill_na_value)
|
|
265
|
+
|
|
266
|
+
# Apply scaling
|
|
254
267
|
if scaler_type == "log":
|
|
255
268
|
col = col.clip(lower_bound=0).log1p()
|
|
256
269
|
elif scaler_type == "none":
|
|
@@ -285,7 +298,8 @@ class DataProcessor(FeatureSet):
|
|
|
285
298
|
|
|
286
299
|
# Sparse features
|
|
287
300
|
for name, config in self.sparse_features.items():
|
|
288
|
-
if not
|
|
301
|
+
if name not in schema:
|
|
302
|
+
logger.warning(f"Sparse feature {name} not found in data")
|
|
289
303
|
continue
|
|
290
304
|
encode_method = config["encode_method"]
|
|
291
305
|
fill_na = config["fill_na"]
|
|
@@ -307,7 +321,7 @@ class DataProcessor(FeatureSet):
|
|
|
307
321
|
low_freq = [k for k, v in token_counts.items() if v < min_freq]
|
|
308
322
|
unk_hash = config.get("_unk_hash")
|
|
309
323
|
if unk_hash is None:
|
|
310
|
-
unk_hash = self.
|
|
324
|
+
unk_hash = self.hash_fn("<UNK>", int(hash_size))
|
|
311
325
|
hash_expr = (
|
|
312
326
|
pl.when(col.is_in(low_freq))
|
|
313
327
|
.then(int(unk_hash))
|
|
@@ -318,13 +332,14 @@ class DataProcessor(FeatureSet):
|
|
|
318
332
|
|
|
319
333
|
# Sequence features
|
|
320
334
|
for name, config in self.sequence_features.items():
|
|
321
|
-
if not
|
|
335
|
+
if name not in schema:
|
|
336
|
+
logger.warning(f"Sequence feature {name} not found in data")
|
|
322
337
|
continue
|
|
323
338
|
encode_method = config["encode_method"]
|
|
324
339
|
max_len = int(config["max_len"])
|
|
325
340
|
pad_value = int(config["pad_value"])
|
|
326
341
|
truncate = config["truncate"]
|
|
327
|
-
seq_col = self.sequence_expr(
|
|
342
|
+
seq_col = self.sequence_expr(name, config, schema)
|
|
328
343
|
|
|
329
344
|
if encode_method == "label":
|
|
330
345
|
token_to_idx = config.get("_token_to_idx")
|
|
@@ -350,7 +365,7 @@ class DataProcessor(FeatureSet):
|
|
|
350
365
|
low_freq = [k for k, v in token_counts.items() if v < min_freq]
|
|
351
366
|
unk_hash = config.get("_unk_hash")
|
|
352
367
|
if unk_hash is None:
|
|
353
|
-
unk_hash = self.
|
|
368
|
+
unk_hash = self.hash_fn("<UNK>", int(hash_size))
|
|
354
369
|
hash_expr = (
|
|
355
370
|
pl.when(elem.is_in(low_freq))
|
|
356
371
|
.then(int(unk_hash))
|
|
@@ -368,7 +383,8 @@ class DataProcessor(FeatureSet):
|
|
|
368
383
|
|
|
369
384
|
# Target features
|
|
370
385
|
for name, config in self.target_features.items():
|
|
371
|
-
if not
|
|
386
|
+
if name not in schema:
|
|
387
|
+
logger.warning(f"Target feature {name} not found in data")
|
|
372
388
|
continue
|
|
373
389
|
target_type = config.get("target_type")
|
|
374
390
|
col = pl.col(name)
|
|
@@ -390,8 +406,8 @@ class DataProcessor(FeatureSet):
|
|
|
390
406
|
expressions.append(col.alias(name))
|
|
391
407
|
|
|
392
408
|
if not expressions:
|
|
393
|
-
return
|
|
394
|
-
return
|
|
409
|
+
return lazy_frame
|
|
410
|
+
return lazy_frame.with_columns(expressions)
|
|
395
411
|
|
|
396
412
|
def process_target_fit(
|
|
397
413
|
self, data: Iterable[Any], config: Dict[str, Any], name: str
|
|
@@ -401,22 +417,18 @@ class DataProcessor(FeatureSet):
|
|
|
401
417
|
if target_type == "binary":
|
|
402
418
|
if label_map is None:
|
|
403
419
|
unique_values = {v for v in data if v is not None}
|
|
404
|
-
# Filter out None values before sorting to avoid comparison errors
|
|
405
420
|
sorted_values = sorted(v for v in unique_values if v is not None)
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
label_map = {
|
|
412
|
-
str(val): idx for idx, val in enumerate(sorted_values)
|
|
413
|
-
}
|
|
414
|
-
except (ValueError, TypeError):
|
|
421
|
+
|
|
422
|
+
int_values = [int(v) for v in sorted_values]
|
|
423
|
+
if int_values == list(range(len(int_values))):
|
|
424
|
+
label_map = {str(val): int(val) for val in sorted_values}
|
|
425
|
+
else:
|
|
415
426
|
label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
|
|
427
|
+
|
|
416
428
|
config["label_map"] = label_map
|
|
417
429
|
self.target_encoders[name] = label_map
|
|
418
430
|
|
|
419
|
-
def
|
|
431
|
+
def fit_from_lazy(self, lazy_frame, schema: Dict[str, Any]) -> "DataProcessor":
|
|
420
432
|
logger = logging.getLogger()
|
|
421
433
|
|
|
422
434
|
missing_features = set()
|
|
@@ -462,7 +474,11 @@ class DataProcessor(FeatureSet):
|
|
|
462
474
|
col.median().alias(f"{name}__median"),
|
|
463
475
|
]
|
|
464
476
|
)
|
|
465
|
-
stats =
|
|
477
|
+
stats = (
|
|
478
|
+
lazy_frame.select(agg_exprs).collect().to_dicts()[0]
|
|
479
|
+
if agg_exprs
|
|
480
|
+
else {}
|
|
481
|
+
)
|
|
466
482
|
else:
|
|
467
483
|
stats = {}
|
|
468
484
|
|
|
@@ -538,7 +554,7 @@ class DataProcessor(FeatureSet):
|
|
|
538
554
|
fill_na = config["fill_na"]
|
|
539
555
|
col = pl.col(name).cast(pl.Utf8).fill_null(fill_na)
|
|
540
556
|
counts_df = (
|
|
541
|
-
|
|
557
|
+
lazy_frame.select(col.alias(name))
|
|
542
558
|
.group_by(name)
|
|
543
559
|
.agg(pl.len().alias("count"))
|
|
544
560
|
.collect()
|
|
@@ -585,7 +601,7 @@ class DataProcessor(FeatureSet):
|
|
|
585
601
|
min_freq = config.get("min_freq")
|
|
586
602
|
if min_freq is not None:
|
|
587
603
|
config["_token_counts"] = counts
|
|
588
|
-
config["_unk_hash"] = self.
|
|
604
|
+
config["_unk_hash"] = self.hash_fn(
|
|
589
605
|
"<UNK>", int(config["hash_size"])
|
|
590
606
|
)
|
|
591
607
|
low_freq_types = sum(
|
|
@@ -608,9 +624,9 @@ class DataProcessor(FeatureSet):
|
|
|
608
624
|
if name not in schema:
|
|
609
625
|
continue
|
|
610
626
|
encode_method = config["encode_method"]
|
|
611
|
-
seq_col = self.sequence_expr(
|
|
627
|
+
seq_col = self.sequence_expr(name, config, schema)
|
|
612
628
|
tokens_df = (
|
|
613
|
-
|
|
629
|
+
lazy_frame.select(seq_col.alias("seq"))
|
|
614
630
|
.explode("seq")
|
|
615
631
|
.select(pl.col("seq").cast(pl.Utf8).alias("seq"))
|
|
616
632
|
.drop_nulls("seq")
|
|
@@ -661,7 +677,7 @@ class DataProcessor(FeatureSet):
|
|
|
661
677
|
min_freq = config.get("min_freq")
|
|
662
678
|
if min_freq is not None:
|
|
663
679
|
config["_token_counts"] = counts
|
|
664
|
-
config["_unk_hash"] = self.
|
|
680
|
+
config["_unk_hash"] = self.hash_fn(
|
|
665
681
|
"<UNK>", int(config["hash_size"])
|
|
666
682
|
)
|
|
667
683
|
low_freq_types = sum(
|
|
@@ -685,7 +701,7 @@ class DataProcessor(FeatureSet):
|
|
|
685
701
|
continue
|
|
686
702
|
if config.get("target_type") == "binary":
|
|
687
703
|
unique_vals = (
|
|
688
|
-
|
|
704
|
+
lazy_frame.select(pl.col(name).drop_nulls().unique())
|
|
689
705
|
.collect()
|
|
690
706
|
.to_series()
|
|
691
707
|
.to_list()
|
|
@@ -715,9 +731,9 @@ class DataProcessor(FeatureSet):
|
|
|
715
731
|
config.pop("_min_freq_logged", None)
|
|
716
732
|
for config in self.sequence_features.values():
|
|
717
733
|
config.pop("_min_freq_logged", None)
|
|
718
|
-
|
|
719
|
-
schema =
|
|
720
|
-
return self.
|
|
734
|
+
lazy_frame = self.polars_scan(file_paths, file_type)
|
|
735
|
+
schema = lazy_frame.collect_schema()
|
|
736
|
+
return self.fit_from_lazy(lazy_frame, schema)
|
|
721
737
|
|
|
722
738
|
def fit_from_path(self, path: str) -> "DataProcessor":
|
|
723
739
|
logger = logging.getLogger()
|
|
@@ -742,7 +758,6 @@ class DataProcessor(FeatureSet):
|
|
|
742
758
|
persist: bool,
|
|
743
759
|
save_format: Optional[str],
|
|
744
760
|
output_path: Optional[str],
|
|
745
|
-
warn_missing: bool = True,
|
|
746
761
|
):
|
|
747
762
|
logger = logging.getLogger()
|
|
748
763
|
|
|
@@ -754,16 +769,16 @@ class DataProcessor(FeatureSet):
|
|
|
754
769
|
df = data
|
|
755
770
|
|
|
756
771
|
schema = df.schema
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
out_df =
|
|
772
|
+
lazy_frame = df.lazy()
|
|
773
|
+
lazy_frame = self.apply_transforms(lazy_frame, schema)
|
|
774
|
+
out_df = lazy_frame.collect()
|
|
760
775
|
|
|
761
776
|
effective_format = save_format
|
|
762
777
|
if persist:
|
|
763
778
|
effective_format = save_format or "parquet"
|
|
764
779
|
|
|
765
780
|
if persist:
|
|
766
|
-
if effective_format not in
|
|
781
|
+
if effective_format not in {"csv", "parquet"}:
|
|
767
782
|
raise ValueError(f"Unsupported save format: {effective_format}")
|
|
768
783
|
if output_path is None:
|
|
769
784
|
raise ValueError(
|
|
@@ -773,14 +788,12 @@ class DataProcessor(FeatureSet):
|
|
|
773
788
|
if output_dir.suffix:
|
|
774
789
|
output_dir = output_dir.parent
|
|
775
790
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
776
|
-
suffix =
|
|
791
|
+
suffix = ".csv" if effective_format == "csv" else ".parquet"
|
|
777
792
|
save_path = output_dir / f"transformed_data{suffix}"
|
|
778
793
|
if effective_format == "csv":
|
|
779
794
|
out_df.write_csv(save_path)
|
|
780
795
|
elif effective_format == "parquet":
|
|
781
796
|
out_df.write_parquet(save_path)
|
|
782
|
-
elif effective_format == "feather":
|
|
783
|
-
out_df.write_ipc(save_path)
|
|
784
797
|
else:
|
|
785
798
|
raise ValueError(
|
|
786
799
|
f"Format '{effective_format}' is not supported by the polars-only pipeline."
|
|
@@ -814,27 +827,28 @@ class DataProcessor(FeatureSet):
|
|
|
814
827
|
logger = logging.getLogger()
|
|
815
828
|
file_paths, file_type = resolve_file_paths(input_path)
|
|
816
829
|
target_format = save_format or file_type
|
|
817
|
-
if target_format not in
|
|
818
|
-
raise ValueError(f"Unsupported format: {target_format}")
|
|
819
|
-
if target_format not in {"csv", "parquet", "feather"}:
|
|
830
|
+
if target_format not in {"csv", "parquet"}:
|
|
820
831
|
raise ValueError(
|
|
821
832
|
f"Format '{target_format}' is not supported by the polars-only pipeline."
|
|
822
833
|
)
|
|
823
|
-
if not
|
|
834
|
+
if file_type not in {"csv", "parquet"}:
|
|
824
835
|
raise ValueError(
|
|
825
836
|
f"Input format '{file_type}' does not support streaming reads. "
|
|
826
837
|
"Polars backend supports csv/parquet only."
|
|
827
838
|
)
|
|
828
839
|
|
|
829
|
-
if
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
)
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
840
|
+
if output_path:
|
|
841
|
+
base_output_dir = Path(output_path)
|
|
842
|
+
else:
|
|
843
|
+
input_path_obj = Path(input_path)
|
|
844
|
+
if input_path_obj.is_file():
|
|
845
|
+
base_output_dir = (
|
|
846
|
+
input_path_obj.parent / f"{input_path_obj.stem}_preprocessed"
|
|
847
|
+
)
|
|
848
|
+
else:
|
|
849
|
+
base_output_dir = input_path_obj.with_name(
|
|
850
|
+
f"{input_path_obj.name}_preprocessed"
|
|
851
|
+
)
|
|
838
852
|
if base_output_dir.suffix:
|
|
839
853
|
base_output_dir = base_output_dir.parent
|
|
840
854
|
output_root = base_output_dir / "transformed_data"
|
|
@@ -843,18 +857,18 @@ class DataProcessor(FeatureSet):
|
|
|
843
857
|
|
|
844
858
|
for file_path in progress(file_paths, description="Transforming files"):
|
|
845
859
|
source_path = Path(file_path)
|
|
846
|
-
suffix =
|
|
860
|
+
suffix = ".csv" if target_format == "csv" else ".parquet"
|
|
847
861
|
target_file = output_root / f"{source_path.stem}{suffix}"
|
|
848
862
|
|
|
849
|
-
|
|
850
|
-
schema =
|
|
851
|
-
|
|
863
|
+
lazy_frame = self.polars_scan([file_path], file_type)
|
|
864
|
+
schema = lazy_frame.collect_schema()
|
|
865
|
+
lazy_frame = self.apply_transforms(lazy_frame, schema)
|
|
852
866
|
|
|
853
867
|
if target_format == "parquet":
|
|
854
|
-
|
|
868
|
+
lazy_frame.sink_parquet(target_file)
|
|
855
869
|
elif target_format == "csv":
|
|
856
870
|
# CSV doesn't support nested data (lists), so convert list columns to string
|
|
857
|
-
transformed_schema =
|
|
871
|
+
transformed_schema = lazy_frame.collect_schema()
|
|
858
872
|
list_cols = [
|
|
859
873
|
name
|
|
860
874
|
for name, dtype in transformed_schema.items()
|
|
@@ -875,11 +889,12 @@ class DataProcessor(FeatureSet):
|
|
|
875
889
|
+ pl.lit("]")
|
|
876
890
|
).alias(name)
|
|
877
891
|
)
|
|
878
|
-
|
|
879
|
-
|
|
892
|
+
lazy_frame = lazy_frame.with_columns(list_exprs)
|
|
893
|
+
lazy_frame.sink_csv(target_file)
|
|
880
894
|
else:
|
|
881
|
-
|
|
882
|
-
|
|
895
|
+
raise ValueError(
|
|
896
|
+
f"Format '{target_format}' is not supported by the polars-only pipeline."
|
|
897
|
+
)
|
|
883
898
|
saved_paths.append(str(target_file.resolve()))
|
|
884
899
|
|
|
885
900
|
logger.info(
|
|
@@ -917,9 +932,9 @@ class DataProcessor(FeatureSet):
|
|
|
917
932
|
df = pl.from_pandas(data)
|
|
918
933
|
else:
|
|
919
934
|
df = data
|
|
920
|
-
|
|
935
|
+
lazy_frame = df.lazy()
|
|
921
936
|
schema = df.schema
|
|
922
|
-
return self.
|
|
937
|
+
return self.fit_from_lazy(lazy_frame, schema)
|
|
923
938
|
|
|
924
939
|
@overload
|
|
925
940
|
def transform(
|
|
@@ -985,6 +1000,33 @@ class DataProcessor(FeatureSet):
|
|
|
985
1000
|
output_path=output_path,
|
|
986
1001
|
)
|
|
987
1002
|
|
|
1003
|
+
@overload
|
|
1004
|
+
def fit_transform(
|
|
1005
|
+
self,
|
|
1006
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
|
|
1007
|
+
return_dict: Literal[True] = True,
|
|
1008
|
+
save_format: Optional[str] = None,
|
|
1009
|
+
output_path: Optional[str] = None,
|
|
1010
|
+
) -> Dict[str, np.ndarray]: ...
|
|
1011
|
+
|
|
1012
|
+
@overload
|
|
1013
|
+
def fit_transform(
|
|
1014
|
+
self,
|
|
1015
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
|
|
1016
|
+
return_dict: Literal[False] = False,
|
|
1017
|
+
save_format: Optional[str] = None,
|
|
1018
|
+
output_path: Optional[str] = None,
|
|
1019
|
+
) -> pl.DataFrame: ...
|
|
1020
|
+
|
|
1021
|
+
@overload
|
|
1022
|
+
def fit_transform(
|
|
1023
|
+
self,
|
|
1024
|
+
data: str | os.PathLike,
|
|
1025
|
+
return_dict: Literal[False] = False,
|
|
1026
|
+
save_format: Optional[str] = None,
|
|
1027
|
+
output_path: Optional[str] = None,
|
|
1028
|
+
) -> list[str]: ...
|
|
1029
|
+
|
|
988
1030
|
def fit_transform(
|
|
989
1031
|
self,
|
|
990
1032
|
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
@@ -1005,9 +1047,16 @@ class DataProcessor(FeatureSet):
|
|
|
1005
1047
|
"""
|
|
1006
1048
|
|
|
1007
1049
|
self.fit(data)
|
|
1008
|
-
|
|
1009
|
-
|
|
1050
|
+
if isinstance(data, (str, os.PathLike)):
|
|
1051
|
+
if return_dict:
|
|
1052
|
+
raise ValueError(
|
|
1053
|
+
"[Data Processor Error] Path transform writes files only; set return_dict=False when passing a path."
|
|
1054
|
+
)
|
|
1055
|
+
return self.transform_path(str(data), output_path, save_format)
|
|
1056
|
+
return self.transform_in_memory(
|
|
1057
|
+
data=data,
|
|
1010
1058
|
return_dict=return_dict,
|
|
1059
|
+
persist=output_path is not None,
|
|
1011
1060
|
save_format=save_format,
|
|
1012
1061
|
output_path=output_path,
|
|
1013
1062
|
)
|
nextrec/loss/__init__.py
CHANGED
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
from nextrec.loss.listwise import (
|
|
2
|
-
ApproxNDCGLoss,
|
|
3
|
-
InfoNCELoss,
|
|
4
|
-
ListMLELoss,
|
|
5
|
-
ListNetLoss,
|
|
6
|
-
SampledSoftmaxLoss,
|
|
7
|
-
)
|
|
8
|
-
from nextrec.loss.grad_norm import GradNormLossWeighting
|
|
9
|
-
from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
|
|
10
|
-
from nextrec.loss.pointwise import (
|
|
11
|
-
ClassBalancedFocalLoss,
|
|
12
|
-
CosineContrastiveLoss,
|
|
13
|
-
FocalLoss,
|
|
14
|
-
WeightedBCELoss,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
__all__ = [
|
|
18
|
-
# Pointwise
|
|
19
|
-
"CosineContrastiveLoss",
|
|
20
|
-
"WeightedBCELoss",
|
|
21
|
-
"FocalLoss",
|
|
22
|
-
"ClassBalancedFocalLoss",
|
|
23
|
-
# Pairwise
|
|
24
|
-
"BPRLoss",
|
|
25
|
-
"HingeLoss",
|
|
26
|
-
"TripletLoss",
|
|
27
|
-
# Listwise
|
|
28
|
-
"SampledSoftmaxLoss",
|
|
29
|
-
"InfoNCELoss",
|
|
30
|
-
"ListNetLoss",
|
|
31
|
-
"ListMLELoss",
|
|
32
|
-
"ApproxNDCGLoss",
|
|
33
|
-
# Multi-task weighting
|
|
34
|
-
"GradNormLossWeighting",
|
|
35
|
-
# Utilities
|
|
36
|
-
]
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Tree-based models for NextRec.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from nextrec.models.tree_base.base import TreeBaseModel
|
|
6
|
-
from nextrec.models.tree_base.catboost import Catboost
|
|
7
|
-
from nextrec.models.tree_base.lightgbm import Lightgbm
|
|
8
|
-
from nextrec.models.tree_base.xgboost import Xgboost
|
|
9
|
-
|
|
10
|
-
__all__ = [
|
|
11
|
-
"TreeBaseModel",
|
|
12
|
-
"Xgboost",
|
|
13
|
-
"Lightgbm",
|
|
14
|
-
"Catboost",
|
|
15
|
-
]
|