replay-rec 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +27 -1
  3. replay/data/dataset_utils/dataset_label_encoder.py +6 -3
  4. replay/data/nn/schema.py +37 -16
  5. replay/data/nn/sequence_tokenizer.py +313 -165
  6. replay/data/nn/torch_sequential_dataset.py +17 -8
  7. replay/data/nn/utils.py +14 -7
  8. replay/data/schema.py +10 -6
  9. replay/metrics/offline_metrics.py +2 -2
  10. replay/models/__init__.py +1 -0
  11. replay/models/base_rec.py +18 -21
  12. replay/models/lin_ucb.py +407 -0
  13. replay/models/nn/sequential/bert4rec/dataset.py +17 -4
  14. replay/models/nn/sequential/bert4rec/lightning.py +121 -54
  15. replay/models/nn/sequential/bert4rec/model.py +21 -0
  16. replay/models/nn/sequential/callbacks/prediction_callbacks.py +5 -1
  17. replay/models/nn/sequential/compiled/__init__.py +5 -0
  18. replay/models/nn/sequential/compiled/base_compiled_model.py +261 -0
  19. replay/models/nn/sequential/compiled/bert4rec_compiled.py +152 -0
  20. replay/models/nn/sequential/compiled/sasrec_compiled.py +145 -0
  21. replay/models/nn/sequential/postprocessors/postprocessors.py +27 -1
  22. replay/models/nn/sequential/sasrec/dataset.py +17 -1
  23. replay/models/nn/sequential/sasrec/lightning.py +126 -50
  24. replay/models/nn/sequential/sasrec/model.py +3 -4
  25. replay/preprocessing/__init__.py +7 -1
  26. replay/preprocessing/discretizer.py +719 -0
  27. replay/preprocessing/label_encoder.py +384 -52
  28. replay/splitters/cold_user_random_splitter.py +1 -1
  29. replay/utils/__init__.py +1 -0
  30. replay/utils/common.py +7 -8
  31. replay/utils/session_handler.py +3 -4
  32. replay/utils/spark_utils.py +15 -1
  33. replay/utils/types.py +8 -0
  34. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/METADATA +73 -60
  35. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/RECORD +37 -31
  36. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/LICENSE +0 -0
  37. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/WHEEL +0 -0
@@ -7,7 +7,11 @@ Contains classes for encoding categorical data
7
7
  """
8
8
 
9
9
  import abc
10
+ import json
11
+ import os
10
12
  import warnings
13
+ from itertools import chain
14
+ from pathlib import Path
11
15
  from typing import Dict, List, Literal, Mapping, Optional, Sequence, Union
12
16
 
13
17
  import polars as pl
@@ -22,9 +26,8 @@ from replay.utils import (
22
26
  )
23
27
 
24
28
  if PYSPARK_AVAILABLE:
25
- from pyspark.sql import functions as sf
26
- from pyspark.sql.types import LongType, StructType
27
- from pyspark.storagelevel import StorageLevel
29
+ from pyspark.sql import Window, functions as sf # noqa: I001
30
+ from pyspark.sql.types import LongType
28
31
 
29
32
  HandleUnknownStrategies = Literal["error", "use_default_value", "drop"]
30
33
 
@@ -33,6 +36,10 @@ class LabelEncoderTransformWarning(Warning):
33
36
  """Label encoder transform warning."""
34
37
 
35
38
 
39
+ class LabelEncoderPartialFitWarning(Warning):
40
+ """Label encoder partial fit warning."""
41
+
42
+
36
43
  class BaseLabelEncodingRule(abc.ABC): # pragma: no cover
37
44
  """
38
45
  Interface of the label encoding rule
@@ -78,7 +85,7 @@ class BaseLabelEncodingRule(abc.ABC): # pragma: no cover
78
85
 
79
86
  class LabelEncodingRule(BaseLabelEncodingRule):
80
87
  """
81
- Implementation of the encoding rule for categorical variables of PySpark and Pandas Data Frames.
88
+ Implementation of the encoding rule for categorical variables of PySpark, Pandas and Polars Data Frames.
82
89
  Encodes target labels with value between 0 and n_classes-1 for the given column.
83
90
  It is recommended to use together with the LabelEncoder.
84
91
  """
