replay-rec 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +27 -1
- replay/data/dataset_utils/dataset_label_encoder.py +6 -3
- replay/data/nn/schema.py +37 -16
- replay/data/nn/sequence_tokenizer.py +313 -165
- replay/data/nn/torch_sequential_dataset.py +17 -8
- replay/data/nn/utils.py +14 -7
- replay/data/schema.py +10 -6
- replay/metrics/offline_metrics.py +2 -2
- replay/models/__init__.py +1 -0
- replay/models/base_rec.py +18 -21
- replay/models/lin_ucb.py +407 -0
- replay/models/nn/sequential/bert4rec/dataset.py +17 -4
- replay/models/nn/sequential/bert4rec/lightning.py +121 -54
- replay/models/nn/sequential/bert4rec/model.py +21 -0
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +5 -1
- replay/models/nn/sequential/compiled/__init__.py +5 -0
- replay/models/nn/sequential/compiled/base_compiled_model.py +261 -0
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +152 -0
- replay/models/nn/sequential/compiled/sasrec_compiled.py +145 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +27 -1
- replay/models/nn/sequential/sasrec/dataset.py +17 -1
- replay/models/nn/sequential/sasrec/lightning.py +126 -50
- replay/models/nn/sequential/sasrec/model.py +3 -4
- replay/preprocessing/__init__.py +7 -1
- replay/preprocessing/discretizer.py +719 -0
- replay/preprocessing/label_encoder.py +384 -52
- replay/splitters/cold_user_random_splitter.py +1 -1
- replay/utils/__init__.py +1 -0
- replay/utils/common.py +7 -8
- replay/utils/session_handler.py +3 -4
- replay/utils/spark_utils.py +15 -1
- replay/utils/types.py +8 -0
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/METADATA +73 -60
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/RECORD +37 -31
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/LICENSE +0 -0
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/WHEEL +0 -0
|
@@ -7,7 +7,11 @@ Contains classes for encoding categorical data
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import abc
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
10
12
|
import warnings
|
|
13
|
+
from itertools import chain
|
|
14
|
+
from pathlib import Path
|
|
11
15
|
from typing import Dict, List, Literal, Mapping, Optional, Sequence, Union
|
|
12
16
|
|
|
13
17
|
import polars as pl
|
|
@@ -22,9 +26,8 @@ from replay.utils import (
|
|
|
22
26
|
)
|
|
23
27
|
|
|
24
28
|
if PYSPARK_AVAILABLE:
|
|
25
|
-
from pyspark.sql import functions as sf
|
|
26
|
-
from pyspark.sql.types import LongType
|
|
27
|
-
from pyspark.storagelevel import StorageLevel
|
|
29
|
+
from pyspark.sql import Window, functions as sf # noqa: I001
|
|
30
|
+
from pyspark.sql.types import LongType
|
|
28
31
|
|
|
29
32
|
HandleUnknownStrategies = Literal["error", "use_default_value", "drop"]
|
|
30
33
|
|
|
@@ -33,6 +36,10 @@ class LabelEncoderTransformWarning(Warning):
|
|
|
33
36
|
"""Label encoder transform warning."""
|
|
34
37
|
|
|
35
38
|
|
|
39
|
+
class LabelEncoderPartialFitWarning(Warning):
|
|
40
|
+
"""Label encoder partial fit warning."""
|
|
41
|
+
|
|
42
|
+
|
|
36
43
|
class BaseLabelEncodingRule(abc.ABC): # pragma: no cover
|
|
37
44
|
"""
|
|
38
45
|
Interface of the label encoding rule
|
|
@@ -78,7 +85,7 @@ class BaseLabelEncodingRule(abc.ABC): # pragma: no cover
|
|
|
78
85
|
|
|
79
86
|
class LabelEncodingRule(BaseLabelEncodingRule):
|
|
80
87
|
"""
|
|
81
|
-
Implementation of the encoding rule for categorical variables of PySpark and
|
|
88
|
+
Implementation of the encoding rule for categorical variables of PySpark, Pandas and Polars Data Frames.
|
|
82
89
|
Encodes target labels with value between 0 and n_classes-1 for the given column.
|
|
83
90
|
It is recommended to use together with the LabelEncoder.
|
|
84
91
|
"""
|
|
@@ -163,22 +170,19 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
163
170
|
return inverse_mapping_list
|
|
164
171
|
|
|
165
172
|
def _fit_spark(self, df: SparkDataFrame) -> None:
|
|
166
|
-
unique_col_values = df.select(self._col).distinct()
|
|
173
|
+
unique_col_values = df.select(self._col).distinct()
|
|
174
|
+
window_function_give_ids = Window.orderBy(self._col)
|
|
167
175
|
|
|
168
176
|
mapping_on_spark = (
|
|
169
|
-
unique_col_values.
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
.add("_1", StructType().add(self._col, df.schema[self._col].dataType, True), True)
|
|
173
|
-
.add("_2", LongType(), True)
|
|
177
|
+
unique_col_values.withColumn(
|
|
178
|
+
self._target_col,
|
|
179
|
+
sf.row_number().over(window_function_give_ids).cast(LongType()),
|
|
174
180
|
)
|
|
175
|
-
.
|
|
176
|
-
.
|
|
181
|
+
.withColumn(self._target_col, sf.col(self._target_col) - 1)
|
|
182
|
+
.select(self._col, self._target_col)
|
|
177
183
|
)
|
|
178
184
|
|
|
179
185
|
self._mapping = mapping_on_spark.rdd.collectAsMap()
|
|
180
|
-
mapping_on_spark.unpersist()
|
|
181
|
-
unique_col_values.unpersist()
|
|
182
186
|
|
|
183
187
|
def _fit_pandas(self, df: PandasDataFrame) -> None:
|
|
184
188
|
unique_col_values = df[self._col].drop_duplicates().reset_index(drop=True)
|
|
@@ -222,34 +226,43 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
222
226
|
|
|
223
227
|
def _partial_fit_spark(self, df: SparkDataFrame) -> None:
|
|
224
228
|
assert self._mapping is not None
|
|
225
|
-
|
|
226
229
|
max_value = sf.lit(max(self._mapping.values()) + 1)
|
|
227
230
|
already_fitted = list(self._mapping.keys())
|
|
228
231
|
new_values = {x[self._col] for x in df.select(self._col).distinct().collect()} - set(already_fitted)
|
|
229
232
|
new_values_list = [[x] for x in new_values]
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
233
|
+
if len(new_values_list) == 0:
|
|
234
|
+
warnings.warn(
|
|
235
|
+
"partial_fit will have no effect because "
|
|
236
|
+
f"there are no new values in the incoming dataset at '{self.column}' column",
|
|
237
|
+
LabelEncoderPartialFitWarning,
|
|
238
|
+
)
|
|
239
|
+
return
|
|
240
|
+
new_unique_values_df: SparkDataFrame = get_spark_session().createDataFrame(new_values_list, schema=[self._col])
|
|
241
|
+
window_function_give_ids = Window.orderBy(self._col)
|
|
242
|
+
new_part_of_mapping = (
|
|
243
|
+
new_unique_values_df.withColumn(
|
|
244
|
+
self._target_col,
|
|
245
|
+
sf.row_number().over(window_function_give_ids).cast(LongType()),
|
|
239
246
|
)
|
|
240
|
-
.
|
|
241
|
-
.
|
|
247
|
+
.withColumn(self._target_col, sf.col(self._target_col) - 1 + max_value)
|
|
248
|
+
.select(self._col, self._target_col)
|
|
242
249
|
.rdd.collectAsMap()
|
|
243
250
|
)
|
|
244
|
-
self._mapping.update(
|
|
245
|
-
self._inverse_mapping.update({v: k for k, v in
|
|
246
|
-
self._inverse_mapping_list.extend(
|
|
247
|
-
new_unique_values.unpersist()
|
|
251
|
+
self._mapping.update(new_part_of_mapping)
|
|
252
|
+
self._inverse_mapping.update({v: k for k, v in new_part_of_mapping.items()})
|
|
253
|
+
self._inverse_mapping_list.extend(new_part_of_mapping.keys())
|
|
248
254
|
|
|
249
255
|
def _partial_fit_pandas(self, df: PandasDataFrame) -> None:
|
|
250
256
|
assert self._mapping is not None
|
|
251
257
|
|
|
252
258
|
new_unique_values = set(df[self._col].tolist()) - set(self._mapping)
|
|
259
|
+
if len(new_unique_values) == 0:
|
|
260
|
+
warnings.warn(
|
|
261
|
+
"partial_fit will have no effect because "
|
|
262
|
+
f"there are no new values in the incoming dataset at '{self.column}' column",
|
|
263
|
+
LabelEncoderPartialFitWarning,
|
|
264
|
+
)
|
|
265
|
+
return
|
|
253
266
|
last_mapping_value = max(self._mapping.values())
|
|
254
267
|
new_data: dict = {value: last_mapping_value + i for i, value in enumerate(new_unique_values, start=1)}
|
|
255
268
|
self._mapping.update(new_data)
|
|
@@ -260,6 +273,13 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
260
273
|
assert self._mapping is not None
|
|
261
274
|
|
|
262
275
|
new_unique_values = set(df.select(self._col).unique().to_series().to_list()) - set(self._mapping)
|
|
276
|
+
if len(new_unique_values) == 0:
|
|
277
|
+
warnings.warn(
|
|
278
|
+
"partial_fit will have no effect because "
|
|
279
|
+
f"there are no new values in the incoming dataset at '{self.column}' column",
|
|
280
|
+
LabelEncoderPartialFitWarning,
|
|
281
|
+
)
|
|
282
|
+
return
|
|
263
283
|
new_data: dict = {value: max(self._mapping.values()) + i for i, value in enumerate(new_unique_values, start=1)}
|
|
264
284
|
self._mapping.update(new_data)
|
|
265
285
|
self._inverse_mapping.update({v: k for k, v in new_data.items()})
|
|
@@ -484,6 +504,272 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
484
504
|
raise ValueError(msg)
|
|
485
505
|
self._handle_unknown = handle_unknown
|
|
486
506
|
|
|
507
|
+
def save(
|
|
508
|
+
self,
|
|
509
|
+
path: str,
|
|
510
|
+
) -> None:
|
|
511
|
+
encoder_rule_dict = {}
|
|
512
|
+
encoder_rule_dict["_class_name"] = self.__class__.__name__
|
|
513
|
+
encoder_rule_dict["init_args"] = {
|
|
514
|
+
"column": self._col,
|
|
515
|
+
"mapping": self._mapping,
|
|
516
|
+
"handle_unknown": self._handle_unknown,
|
|
517
|
+
"default_value": self._default_value,
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
column_type = str(type(next(iter(self._mapping))))
|
|
521
|
+
|
|
522
|
+
if not isinstance(column_type, (str, int, float)): # pragma: no cover
|
|
523
|
+
msg = f"LabelEncodingRule.save() is not implemented for column type {column_type}. \
|
|
524
|
+
Convert type to string, integer, or float."
|
|
525
|
+
raise NotImplementedError(msg)
|
|
526
|
+
|
|
527
|
+
encoder_rule_dict["fitted_args"] = {
|
|
528
|
+
"target_col": self._target_col,
|
|
529
|
+
"is_fitted": self._is_fitted,
|
|
530
|
+
"column_type": column_type,
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
534
|
+
if os.path.exists(base_path): # pragma: no cover
|
|
535
|
+
msg = "There is already LabelEncodingRule object saved at the given path. File will be overwrited."
|
|
536
|
+
warnings.warn(msg)
|
|
537
|
+
else: # pragma: no cover
|
|
538
|
+
base_path.mkdir(parents=True, exist_ok=True)
|
|
539
|
+
|
|
540
|
+
with open(base_path / "init_args.json", "w+") as file:
|
|
541
|
+
json.dump(encoder_rule_dict, file)
|
|
542
|
+
|
|
543
|
+
@classmethod
|
|
544
|
+
def load(cls, path: str) -> "LabelEncodingRule":
|
|
545
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
546
|
+
with open(base_path / "init_args.json", "r") as file:
|
|
547
|
+
encoder_rule_dict = json.loads(file.read())
|
|
548
|
+
|
|
549
|
+
string_column_type = encoder_rule_dict["fitted_args"]["column_type"]
|
|
550
|
+
if "str" in string_column_type:
|
|
551
|
+
column_type = str
|
|
552
|
+
elif "int" in string_column_type:
|
|
553
|
+
column_type = int
|
|
554
|
+
elif "float" in string_column_type:
|
|
555
|
+
column_type = float
|
|
556
|
+
|
|
557
|
+
encoder_rule_dict["init_args"]["mapping"] = {
|
|
558
|
+
column_type(key): int(value) for key, value in encoder_rule_dict["init_args"]["mapping"].items()
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
encoding_rule = cls(**encoder_rule_dict["init_args"])
|
|
562
|
+
encoding_rule._target_col = encoder_rule_dict["fitted_args"]["target_col"]
|
|
563
|
+
encoding_rule._is_fitted = encoder_rule_dict["fitted_args"]["is_fitted"]
|
|
564
|
+
return encoding_rule
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
class SequenceEncodingRule(LabelEncodingRule):
|
|
568
|
+
"""
|
|
569
|
+
Implementation of the encoding rule for grouped categorical variables of PySpark, Pandas and Polars Data Frames.
|
|
570
|
+
Grouped means that one cell of the table contains a list with categorical values.
|
|
571
|
+
Encodes target labels with value between 0 and n_classes-1 for the given column.
|
|
572
|
+
It is recommended to use together with the LabelEncoder.
|
|
573
|
+
"""
|
|
574
|
+
|
|
575
|
+
_FAKE_INDEX_COLUMN_NAME: str = "__index__"
|
|
576
|
+
|
|
577
|
+
def fit(self, df: DataFrameLike) -> "SequenceEncodingRule":
|
|
578
|
+
"""
|
|
579
|
+
Fits encoder to input dataframe.
|
|
580
|
+
|
|
581
|
+
:param df: input dataframe.
|
|
582
|
+
:returns: fitted EncodingRule.
|
|
583
|
+
"""
|
|
584
|
+
if self._mapping is not None:
|
|
585
|
+
return self
|
|
586
|
+
|
|
587
|
+
if isinstance(df, PandasDataFrame):
|
|
588
|
+
self._fit_pandas(df[[self.column]].explode(self.column))
|
|
589
|
+
elif isinstance(df, SparkDataFrame):
|
|
590
|
+
self._fit_spark(df.select(self.column).withColumn(self.column, sf.explode(self.column)))
|
|
591
|
+
elif isinstance(df, PolarsDataFrame):
|
|
592
|
+
self._fit_polars(df.select(self.column).explode(self.column))
|
|
593
|
+
else:
|
|
594
|
+
msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
|
|
595
|
+
raise NotImplementedError(msg)
|
|
596
|
+
self._inverse_mapping = self._make_inverse_mapping()
|
|
597
|
+
self._inverse_mapping_list = self._make_inverse_mapping_list()
|
|
598
|
+
if self._handle_unknown == "use_default_value" and self._default_value in self._inverse_mapping:
|
|
599
|
+
msg = (
|
|
600
|
+
"The used value for default_value "
|
|
601
|
+
f"{self._default_value} is one of the "
|
|
602
|
+
"values already used for encoding the "
|
|
603
|
+
"seen labels."
|
|
604
|
+
)
|
|
605
|
+
raise ValueError(msg)
|
|
606
|
+
self._is_fitted = True
|
|
607
|
+
return self
|
|
608
|
+
|
|
609
|
+
def partial_fit(self, df: DataFrameLike) -> "SequenceEncodingRule":
|
|
610
|
+
"""
|
|
611
|
+
Fits new data to already fitted encoder.
|
|
612
|
+
|
|
613
|
+
:param df: input dataframe.
|
|
614
|
+
:returns: fitted EncodingRule.
|
|
615
|
+
"""
|
|
616
|
+
if self._mapping is None:
|
|
617
|
+
return self.fit(df)
|
|
618
|
+
if isinstance(df, SparkDataFrame):
|
|
619
|
+
self._partial_fit_spark(df.select(self.column).withColumn(self.column, sf.explode(self.column)))
|
|
620
|
+
elif isinstance(df, PandasDataFrame):
|
|
621
|
+
self._partial_fit_pandas(df[[self.column]].explode(self.column))
|
|
622
|
+
elif isinstance(df, PolarsDataFrame):
|
|
623
|
+
self._partial_fit_polars(df.select(self.column).explode(self.column))
|
|
624
|
+
else:
|
|
625
|
+
msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
|
|
626
|
+
raise NotImplementedError(msg)
|
|
627
|
+
|
|
628
|
+
self._is_fitted = True
|
|
629
|
+
return self
|
|
630
|
+
|
|
631
|
+
def _transform_spark(self, df: SparkDataFrame, default_value: Optional[int]) -> SparkDataFrame:
|
|
632
|
+
map_expr = sf.create_map([sf.lit(x) for x in chain(*self.get_mapping().items())])
|
|
633
|
+
encoded_df = df.withColumn(self._target_col, sf.transform(self.column, lambda x: map_expr.getItem(x)))
|
|
634
|
+
|
|
635
|
+
if self._handle_unknown == "drop":
|
|
636
|
+
encoded_df = encoded_df.withColumn(self._target_col, sf.filter(self._target_col, lambda x: x.isNotNull()))
|
|
637
|
+
if encoded_df.select(sf.max(sf.size(self._target_col))).first()[0] == 0:
|
|
638
|
+
warnings.warn(
|
|
639
|
+
f"You are trying to transform dataframe with all values are unknown for {self._col}, "
|
|
640
|
+
"with `handle_unknown_strategy=drop` leads to empty dataframe",
|
|
641
|
+
LabelEncoderTransformWarning,
|
|
642
|
+
)
|
|
643
|
+
elif self._handle_unknown == "error":
|
|
644
|
+
if (
|
|
645
|
+
encoded_df.select(sf.sum(sf.array_contains(self._target_col, -1).isNull().cast("integer"))).first()[0]
|
|
646
|
+
!= 0
|
|
647
|
+
):
|
|
648
|
+
msg = f"Found unknown labels in column {self._col} during transform"
|
|
649
|
+
raise ValueError(msg)
|
|
650
|
+
else:
|
|
651
|
+
if default_value:
|
|
652
|
+
encoded_df = encoded_df.withColumn(
|
|
653
|
+
self._target_col,
|
|
654
|
+
sf.transform(self._target_col, lambda x: sf.when(x.isNull(), default_value).otherwise(x)),
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
result_df = encoded_df.drop(self._col).withColumnRenamed(self._target_col, self._col)
|
|
658
|
+
return result_df
|
|
659
|
+
|
|
660
|
+
def _transform_pandas(self, df: PandasDataFrame, default_value: Optional[int]) -> PandasDataFrame:
|
|
661
|
+
mapping = self.get_mapping()
|
|
662
|
+
joined_df = df.copy()
|
|
663
|
+
if self._handle_unknown == "drop":
|
|
664
|
+
max_array_len = 0
|
|
665
|
+
|
|
666
|
+
def encode_func(array_col):
|
|
667
|
+
nonlocal mapping, max_array_len
|
|
668
|
+
res = []
|
|
669
|
+
for x in array_col:
|
|
670
|
+
cur_len = 0
|
|
671
|
+
mapped = mapping.get(x)
|
|
672
|
+
if mapped is not None:
|
|
673
|
+
res.append(mapped)
|
|
674
|
+
cur_len += 1
|
|
675
|
+
max_array_len = max(max_array_len, cur_len)
|
|
676
|
+
return res
|
|
677
|
+
|
|
678
|
+
joined_df[self._target_col] = joined_df[self._col].apply(encode_func)
|
|
679
|
+
if max_array_len == 0:
|
|
680
|
+
warnings.warn(
|
|
681
|
+
f"You are trying to transform dataframe with all values are unknown for {self._col}, "
|
|
682
|
+
"with `handle_unknown_strategy=drop` leads to empty dataframe",
|
|
683
|
+
LabelEncoderTransformWarning,
|
|
684
|
+
)
|
|
685
|
+
elif self._handle_unknown == "error":
|
|
686
|
+
none_count = 0
|
|
687
|
+
|
|
688
|
+
def encode_func(array_col):
|
|
689
|
+
nonlocal mapping, none_count
|
|
690
|
+
res = []
|
|
691
|
+
for x in array_col:
|
|
692
|
+
mapped = mapping.get(x)
|
|
693
|
+
if mapped is None:
|
|
694
|
+
none_count += 1
|
|
695
|
+
else:
|
|
696
|
+
res.append(mapped)
|
|
697
|
+
return res
|
|
698
|
+
|
|
699
|
+
joined_df[self._target_col] = joined_df[self._col].apply(encode_func)
|
|
700
|
+
if none_count != 0:
|
|
701
|
+
msg = f"Found unknown labels in column {self._col} during transform"
|
|
702
|
+
raise ValueError(msg)
|
|
703
|
+
else:
|
|
704
|
+
|
|
705
|
+
def encode_func(array_col):
|
|
706
|
+
nonlocal mapping
|
|
707
|
+
return [mapping.get(x, default_value) for x in array_col]
|
|
708
|
+
|
|
709
|
+
joined_df[self._target_col] = joined_df[self._col].apply(encode_func)
|
|
710
|
+
|
|
711
|
+
result_df = joined_df.drop(self._col, axis=1).rename(columns={self._target_col: self._col})
|
|
712
|
+
return result_df
|
|
713
|
+
|
|
714
|
+
def _transform_polars(self, df: PolarsDataFrame, default_value: Optional[int]) -> SparkDataFrame:
|
|
715
|
+
transformed_df = df.with_columns(
|
|
716
|
+
pl.col(self._col)
|
|
717
|
+
.list.eval(
|
|
718
|
+
pl.element().replace_strict(
|
|
719
|
+
self.get_mapping(), default=default_value if self._handle_unknown == "use_default_value" else None
|
|
720
|
+
),
|
|
721
|
+
parallel=True,
|
|
722
|
+
)
|
|
723
|
+
.alias(self._target_col)
|
|
724
|
+
)
|
|
725
|
+
if self._handle_unknown == "drop":
|
|
726
|
+
transformed_df = transformed_df.with_columns(pl.col(self._target_col).list.drop_nulls())
|
|
727
|
+
if (
|
|
728
|
+
transformed_df.with_columns(pl.col(self._target_col).list.len()).select(pl.sum(self._target_col)).item()
|
|
729
|
+
== 0
|
|
730
|
+
):
|
|
731
|
+
warnings.warn(
|
|
732
|
+
f"You are trying to transform dataframe with all values are unknown for {self._col}, "
|
|
733
|
+
"with `handle_unknown_strategy=drop` leads to empty dataframe",
|
|
734
|
+
LabelEncoderTransformWarning,
|
|
735
|
+
)
|
|
736
|
+
elif self._handle_unknown == "error":
|
|
737
|
+
none_checker = transformed_df.with_columns(
|
|
738
|
+
pl.col(self._target_col).list.contains(pl.lit(None, dtype=pl.Int64)).cast(pl.Int64)
|
|
739
|
+
)
|
|
740
|
+
if none_checker.select(pl.sum(self._target_col)).item() != 0:
|
|
741
|
+
msg = f"Found unknown labels in column {self._col} during transform"
|
|
742
|
+
raise ValueError(msg)
|
|
743
|
+
|
|
744
|
+
result_df = transformed_df.drop(self._col).rename({self._target_col: self._col})
|
|
745
|
+
return result_df
|
|
746
|
+
|
|
747
|
+
def _inverse_transform_pandas(self, df: PandasDataFrame) -> PandasDataFrame:
|
|
748
|
+
decoded_df = df.copy()
|
|
749
|
+
|
|
750
|
+
def decode_func(array_col):
|
|
751
|
+
return [self._inverse_mapping_list[x] for x in array_col]
|
|
752
|
+
|
|
753
|
+
decoded_df[self._col] = decoded_df[self._col].apply(decode_func)
|
|
754
|
+
return decoded_df
|
|
755
|
+
|
|
756
|
+
def _inverse_transform_polars(self, df: PolarsDataFrame) -> PolarsDataFrame:
|
|
757
|
+
mapping_size = len(self._inverse_mapping_list)
|
|
758
|
+
transformed_df = df.with_columns(
|
|
759
|
+
pl.col(self._col).list.eval(
|
|
760
|
+
pl.element().replace_strict(old=list(range(mapping_size)), new=self._inverse_mapping_list),
|
|
761
|
+
parallel=True,
|
|
762
|
+
)
|
|
763
|
+
)
|
|
764
|
+
return transformed_df
|
|
765
|
+
|
|
766
|
+
def _inverse_transform_spark(self, df: SparkDataFrame) -> SparkDataFrame:
|
|
767
|
+
array_expr = sf.array([sf.lit(x) for x in self._inverse_mapping_list])
|
|
768
|
+
decoded_df = df.withColumn(
|
|
769
|
+
self._target_col, sf.transform(self._col, lambda x: sf.element_at(array_expr, x + 1))
|
|
770
|
+
)
|
|
771
|
+
return decoded_df.drop(self._col).withColumnRenamed(self._target_col, self._col)
|
|
772
|
+
|
|
487
773
|
|
|
488
774
|
class LabelEncoder:
|
|
489
775
|
"""
|
|
@@ -491,42 +777,48 @@ class LabelEncoder:
|
|
|
491
777
|
|
|
492
778
|
>>> import pandas as pd
|
|
493
779
|
>>> user_interactions = pd.DataFrame([
|
|
494
|
-
...
|
|
495
|
-
...
|
|
496
|
-
...
|
|
497
|
-
... ], columns=["user_id", "item_1", "item_2"])
|
|
780
|
+
... ("u1", "item_1", "item_1", [1, 2, 3]),
|
|
781
|
+
... ("u2", "item_2", "item_2", [3, 4, 5]),
|
|
782
|
+
... ("u3", "item_3", "item_3", [-1, -2, 4]),
|
|
783
|
+
... ], columns=["user_id", "item_1", "item_2", "list"])
|
|
498
784
|
>>> user_interactions
|
|
499
|
-
|
|
500
|
-
0
|
|
501
|
-
1
|
|
502
|
-
2
|
|
503
|
-
>>> encoder = LabelEncoder(
|
|
504
|
-
...
|
|
505
|
-
...
|
|
785
|
+
user_id item_1 item_2 list
|
|
786
|
+
0 u1 item_1 item_1 [1, 2, 3]
|
|
787
|
+
1 u2 item_2 item_2 [3, 4, 5]
|
|
788
|
+
2 u3 item_3 item_3 [-1, -2, 4]
|
|
789
|
+
>>> encoder = LabelEncoder([
|
|
790
|
+
... LabelEncodingRule("user_id"),
|
|
791
|
+
... LabelEncodingRule("item_1"),
|
|
792
|
+
... LabelEncodingRule("item_2"),
|
|
793
|
+
... SequenceEncodingRule("list"),
|
|
794
|
+
... ])
|
|
506
795
|
>>> mapped_interactions = encoder.fit_transform(user_interactions)
|
|
507
796
|
>>> mapped_interactions
|
|
508
|
-
user_id item_1 item_2
|
|
509
|
-
0 0 0 0
|
|
510
|
-
1 1 1 1
|
|
511
|
-
2 2 2 2
|
|
797
|
+
user_id item_1 item_2 list
|
|
798
|
+
0 0 0 0 [0, 1, 2]
|
|
799
|
+
1 1 1 1 [2, 3, 4]
|
|
800
|
+
2 2 2 2 [5, 6, 3]
|
|
512
801
|
>>> encoder.mapping
|
|
513
802
|
{'user_id': {'u1': 0, 'u2': 1, 'u3': 2},
|
|
514
803
|
'item_1': {'item_1': 0, 'item_2': 1, 'item_3': 2},
|
|
515
|
-
'item_2': {'item_1': 0, 'item_2': 1, 'item_3': 2}
|
|
804
|
+
'item_2': {'item_1': 0, 'item_2': 1, 'item_3': 2},
|
|
805
|
+
'list': {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, -1: 5, -2: 6}}
|
|
516
806
|
>>> encoder.inverse_mapping
|
|
517
807
|
{'user_id': {0: 'u1', 1: 'u2', 2: 'u3'},
|
|
518
808
|
'item_1': {0: 'item_1', 1: 'item_2', 2: 'item_3'},
|
|
519
|
-
'item_2': {0: 'item_1', 1: 'item_2', 2: 'item_3'}
|
|
809
|
+
'item_2': {0: 'item_1', 1: 'item_2', 2: 'item_3'},
|
|
810
|
+
'list': {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: -1, 6: -2}}
|
|
520
811
|
>>> new_encoder = LabelEncoder([
|
|
521
812
|
... LabelEncodingRule("user_id", encoder.mapping["user_id"]),
|
|
522
813
|
... LabelEncodingRule("item_1", encoder.mapping["item_1"]),
|
|
523
|
-
... LabelEncodingRule("item_2", encoder.mapping["item_2"])
|
|
814
|
+
... LabelEncodingRule("item_2", encoder.mapping["item_2"]),
|
|
815
|
+
... SequenceEncodingRule("list", encoder.mapping["list"]),
|
|
524
816
|
... ])
|
|
525
817
|
>>> new_encoder.inverse_transform(mapped_interactions)
|
|
526
|
-
user_id
|
|
527
|
-
0 u1
|
|
528
|
-
1 u2
|
|
529
|
-
2 u3
|
|
818
|
+
user_id item_1 item_2 list
|
|
819
|
+
0 u1 item_1 item_1 [1, 2, 3]
|
|
820
|
+
1 u2 item_2 item_2 [3, 4, 5]
|
|
821
|
+
2 u3 item_3 item_3 [-1, -2, 4]
|
|
530
822
|
<BLANKLINE>
|
|
531
823
|
"""
|
|
532
824
|
|
|
@@ -650,3 +942,43 @@ class LabelEncoder:
|
|
|
650
942
|
raise ValueError(msg)
|
|
651
943
|
rule = list(filter(lambda x: x.column == column, self.rules))
|
|
652
944
|
rule[0].set_default_value(default_value)
|
|
945
|
+
|
|
946
|
+
def save(
|
|
947
|
+
self,
|
|
948
|
+
path: str,
|
|
949
|
+
) -> None:
|
|
950
|
+
encoder_dict = {}
|
|
951
|
+
encoder_dict["_class_name"] = self.__class__.__name__
|
|
952
|
+
|
|
953
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
954
|
+
if os.path.exists(base_path): # pragma: no cover
|
|
955
|
+
msg = "There is already LabelEncoder object saved at the given path. File will be overwrited."
|
|
956
|
+
warnings.warn(msg)
|
|
957
|
+
else: # pragma: no cover
|
|
958
|
+
base_path.mkdir(parents=True, exist_ok=True)
|
|
959
|
+
|
|
960
|
+
encoder_dict["rule_names"] = []
|
|
961
|
+
|
|
962
|
+
for rule in self.rules:
|
|
963
|
+
path_suffix = f"{rule.__class__.__name__}_{rule.column}"
|
|
964
|
+
rule.save(str(base_path) + f"/rules/{path_suffix}")
|
|
965
|
+
encoder_dict["rule_names"].append(path_suffix)
|
|
966
|
+
|
|
967
|
+
with open(base_path / "init_args.json", "w+") as file:
|
|
968
|
+
json.dump(encoder_dict, file)
|
|
969
|
+
|
|
970
|
+
@classmethod
|
|
971
|
+
def load(cls, path: str) -> "LabelEncoder":
|
|
972
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
973
|
+
with open(base_path / "init_args.json", "r") as file:
|
|
974
|
+
encoder_dict = json.loads(file.read())
|
|
975
|
+
rules = []
|
|
976
|
+
for root, dirs, files in os.walk(str(base_path) + "/rules/"):
|
|
977
|
+
for d in dirs:
|
|
978
|
+
if d.split(".")[0] in encoder_dict["rule_names"]:
|
|
979
|
+
with open(root + d + "/init_args.json", "r") as file:
|
|
980
|
+
encoder_rule_dict = json.loads(file.read())
|
|
981
|
+
rules.append(globals()[encoder_rule_dict["_class_name"]].load(root + d))
|
|
982
|
+
|
|
983
|
+
encoder = cls(rules=rules)
|
|
984
|
+
return encoder
|
|
@@ -94,7 +94,7 @@ class ColdUserRandomSplitter(Splitter):
|
|
|
94
94
|
) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
95
95
|
train_users = (
|
|
96
96
|
interactions.select(self.query_column)
|
|
97
|
-
.unique()
|
|
97
|
+
.unique(maintain_order=True)
|
|
98
98
|
.sample(fraction=(1 - threshold), seed=self.seed)
|
|
99
99
|
.with_columns(pl.lit(False).alias("is_test"))
|
|
100
100
|
)
|
replay/utils/__init__.py
CHANGED
replay/utils/common.py
CHANGED
|
@@ -7,6 +7,10 @@ from typing import Any, Callable, Union
|
|
|
7
7
|
from polars import from_pandas as pl_from_pandas
|
|
8
8
|
|
|
9
9
|
from replay.data.dataset import Dataset
|
|
10
|
+
from replay.preprocessing import (
|
|
11
|
+
LabelEncoder,
|
|
12
|
+
LabelEncodingRule,
|
|
13
|
+
)
|
|
10
14
|
from replay.splitters import (
|
|
11
15
|
ColdUserRandomSplitter,
|
|
12
16
|
KFolds,
|
|
@@ -38,20 +42,15 @@ SavableObject = Union[
|
|
|
38
42
|
TimeSplitter,
|
|
39
43
|
TwoStageSplitter,
|
|
40
44
|
Dataset,
|
|
45
|
+
LabelEncoder,
|
|
46
|
+
LabelEncodingRule,
|
|
41
47
|
]
|
|
42
48
|
|
|
43
49
|
if TORCH_AVAILABLE:
|
|
44
50
|
from replay.data.nn import PandasSequentialDataset, PolarsSequentialDataset, SequenceTokenizer
|
|
45
51
|
|
|
46
52
|
SavableObject = Union[
|
|
47
|
-
|
|
48
|
-
KFolds,
|
|
49
|
-
LastNSplitter,
|
|
50
|
-
NewUsersSplitter,
|
|
51
|
-
RandomSplitter,
|
|
52
|
-
RatioSplitter,
|
|
53
|
-
TimeSplitter,
|
|
54
|
-
TwoStageSplitter,
|
|
53
|
+
SavableObject,
|
|
55
54
|
SequenceTokenizer,
|
|
56
55
|
PandasSequentialDataset,
|
|
57
56
|
PolarsSequentialDataset,
|
replay/utils/session_handler.py
CHANGED
|
@@ -71,7 +71,7 @@ def get_spark_session(
|
|
|
71
71
|
shuffle_partitions = os.cpu_count() * 3
|
|
72
72
|
driver_memory = f"{spark_memory}g"
|
|
73
73
|
user_home = os.environ["HOME"]
|
|
74
|
-
|
|
74
|
+
spark_session_builder = (
|
|
75
75
|
SparkSession.builder.config("spark.driver.memory", driver_memory)
|
|
76
76
|
.config(
|
|
77
77
|
"spark.driver.extraJavaOptions",
|
|
@@ -87,10 +87,9 @@ def get_spark_session(
|
|
|
87
87
|
.config("spark.kryoserializer.buffer.max", "256m")
|
|
88
88
|
.config("spark.files.overwrite", "true")
|
|
89
89
|
.master(f"local[{'*' if core_count == -1 else core_count}]")
|
|
90
|
-
.enableHiveSupport()
|
|
91
|
-
.getOrCreate()
|
|
92
90
|
)
|
|
93
|
-
|
|
91
|
+
|
|
92
|
+
return spark_session_builder.getOrCreate()
|
|
94
93
|
|
|
95
94
|
|
|
96
95
|
def logger_with_settings() -> logging.Logger:
|
replay/utils/spark_utils.py
CHANGED
|
@@ -10,7 +10,7 @@ import pandas as pd
|
|
|
10
10
|
from numpy.random import default_rng
|
|
11
11
|
|
|
12
12
|
from .session_handler import State
|
|
13
|
-
from .types import PYSPARK_AVAILABLE, DataFrameLike, MissingImportType, NumType, SparkDataFrame
|
|
13
|
+
from .types import PYSPARK_AVAILABLE, DataFrameLike, MissingImportType, NumType, PolarsDataFrame, SparkDataFrame
|
|
14
14
|
|
|
15
15
|
if PYSPARK_AVAILABLE:
|
|
16
16
|
import pyspark.sql.types as st
|
|
@@ -27,6 +27,12 @@ else:
|
|
|
27
27
|
Column = MissingImportType
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
class PolarsConvertToSparkWarning(Warning):
|
|
31
|
+
"""
|
|
32
|
+
Direct PolarsDataFrame to SparkDataFrame convertation warning.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
|
|
30
36
|
class SparkCollectToMasterWarning(Warning): # pragma: no cover
|
|
31
37
|
"""
|
|
32
38
|
Collect to master warning for Spark DataFrames.
|
|
@@ -69,7 +75,15 @@ def convert2spark(data_frame: Optional[DataFrameLike]) -> Optional[SparkDataFram
|
|
|
69
75
|
return None
|
|
70
76
|
if isinstance(data_frame, SparkDataFrame):
|
|
71
77
|
return data_frame
|
|
78
|
+
|
|
72
79
|
spark = State().session
|
|
80
|
+
if isinstance(data_frame, PolarsDataFrame):
|
|
81
|
+
warnings.warn(
|
|
82
|
+
"Direct convertation PolarsDataFrame to SparkDataFrame currently is not supported, "
|
|
83
|
+
"converting to pandas first",
|
|
84
|
+
PolarsConvertToSparkWarning,
|
|
85
|
+
)
|
|
86
|
+
return spark.createDataFrame(data_frame.to_pandas()) # TODO: remove extra convertation to pandas
|
|
73
87
|
return spark.createDataFrame(data_frame)
|
|
74
88
|
|
|
75
89
|
|
replay/utils/types.py
CHANGED
|
@@ -25,6 +25,14 @@ try:
|
|
|
25
25
|
except ImportError:
|
|
26
26
|
TORCH_AVAILABLE = False
|
|
27
27
|
|
|
28
|
+
try:
|
|
29
|
+
import onnx # noqa: F401
|
|
30
|
+
import openvino # noqa: F401
|
|
31
|
+
|
|
32
|
+
OPENVINO_AVAILABLE = TORCH_AVAILABLE
|
|
33
|
+
except ImportError:
|
|
34
|
+
OPENVINO_AVAILABLE = False
|
|
35
|
+
|
|
28
36
|
DataFrameLike = Union[PandasDataFrame, SparkDataFrame, PolarsDataFrame]
|
|
29
37
|
IntOrList = Union[Iterable[int], int]
|
|
30
38
|
NumType = Union[int, float]
|