@@ -163,22 +170,19 @@ class LabelEncodingRule(BaseLabelEncodingRule):
163
170
  return inverse_mapping_list
164
171
 
165
172
  def _fit_spark(self, df: SparkDataFrame) -> None:
166
- unique_col_values = df.select(self._col).distinct().persist(StorageLevel.MEMORY_ONLY)
173
+ unique_col_values = df.select(self._col).distinct()
174
+ window_function_give_ids = Window.orderBy(self._col)
167
175
 
168
176
  mapping_on_spark = (
169
- unique_col_values.rdd.zipWithIndex()
170
- .toDF(
171
- StructType()
172
- .add("_1", StructType().add(self._col, df.schema[self._col].dataType, True), True)
173
- .add("_2", LongType(), True)
177
+ unique_col_values.withColumn(
178
+ self._target_col,
179
+ sf.row_number().over(window_function_give_ids).cast(LongType()),
174
180
  )
175
- .select(sf.col(f"_1.{self._col}").alias(self._col), sf.col("_2").alias(self._target_col))
176
- .persist(StorageLevel.MEMORY_ONLY)
181
+ .withColumn(self._target_col, sf.col(self._target_col) - 1)
182
+ .select(self._col, self._target_col)
177
183
  )
178
184
 
179
185
  self._mapping = mapping_on_spark.rdd.collectAsMap()
180
- mapping_on_spark.unpersist()
181
- unique_col_values.unpersist()
182
186
 
183
187
  def _fit_pandas(self, df: PandasDataFrame) -> None:
184
188
  unique_col_values = df[self._col].drop_duplicates().reset_index(drop=True)
@@ -222,34 +226,43 @@ class LabelEncodingRule(BaseLabelEncodingRule):
222
226
 
223
227
  def _partial_fit_spark(self, df: SparkDataFrame) -> None:
224
228
  assert self._mapping is not None
225
-
226
229
  max_value = sf.lit(max(self._mapping.values()) + 1)
227
230
  already_fitted = list(self._mapping.keys())
228
231
  new_values = {x[self._col] for x in df.select(self._col).distinct().collect()} - set(already_fitted)
229
232
  new_values_list = [[x] for x in new_values]
230
- new_values_df: SparkDataFrame = get_spark_session().createDataFrame(new_values_list, schema=[self._col])
231
- new_unique_values = new_values_df.join(df, on=self._col, how="left").select(self._col)
232
-
233
- new_data: dict = (
234
- new_unique_values.rdd.zipWithIndex()
235
- .toDF(
236
- StructType()
237
- .add("_1", StructType().add(self._col, df.schema[self._col].dataType), True)
238
- .add("_2", LongType(), True)
233
+ if len(new_values_list) == 0:
234
+ warnings.warn(
235
+ "partial_fit will have no effect because "
236
+ f"there are no new values in the incoming dataset at '{self.column}' column",
237
+ LabelEncoderPartialFitWarning,
238
+ )
239
+ return
240
+ new_unique_values_df: SparkDataFrame = get_spark_session().createDataFrame(new_values_list, schema=[self._col])
241
+ window_function_give_ids = Window.orderBy(self._col)
242
+ new_part_of_mapping = (
243
+ new_unique_values_df.withColumn(
244
+ self._target_col,
245
+ sf.row_number().over(window_function_give_ids).cast(LongType()),
239
246
  )
240
- .select(sf.col(f"_1.{self._col}").alias(self._col), sf.col("_2").alias(self._target_col))
241
- .withColumn(self._target_col, sf.col(self._target_col) + max_value)
247
+ .withColumn(self._target_col, sf.col(self._target_col) - 1 + max_value)
248
+ .select(self._col, self._target_col)
242
249
  .rdd.collectAsMap()
243
250
  )
244
- self._mapping.update(new_data)
245
- self._inverse_mapping.update({v: k for k, v in new_data.items()})
246
- self._inverse_mapping_list.extend(new_data.keys())
247
- new_unique_values.unpersist()
251
+ self._mapping.update(new_part_of_mapping)
252
+ self._inverse_mapping.update({v: k for k, v in new_part_of_mapping.items()})
253
+ self._inverse_mapping_list.extend(new_part_of_mapping.keys())
248
254
 
249
255
  def _partial_fit_pandas(self, df: PandasDataFrame) -> None:
250
256
  assert self._mapping is not None
251
257
 
252
258
  new_unique_values = set(df[self._col].tolist()) - set(self._mapping)
259
+ if len(new_unique_values) == 0:
260
+ warnings.warn(
261
+ "partial_fit will have no effect because "
262
+ f"there are no new values in the incoming dataset at '{self.column}' column",
263
+ LabelEncoderPartialFitWarning,
264
+ )
265
+ return
253
266
  last_mapping_value = max(self._mapping.values())
254
267
  new_data: dict = {value: last_mapping_value + i for i, value in enumerate(new_unique_values, start=1)}
255
268
  self._mapping.update(new_data)
@@ -260,6 +273,13 @@ class LabelEncodingRule(BaseLabelEncodingRule):
260
273
  assert self._mapping is not None
261
274
 
262
275
  new_unique_values = set(df.select(self._col).unique().to_series().to_list()) - set(self._mapping)
276
+ if len(new_unique_values) == 0:
277
+ warnings.warn(
278
+ "partial_fit will have no effect because "
279
+ f"there are no new values in the incoming dataset at '{self.column}' column",
280
+ LabelEncoderPartialFitWarning,
281
+ )
282
+ return
263
283
  new_data: dict = {value: max(self._mapping.values()) + i for i, value in enumerate(new_unique_values, start=1)}
264
284
  self._mapping.update(new_data)
265
285
  self._inverse_mapping.update({v: k for k, v in new_data.items()})
@@ -484,6 +504,272 @@ class LabelEncodingRule(BaseLabelEncodingRule):
484
504
  raise ValueError(msg)
485
505
  self._handle_unknown = handle_unknown
486
506
 
507
+ def save(
508
+ self,
509
+ path: str,
510
+ ) -> None:
511
+ encoder_rule_dict = {}
512
+ encoder_rule_dict["_class_name"] = self.__class__.__name__
513
+ encoder_rule_dict["init_args"] = {
514
+ "column": self._col,
515
+ "mapping": self._mapping,
516
+ "handle_unknown": self._handle_unknown,
517
+ "default_value": self._default_value,
518
+ }
519
+
520
+ column_type = str(type(next(iter(self._mapping))))
521
+
522
+ if not isinstance(column_type, (str, int, float)): # pragma: no cover
523
+ msg = f"LabelEncodingRule.save() is not implemented for column type {column_type}. \
524
+ Convert type to string, integer, or float."
525
+ raise NotImplementedError(msg)
526
+
527
+ encoder_rule_dict["fitted_args"] = {
528
+ "target_col": self._target_col,
529
+ "is_fitted": self._is_fitted,
530
+ "column_type": column_type,
531
+ }
532
+
533
+ base_path = Path(path).with_suffix(".replay").resolve()
534
+ if os.path.exists(base_path): # pragma: no cover
535
+ msg = "There is already LabelEncodingRule object saved at the given path. File will be overwrited."
536
+ warnings.warn(msg)
537
+ else: # pragma: no cover
538
+ base_path.mkdir(parents=True, exist_ok=True)
539
+
540
+ with open(base_path / "init_args.json", "w+") as file:
541
+ json.dump(encoder_rule_dict, file)
542
+
543
+ @classmethod
544
+ def load(cls, path: str) -> "LabelEncodingRule":
545
+ base_path = Path(path).with_suffix(".replay").resolve()
546
+ with open(base_path / "init_args.json", "r") as file:
547
+ encoder_rule_dict = json.loads(file.read())
548
+
549
+ string_column_type = encoder_rule_dict["fitted_args"]["column_type"]
550
+ if "str" in string_column_type:
551
+ column_type = str
552
+ elif "int" in string_column_type:
553
+ column_type = int
554
+ elif "float" in string_column_type:
555
+ column_type = float
556
+
557
+ encoder_rule_dict["init_args"]["mapping"] = {
558
+ column_type(key): int(value) for key, value in encoder_rule_dict["init_args"]["mapping"].items()
559
+ }
560
+
561
+ encoding_rule = cls(**encoder_rule_dict["init_args"])
562
+ encoding_rule._target_col = encoder_rule_dict["fitted_args"]["target_col"]
563
+ encoding_rule._is_fitted = encoder_rule_dict["fitted_args"]["is_fitted"]
564
+ return encoding_rule
565
+
566
+
567
+ class SequenceEncodingRule(LabelEncodingRule):
568
+ """
569
+ Implementation of the encoding rule for grouped categorical variables of PySpark, Pandas and Polars Data Frames.
570
+ Grouped means that one cell of the table contains a list with categorical values.
571
+ Encodes target labels with value between 0 and n_classes-1 for the given column.
572
+ It is recommended to use together with the LabelEncoder.
573
+ """
574
+
575
+ _FAKE_INDEX_COLUMN_NAME: str = "__index__"
576
+
577
+ def fit(self, df: DataFrameLike) -> "SequenceEncodingRule":
578
+ """
579
+ Fits encoder to input dataframe.
580
+
581
+ :param df: input dataframe.
582
+ :returns: fitted EncodingRule.
583
+ """
584
+ if self._mapping is not None:
585
+ return self
586
+
587
+ if isinstance(df, PandasDataFrame):
588
+ self._fit_pandas(df[[self.column]].explode(self.column))
589
+ elif isinstance(df, SparkDataFrame):
590
+ self._fit_spark(df.select(self.column).withColumn(self.column, sf.explode(self.column)))
591
+ elif isinstance(df, PolarsDataFrame):
592
+ self._fit_polars(df.select(self.column).explode(self.column))
593
+ else:
594
+ msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
595
+ raise NotImplementedError(msg)
596
+ self._inverse_mapping = self._make_inverse_mapping()
597
+ self._inverse_mapping_list = self._make_inverse_mapping_list()
598
+ if self._handle_unknown == "use_default_value" and self._default_value in self._inverse_mapping:
599
+ msg = (
600
+ "The used value for default_value "
601
+ f"{self._default_value} is one of the "
602
+ "values already used for encoding the "
603
+ "seen labels."
604
+ )
605
+ raise ValueError(msg)
606
+ self._is_fitted = True
607
+ return self
608
+
609
+ def partial_fit(self, df: DataFrameLike) -> "SequenceEncodingRule":
610
+ """
611
+ Fits new data to already fitted encoder.
612
+
613
+ :param df: input dataframe.
614
+ :returns: fitted EncodingRule.
615
+ """
616
+ if self._mapping is None:
617
+ return self.fit(df)
618
+ if isinstance(df, SparkDataFrame):
619
+ self._partial_fit_spark(df.select(self.column).withColumn(self.column, sf.explode(self.column)))
620
+ elif isinstance(df, PandasDataFrame):
621
+ self._partial_fit_pandas(df[[self.column]].explode(self.column))
622
+ elif isinstance(df, PolarsDataFrame):
623
+ self._partial_fit_polars(df.select(self.column).explode(self.column))
624
+ else:
625
+ msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
626
+ raise NotImplementedError(msg)
627
+
628
+ self._is_fitted = True
629
+ return self
630
+
631
+ def _transform_spark(self, df: SparkDataFrame, default_value: Optional[int]) -> SparkDataFrame:
632
+ map_expr = sf.create_map([sf.lit(x) for x in chain(*self.get_mapping().items())])
633
+ encoded_df = df.withColumn(self._target_col, sf.transform(self.column, lambda x: map_expr.getItem(x)))
634
+
635
+ if self._handle_unknown == "drop":
636
+ encoded_df = encoded_df.withColumn(self._target_col, sf.filter(self._target_col, lambda x: x.isNotNull()))
637
+ if encoded_df.select(sf.max(sf.size(self._target_col))).first()[0] == 0:
638
+ warnings.warn(
639
+ f"You are trying to transform dataframe with all values are unknown for {self._col}, "
640
+ "with `handle_unknown_strategy=drop` leads to empty dataframe",
641
+ LabelEncoderTransformWarning,
642
+ )
643
+ elif self._handle_unknown == "error":
644
+ if (
645
+ encoded_df.select(sf.sum(sf.array_contains(self._target_col, -1).isNull().cast("integer"))).first()[0]
646
+ != 0
647
+ ):
648
+ msg = f"Found unknown labels in column {self._col} during transform"
649
+ raise ValueError(msg)
650
+ else:
651
+ if default_value:
652
+ encoded_df = encoded_df.withColumn(
653
+ self._target_col,
654
+ sf.transform(self._target_col, lambda x: sf.when(x.isNull(), default_value).otherwise(x)),
655
+ )
656
+
657
+ result_df = encoded_df.drop(self._col).withColumnRenamed(self._target_col, self._col)
658
+ return result_df
659
+
660
+ def _transform_pandas(self, df: PandasDataFrame, default_value: Optional[int]) -> PandasDataFrame:
661
+ mapping = self.get_mapping()
662
+ joined_df = df.copy()
663
+ if self._handle_unknown == "drop":
664
+ max_array_len = 0
665
+
666
+ def encode_func(array_col):
667
+ nonlocal mapping, max_array_len
668
+ res = []
669
+ for x in array_col:
670
+ cur_len = 0
671
+ mapped = mapping.get(x)
672
+ if mapped is not None:
673
+ res.append(mapped)
674
+ cur_len += 1
675
+ max_array_len = max(max_array_len, cur_len)
676
+ return res
677
+
678
+ joined_df[self._target_col] = joined_df[self._col].apply(encode_func)
679
+ if max_array_len == 0:
680
+ warnings.warn(
681
+ f"You are trying to transform dataframe with all values are unknown for {self._col}, "
682
+ "with `handle_unknown_strategy=drop` leads to empty dataframe",
683
+ LabelEncoderTransformWarning,
684
+ )
685
+ elif self._handle_unknown == "error":
686
+ none_count = 0
687
+
688
+ def encode_func(array_col):
689
+ nonlocal mapping, none_count
690
+ res = []
691
+ for x in array_col:
692
+ mapped = mapping.get(x)
693
+ if mapped is None:
694
+ none_count += 1
695
+ else:
696
+ res.append(mapped)
697
+ return res
698
+
699
+ joined_df[self._target_col] = joined_df[self._col].apply(encode_func)
700
+ if none_count != 0:
701
+ msg = f"Found unknown labels in column {self._col} during transform"
702
+ raise ValueError(msg)
703
+ else:
704
+
705
+ def encode_func(array_col):
706
+ nonlocal mapping
707
+ return [mapping.get(x, default_value) for x in array_col]
708
+
709
+ joined_df[self._target_col] = joined_df[self._col].apply(encode_func)
710
+
711
+ result_df = joined_df.drop(self._col, axis=1).rename(columns={self._target_col: self._col})
712
+ return result_df
713
+
714
+ def _transform_polars(self, df: PolarsDataFrame, default_value: Optional[int]) -> SparkDataFrame:
715
+ transformed_df = df.with_columns(
716
+ pl.col(self._col)
717
+ .list.eval(
718
+ pl.element().replace_strict(
719
+ self.get_mapping(), default=default_value if self._handle_unknown == "use_default_value" else None
720
+ ),
721
+ parallel=True,
722
+ )
723
+ .alias(self._target_col)
724
+ )
725
+ if self._handle_unknown == "drop":
726
+ transformed_df = transformed_df.with_columns(pl.col(self._target_col).list.drop_nulls())
727
+ if (
728
+ transformed_df.with_columns(pl.col(self._target_col).list.len()).select(pl.sum(self._target_col)).item()
729
+ == 0
730
+ ):
731
+ warnings.warn(
732
+ f"You are trying to transform dataframe with all values are unknown for {self._col}, "
733
+ "with `handle_unknown_strategy=drop` leads to empty dataframe",
734
+ LabelEncoderTransformWarning,
735
+ )
736
+ elif self._handle_unknown == "error":
737
+ none_checker = transformed_df.with_columns(
738
+ pl.col(self._target_col).list.contains(pl.lit(None, dtype=pl.Int64)).cast(pl.Int64)
739
+ )
740
+ if none_checker.select(pl.sum(self._target_col)).item() != 0:
741
+ msg = f"Found unknown labels in column {self._col} during transform"
742
+ raise ValueError(msg)
743
+
744
+ result_df = transformed_df.drop(self._col).rename({self._target_col: self._col})
745
+ return result_df
746
+
747
+ def _inverse_transform_pandas(self, df: PandasDataFrame) -> PandasDataFrame:
748
+ decoded_df = df.copy()
749
+
750
+ def decode_func(array_col):
751
+ return [self._inverse_mapping_list[x] for x in array_col]
752
+
753
+ decoded_df[self._col] = decoded_df[self._col].apply(decode_func)
754
+ return decoded_df
755
+
756
+ def _inverse_transform_polars(self, df: PolarsDataFrame) -> PolarsDataFrame:
757
+ mapping_size = len(self._inverse_mapping_list)
758
+ transformed_df = df.with_columns(
759
+ pl.col(self._col).list.eval(
760
+ pl.element().replace_strict(old=list(range(mapping_size)), new=self._inverse_mapping_list),
761
+ parallel=True,
762
+ )
763
+ )
764
+ return transformed_df
765
+
766
+ def _inverse_transform_spark(self, df: SparkDataFrame) -> SparkDataFrame:
767
+ array_expr = sf.array([sf.lit(x) for x in self._inverse_mapping_list])
768
+ decoded_df = df.withColumn(
769
+ self._target_col, sf.transform(self._col, lambda x: sf.element_at(array_expr, x + 1))
770
+ )
771
+ return decoded_df.drop(self._col).withColumnRenamed(self._target_col, self._col)
772
+
487
773
 
488
774
  class LabelEncoder:
489
775
  """
@@ -491,42 +777,48 @@ class LabelEncoder:
491
777
 
492
778
  >>> import pandas as pd
493
779
  >>> user_interactions = pd.DataFrame([
494
- ... ("u1", "item_1", "item_1"),
495
- ... ("u2", "item_2", "item_2"),
496
- ... ("u3", "item_3", "item_3"),
497
- ... ], columns=["user_id", "item_1", "item_2"])
780
+ ... ("u1", "item_1", "item_1", [1, 2, 3]),
781
+ ... ("u2", "item_2", "item_2", [3, 4, 5]),
782
+ ... ("u3", "item_3", "item_3", [-1, -2, 4]),
783
+ ... ], columns=["user_id", "item_1", "item_2", "list"])
498
784
  >>> user_interactions
499
- user_id item_1 item_2
500
- 0 u1 item_1 item_1
501
- 1 u2 item_2 item_2
502
- 2 u3 item_3 item_3
503
- >>> encoder = LabelEncoder(
504
- ... [LabelEncodingRule("user_id"), LabelEncodingRule("item_1"), LabelEncodingRule("item_2")]
505
- ... )
785
+ user_id item_1 item_2 list
786
+ 0 u1 item_1 item_1 [1, 2, 3]
787
+ 1 u2 item_2 item_2 [3, 4, 5]
788
+ 2 u3 item_3 item_3 [-1, -2, 4]
789
+ >>> encoder = LabelEncoder([
790
+ ... LabelEncodingRule("user_id"),
791
+ ... LabelEncodingRule("item_1"),
792
+ ... LabelEncodingRule("item_2"),
793
+ ... SequenceEncodingRule("list"),
794
+ ... ])
506
795
  >>> mapped_interactions = encoder.fit_transform(user_interactions)
507
796
  >>> mapped_interactions
508
- user_id item_1 item_2
509
- 0 0 0 0
510
- 1 1 1 1
511
- 2 2 2 2
797
+ user_id item_1 item_2 list
798
+ 0 0 0 0 [0, 1, 2]
799
+ 1 1 1 1 [2, 3, 4]
800
+ 2 2 2 2 [5, 6, 3]
512
801
  >>> encoder.mapping
513
802
  {'user_id': {'u1': 0, 'u2': 1, 'u3': 2},
514
803
  'item_1': {'item_1': 0, 'item_2': 1, 'item_3': 2},
515
- 'item_2': {'item_1': 0, 'item_2': 1, 'item_3': 2}}
804
+ 'item_2': {'item_1': 0, 'item_2': 1, 'item_3': 2},
805
+ 'list': {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, -1: 5, -2: 6}}
516
806
  >>> encoder.inverse_mapping
517
807
  {'user_id': {0: 'u1', 1: 'u2', 2: 'u3'},
518
808
  'item_1': {0: 'item_1', 1: 'item_2', 2: 'item_3'},
519
- 'item_2': {0: 'item_1', 1: 'item_2', 2: 'item_3'}}
809
+ 'item_2': {0: 'item_1', 1: 'item_2', 2: 'item_3'},
810
+ 'list': {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: -1, 6: -2}}
520
811
  >>> new_encoder = LabelEncoder([
521
812
  ... LabelEncodingRule("user_id", encoder.mapping["user_id"]),
522
813
  ... LabelEncodingRule("item_1", encoder.mapping["item_1"]),
523
- ... LabelEncodingRule("item_2", encoder.mapping["item_2"])
814
+ ... LabelEncodingRule("item_2", encoder.mapping["item_2"]),
815
+ ... SequenceEncodingRule("list", encoder.mapping["list"]),
524
816
  ... ])
525
817
  >>> new_encoder.inverse_transform(mapped_interactions)
526
- user_id item_1 item_2
527
- 0 u1 item_1 item_1
528
- 1 u2 item_2 item_2
529
- 2 u3 item_3 item_3
818
+ user_id item_1 item_2 list
819
+ 0 u1 item_1 item_1 [1, 2, 3]
820
+ 1 u2 item_2 item_2 [3, 4, 5]
821
+ 2 u3 item_3 item_3 [-1, -2, 4]
530
822
  <BLANKLINE>
531
823
  """
532
824
 
@@ -650,3 +942,43 @@ class LabelEncoder:
650
942
  raise ValueError(msg)
651
943
  rule = list(filter(lambda x: x.column == column, self.rules))
652
944
  rule[0].set_default_value(default_value)
945
+
946
+ def save(
947
+ self,
948
+ path: str,
949
+ ) -> None:
950
+ encoder_dict = {}
951
+ encoder_dict["_class_name"] = self.__class__.__name__
952
+
953
+ base_path = Path(path).with_suffix(".replay").resolve()
954
+ if os.path.exists(base_path): # pragma: no cover
955
+ msg = "There is already LabelEncoder object saved at the given path. File will be overwrited."
956
+ warnings.warn(msg)
957
+ else: # pragma: no cover
958
+ base_path.mkdir(parents=True, exist_ok=True)
959
+
960
+ encoder_dict["rule_names"] = []
961
+
962
+ for rule in self.rules:
963
+ path_suffix = f"{rule.__class__.__name__}_{rule.column}"
964
+ rule.save(str(base_path) + f"/rules/{path_suffix}")
965
+ encoder_dict["rule_names"].append(path_suffix)
966
+
967
+ with open(base_path / "init_args.json", "w+") as file:
968
+ json.dump(encoder_dict, file)
969
+
970
+ @classmethod
971
+ def load(cls, path: str) -> "LabelEncoder":
972
+ base_path = Path(path).with_suffix(".replay").resolve()
973
+ with open(base_path / "init_args.json", "r") as file:
974
+ encoder_dict = json.loads(file.read())
975
+ rules = []
976
+ for root, dirs, files in os.walk(str(base_path) + "/rules/"):
977
+ for d in dirs:
978
+ if d.split(".")[0] in encoder_dict["rule_names"]:
979
+ with open(root + d + "/init_args.json", "r") as file:
980
+ encoder_rule_dict = json.loads(file.read())
981
+ rules.append(globals()[encoder_rule_dict["_class_name"]].load(root + d))
982
+
983
+ encoder = cls(rules=rules)
984
+ return encoder
@@ -94,7 +94,7 @@ class ColdUserRandomSplitter(Splitter):
94
94
  ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
95
95
  train_users = (
96
96
  interactions.select(self.query_column)
97
- .unique()
97
+ .unique(maintain_order=True)
98
98
  .sample(fraction=(1 - threshold), seed=self.seed)
99
99
  .with_columns(pl.lit(False).alias("is_test"))
100
100
  )
replay/utils/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from .session_handler import State, get_spark_session
2
2
  from .types import (
3
+ OPENVINO_AVAILABLE,
3
4
  PYSPARK_AVAILABLE,
4
5
  TORCH_AVAILABLE,
5
6
  DataFrameLike,
replay/utils/common.py CHANGED
@@ -7,6 +7,10 @@ from typing import Any, Callable, Union
7
7
  from polars import from_pandas as pl_from_pandas
8
8
 
9
9
  from replay.data.dataset import Dataset
10
+ from replay.preprocessing import (
11
+ LabelEncoder,
12
+ LabelEncodingRule,
13
+ )
10
14
  from replay.splitters import (
11
15
  ColdUserRandomSplitter,
12
16
  KFolds,
@@ -38,20 +42,15 @@ SavableObject = Union[
38
42
  TimeSplitter,
39
43
  TwoStageSplitter,
40
44
  Dataset,
45
+ LabelEncoder,
46
+ LabelEncodingRule,
41
47
  ]
42
48
 
43
49
  if TORCH_AVAILABLE:
44
50
  from replay.data.nn import PandasSequentialDataset, PolarsSequentialDataset, SequenceTokenizer
45
51
 
46
52
  SavableObject = Union[
47
- ColdUserRandomSplitter,
48
- KFolds,
49
- LastNSplitter,
50
- NewUsersSplitter,
51
- RandomSplitter,
52
- RatioSplitter,
53
- TimeSplitter,
54
- TwoStageSplitter,
53
+ SavableObject,
55
54
  SequenceTokenizer,
56
55
  PandasSequentialDataset,
57
56
  PolarsSequentialDataset,
@@ -71,7 +71,7 @@ def get_spark_session(
71
71
  shuffle_partitions = os.cpu_count() * 3
72
72
  driver_memory = f"{spark_memory}g"
73
73
  user_home = os.environ["HOME"]
74
- spark = (
74
+ spark_session_builder = (
75
75
  SparkSession.builder.config("spark.driver.memory", driver_memory)
76
76
  .config(
77
77
  "spark.driver.extraJavaOptions",
@@ -87,10 +87,9 @@ def get_spark_session(
87
87
  .config("spark.kryoserializer.buffer.max", "256m")
88
88
  .config("spark.files.overwrite", "true")
89
89
  .master(f"local[{'*' if core_count == -1 else core_count}]")
90
- .enableHiveSupport()
91
- .getOrCreate()
92
90
  )
93
- return spark
91
+
92
+ return spark_session_builder.getOrCreate()
94
93
 
95
94
 
96
95
  def logger_with_settings() -> logging.Logger:
@@ -10,7 +10,7 @@ import pandas as pd
10
10
  from numpy.random import default_rng
11
11
 
12
12
  from .session_handler import State
13
- from .types import PYSPARK_AVAILABLE, DataFrameLike, MissingImportType, NumType, SparkDataFrame
13
+ from .types import PYSPARK_AVAILABLE, DataFrameLike, MissingImportType, NumType, PolarsDataFrame, SparkDataFrame
14
14
 
15
15
  if PYSPARK_AVAILABLE:
16
16
  import pyspark.sql.types as st
@@ -27,6 +27,12 @@ else:
27
27
  Column = MissingImportType
28
28
 
29
29
 
30
+ class PolarsConvertToSparkWarning(Warning):
31
+ """
32
+ Direct PolarsDataFrame to SparkDataFrame convertation warning.
33
+ """
34
+
35
+
30
36
  class SparkCollectToMasterWarning(Warning): # pragma: no cover
31
37
  """
32
38
  Collect to master warning for Spark DataFrames.
@@ -69,7 +75,15 @@ def convert2spark(data_frame: Optional[DataFrameLike]) -> Optional[SparkDataFram
69
75
  return None
70
76
  if isinstance(data_frame, SparkDataFrame):
71
77
  return data_frame
78
+
72
79
  spark = State().session
80
+ if isinstance(data_frame, PolarsDataFrame):
81
+ warnings.warn(
82
+ "Direct convertation PolarsDataFrame to SparkDataFrame currently is not supported, "
83
+ "converting to pandas first",
84
+ PolarsConvertToSparkWarning,
85
+ )
86
+ return spark.createDataFrame(data_frame.to_pandas()) # TODO: remove extra convertation to pandas
73
87
  return spark.createDataFrame(data_frame)
74
88
 
75
89
 
replay/utils/types.py CHANGED
@@ -25,6 +25,14 @@ try:
25
25
  except ImportError:
26
26
  TORCH_AVAILABLE = False
27
27
 
28
+ try:
29
+ import onnx # noqa: F401
30
+ import openvino # noqa: F401
31
+
32
+ OPENVINO_AVAILABLE = TORCH_AVAILABLE
33
+ except ImportError:
34
+ OPENVINO_AVAILABLE = False
35
+
28
36
  DataFrameLike = Union[PandasDataFrame, SparkDataFrame, PolarsDataFrame]
29
37
  IntOrList = Union[Iterable[int], int]
30
38
  NumType = Union[int, float]