replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,69 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+
6
+ class TokenMaskTransform(torch.nn.Module):
7
+ """
8
+ For the feature tensor specified by ``token_field``, randomly masks items
9
+ in the sequence based on a uniform distribution with specified probability of masking.
10
+ In fact, this transform creates mask for the Masked Language Modeling (MLM) task analog in the recommendations.
11
+
12
+ Example:
13
+
14
+ .. code-block:: python
15
+
16
+ >>> _ = torch.manual_seed(0)
17
+ >>> input_tensor = {"padding_id": torch.BoolTensor([0, 1, 1])}
18
+ >>> transform = TokenMaskTransform("padding_id")
19
+ >>> output_tensor = transform(input_tensor)
20
+ >>> output_tensor
21
+ {'padding_id': tensor([False, True, True]),
22
+ 'token_mask': tensor([False, True, False])}
23
+
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ token_field: str,
29
+ out_feature_name: str = "token_mask",
30
+ mask_prob: float = 0.15,
31
+ generator: Optional[torch.Generator] = None,
32
+ ) -> None:
33
+ """
34
+ :param token_field: Name of the column containing the unmasked tokes.
35
+ :param out_feature_name: Name of the resulting mask column. Default: ``token_mask``.
36
+ :param mask_prob: Probability of masking the item, i.e. setting it to ``0``. Default: ``0.15``.
37
+ :param generator: Random number generator to be used for generating
38
+ the uniform distribution. Default: ``None``.
39
+ """
40
+ super().__init__()
41
+ self.token_field = token_field
42
+ self.out_feature_name = out_feature_name
43
+ self.mask_prob = mask_prob
44
+ self.generator = generator
45
+
46
+ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
47
+ output_batch = dict(batch.items())
48
+
49
+ paddings = batch[self.token_field]
50
+
51
+ assert paddings.dtype == torch.bool, "Source tensor for token mask should be boolean."
52
+
53
+ mask_prob = torch.rand(paddings.size(-1), dtype=torch.float32, generator=self.generator).to(
54
+ device=paddings.device
55
+ )
56
+
57
+ # mask[i], 0 ~ mask_prob, 1 ~ (1 - mask_prob)
58
+ mask = (mask_prob * paddings) >= self.mask_prob
59
+
60
+ # Fix corner cases in mask
61
+ # 1. If all token are not masked, add mask to the end
62
+ if mask.all() or mask[paddings].all():
63
+ mask[-1] = 0
64
+ # 2. If all token are masked, add non-masked before the last
65
+ elif (not mask.any()) and (len(mask) > 1):
66
+ mask[-2] = 1
67
+
68
+ output_batch[self.out_feature_name] = mask
69
+ return output_batch
@@ -0,0 +1,51 @@
1
+ from typing import List, Union
2
+
3
+ import torch
4
+
5
+
6
+ class TrimTransform(torch.nn.Module):
7
+ """
8
+ Trims sequences of specified names `feature_names` keeping the specified sequence length `seq_len` on the right.
9
+
10
+ Example:
11
+
12
+ .. code-block:: python
13
+
14
+ >>> input_batch = {
15
+ ... "user_id": torch.LongTensor([111]),
16
+ ... "item_id": torch.LongTensor([[5, 4, 0, 7, 4]]),
17
+ ... "seen_ids": torch.LongTensor([[5, 4, 0, 7, 4]]),
18
+ ... }
19
+ >>> transform = TrimTransform(seq_len=3, feature_names="item_id")
20
+ >>> output_batch = transform(input_batch)
21
+ >>> output_batch
22
+ {'user_id': tensor([111]),
23
+ 'item_id': tensor([[0, 7, 4]]),
24
+ 'seen_ids': tensor([[5, 4, 0, 7, 4]])}
25
+
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ seq_len: int,
31
+ feature_names: Union[List[str], str],
32
+ ) -> None:
33
+ """
34
+ :param seq_len: max sequence length used in model. Must be positive.
35
+ :param feature_name: name of feature in batch to be trimmed.
36
+ """
37
+ super().__init__()
38
+ assert seq_len > 0
39
+ self.seq_len = seq_len
40
+ self.feature_names = [feature_names] if isinstance(feature_names, str) else feature_names
41
+
42
+ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
43
+ output_batch = dict(batch.items())
44
+
45
+ for name in self.feature_names:
46
+ assert output_batch[name].shape[1] >= self.seq_len
47
+
48
+ trimmed_seq = output_batch[name][:, -self.seq_len :, ...].clone()
49
+ output_batch[name] = trimmed_seq
50
+
51
+ return output_batch
replay/nn/utils.py ADDED
@@ -0,0 +1,28 @@
1
+ import warnings
2
+ from typing import Callable, Literal, Tuple
3
+
4
+ import torch
5
+
6
+
7
+ def warning_is_not_none(msg: str) -> Callable:
8
+ def checker(value: Tuple[torch.Tensor, str]) -> bool:
9
+ if value[0] is not None:
10
+ warnings.warn(msg.format(value[1]), RuntimeWarning, stacklevel=2)
11
+ return False
12
+ return True
13
+
14
+ return checker
15
+
16
+
17
+ def create_activation(
18
+ activation: Literal["relu", "gelu", "sigmoid"],
19
+ ) -> torch.nn.Module:
20
+ """The function of creating an activation function based on its name"""
21
+ if activation == "relu":
22
+ return torch.nn.ReLU()
23
+ if activation == "gelu":
24
+ return torch.nn.GELU()
25
+ if activation == "sigmoid":
26
+ return torch.nn.Sigmoid()
27
+ msg = "Expected to get activation relu/gelu/sigmoid"
28
+ raise ValueError(msg)
@@ -1090,3 +1090,131 @@ class ConsecutiveDuplicatesFilter(_BaseFilter):
1090
1090
  .where((sf.col(self.item_column) != sf.col(self.temporary_column)) | sf.col(self.temporary_column).isNull())
1091
1091
  .drop(self.temporary_column)
1092
1092
  )
1093
+
1094
+
1095
+ def _check_col_present(
1096
+ target: DataFrameLike,
1097
+ reference: DataFrameLike,
1098
+ columns_to_process: list[str],
1099
+ ) -> None:
1100
+ target_columns = set(target.columns)
1101
+ reference_columns = set(reference.columns)
1102
+ for column in columns_to_process:
1103
+ if column not in target_columns or column not in reference_columns:
1104
+ msg = f"Column '{column}' must be in both dataframes"
1105
+ raise KeyError(msg)
1106
+
1107
+
1108
+ def _filter_cold_pandas(
1109
+ target: PandasDataFrame,
1110
+ reference: PandasDataFrame,
1111
+ columns_to_process: list[str],
1112
+ ) -> PandasDataFrame:
1113
+ for column in columns_to_process:
1114
+ allowed_values = reference[column].unique()
1115
+ target = target[target[column].isin(allowed_values)]
1116
+ return target
1117
+
1118
+
1119
+ def _filter_cold_polars(
1120
+ target: PolarsDataFrame,
1121
+ reference: PolarsDataFrame,
1122
+ columns_to_process: list[str],
1123
+ ) -> PolarsDataFrame:
1124
+ for column in columns_to_process:
1125
+ allowed_values = reference.select(column).unique()
1126
+ target = target.join(allowed_values, on=column, how="semi")
1127
+ return target
1128
+
1129
+
1130
+ def _filter_cold_spark(
1131
+ target: SparkDataFrame,
1132
+ reference: SparkDataFrame,
1133
+ columns_to_process: list[str],
1134
+ ) -> SparkDataFrame:
1135
+ for column in columns_to_process:
1136
+ allowed_values = reference.select(column).distinct()
1137
+ target = target.join(allowed_values, on=column, how="left_semi")
1138
+ return target
1139
+
1140
+
1141
+ def filter_cold(
1142
+ target: DataFrameLike,
1143
+ reference: DataFrameLike,
1144
+ mode: Literal["items", "users", "both"] = "items",
1145
+ query_column: str = "query_id",
1146
+ item_column: str = "item_id",
1147
+ ) -> DataFrameLike:
1148
+ """
1149
+ Filter rows in ``target`` keeping only users/items that exist in ``reference``.
1150
+
1151
+ This function works with pandas, Polars and Spark DataFrames. ``target`` and
1152
+ ``reference`` must be of the same backend type. Depending on ``mode``, it
1153
+ removes rows whose ``item_column`` and/or ``query_column`` values are not
1154
+ present in the corresponding columns of ``reference``.
1155
+
1156
+ Parameters
1157
+ ----------
1158
+ target : DataFrameLike
1159
+ Dataset to be filtered (pandas/Polars/Spark).
1160
+ reference : DataFrameLike
1161
+ Dataset that defines the allowed universe of users/items.
1162
+ mode : {"items", "users", "both"}, default "items"
1163
+ What to filter: only items, only users, or both.
1164
+ query_column : str, default "query_id"
1165
+ Name of the user (query) column.
1166
+ item_column : str, default "item_id"
1167
+ Name of the item column.
1168
+
1169
+ Returns
1170
+ -------
1171
+ DataFrameLike
1172
+ Filtered ``target`` of the same backend type as the input.
1173
+
1174
+ Raises
1175
+ ------
1176
+ ValueError
1177
+ If ``mode`` is not one of {"items", "users", "both"}.
1178
+ TypeError
1179
+ If ``target`` and ``reference`` are of different backend types.
1180
+ KeyError
1181
+ If required columns are missing in either dataset.
1182
+ NotImplementedError
1183
+ If the input dataframe type is not supported.
1184
+ """
1185
+ if mode not in {"items", "users", "both"}:
1186
+ msg = "mode must be 'items' | 'users' | 'both'"
1187
+ raise ValueError(msg)
1188
+ if not isinstance(target, type(reference)):
1189
+ msg = "Target and reference must be of the same type"
1190
+ raise TypeError(msg)
1191
+
1192
+ if mode == "both":
1193
+ columns_to_process = [query_column, item_column]
1194
+ elif mode == "items":
1195
+ columns_to_process = [item_column]
1196
+ elif mode == "users":
1197
+ columns_to_process = [query_column]
1198
+
1199
+ _check_col_present(target, reference, columns_to_process)
1200
+
1201
+ if isinstance(target, PandasDataFrame):
1202
+ return _filter_cold_pandas(
1203
+ target,
1204
+ reference,
1205
+ columns_to_process,
1206
+ )
1207
+ if isinstance(target, PolarsDataFrame):
1208
+ return _filter_cold_polars(
1209
+ target,
1210
+ reference,
1211
+ columns_to_process,
1212
+ )
1213
+ if isinstance(target, SparkDataFrame):
1214
+ return _filter_cold_spark(
1215
+ target,
1216
+ reference,
1217
+ columns_to_process,
1218
+ )
1219
+ msg = f"Unsupported data frame type: {type(target)}"
1220
+ raise NotImplementedError(msg)
@@ -26,7 +26,7 @@ from replay.utils import (
26
26
 
27
27
  if PYSPARK_AVAILABLE:
28
28
  from pyspark.sql import Window, functions as sf # noqa: I001
29
- from pyspark.sql.types import LongType, IntegerType, ArrayType
29
+ from pyspark.sql.types import LongType
30
30
  from replay.utils.session_handler import get_spark_session
31
31
 
32
32
  HandleUnknownStrategies = Literal["error", "use_default_value", "drop"]
@@ -185,11 +185,11 @@ class LabelEncodingRule(BaseLabelEncodingRule):
185
185
  self._mapping = mapping_on_spark.rdd.collectAsMap()
186
186
 
187
187
  def _fit_pandas(self, df: PandasDataFrame) -> None:
188
- unique_col_values = df[self._col].drop_duplicates().reset_index(drop=True)
188
+ unique_col_values = df[self._col].sort_values().drop_duplicates().reset_index(drop=True)
189
189
  self._mapping = {val: key for key, val in unique_col_values.to_dict().items()}
190
190
 
191
191
  def _fit_polars(self, df: PolarsDataFrame) -> None:
192
- unique_col_values = df.select(self._col).unique()
192
+ unique_col_values = df.sort(self._col).select(self._col).unique()
193
193
  self._mapping = {key: val for val, key in enumerate(unique_col_values.to_series().to_list())}
194
194
 
195
195
  def fit(self, df: DataFrameLike) -> "LabelEncodingRule":
@@ -630,37 +630,40 @@ class SequenceEncodingRule(LabelEncodingRule):
630
630
  return self
631
631
 
632
632
  def _transform_spark(self, df: SparkDataFrame, default_value: Optional[int]) -> SparkDataFrame:
633
- def mapper_udf(x):
634
- return [mapping.get(value) for value in x] # pragma: no cover
633
+ other_columns = [col for col in df.columns if col != self._col]
635
634
 
636
- mapping = self.get_mapping()
637
- call_mapper_udf = sf.udf(mapper_udf, ArrayType(IntegerType()))
638
- encoded_df = df.withColumn(self._target_col, call_mapper_udf(sf.col(self.column)))
635
+ mapping_on_spark = get_spark_session().createDataFrame(
636
+ data=list(self.get_mapping().items()), schema=[self._col, self._target_col]
637
+ )
638
+ encoded_df = (
639
+ df.select(*other_columns, sf.posexplode(self._col))
640
+ .withColumnRenamed("col", self._col)
641
+ .join(mapping_on_spark, on=self._col, how="left")
642
+ )
643
+
644
+ if self._handle_unknown == "error":
645
+ if encoded_df.filter(sf.col(self._target_col).isNull()).count() > 0:
646
+ msg = f"Found unknown labels in column {self._col} during transform"
647
+ raise ValueError(msg)
648
+ else:
649
+ if default_value is not None:
650
+ encoded_df = encoded_df.fillna(default_value, subset=[self._target_col])
639
651
 
652
+ result = encoded_df.groupBy(other_columns).agg(
653
+ sf.sort_array(sf.collect_list(sf.struct("pos", self._target_col)))
654
+ .getItem(self._target_col)
655
+ .alias(self._col)
656
+ )
640
657
  if self._handle_unknown == "drop":
641
- encoded_df = encoded_df.withColumn(self._target_col, sf.filter(self._target_col, lambda x: x.isNotNull()))
642
- if encoded_df.select(sf.max(sf.size(self._target_col))).first()[0] == 0:
658
+ result = result.withColumn(self._col, sf.filter(self._col, lambda x: x.isNotNull()))
659
+ if result.select(sf.max(sf.size(self._col))).first()[0] == 0:
643
660
  warnings.warn(
644
661
  f"You are trying to transform dataframe with all values are unknown for {self._col}, "
645
662
  "with `handle_unknown_strategy=drop` leads to empty dataframe",
646
663
  LabelEncoderTransformWarning,
647
664
  )
648
- elif self._handle_unknown == "error":
649
- if (
650
- encoded_df.select(sf.sum(sf.array_contains(self._target_col, -1).isNull().cast("integer"))).first()[0]
651
- != 0
652
- ):
653
- msg = f"Found unknown labels in column {self._col} during transform"
654
- raise ValueError(msg)
655
- else:
656
- if default_value:
657
- encoded_df = encoded_df.withColumn(
658
- self._target_col,
659
- sf.transform(self._target_col, lambda x: sf.when(x.isNull(), default_value).otherwise(x)),
660
- )
661
665
 
662
- result_df = encoded_df.drop(self._col).withColumnRenamed(self._target_col, self._col)
663
- return result_df
666
+ return result
664
667
 
665
668
  def _transform_pandas(self, df: PandasDataFrame, default_value: Optional[int]) -> PandasDataFrame:
666
669
  mapping = self.get_mapping()
@@ -771,7 +774,7 @@ class SequenceEncodingRule(LabelEncodingRule):
771
774
  def _inverse_transform_spark(self, df: SparkDataFrame) -> SparkDataFrame:
772
775
  array_expr = sf.array([sf.lit(x) for x in self._inverse_mapping_list])
773
776
  decoded_df = df.withColumn(
774
- self._target_col, sf.transform(self._col, lambda x: sf.element_at(array_expr, x + 1))
777
+ self._target_col, sf.transform(self._col, lambda x: sf.element_at(array_expr, x.cast("int") + 1))
775
778
  )
776
779
  return decoded_df.drop(self._col).withColumnRenamed(self._target_col, self._col)
777
780
 
@@ -800,19 +803,19 @@ class LabelEncoder:
800
803
  >>> mapped_interactions = encoder.fit_transform(user_interactions)
801
804
  >>> mapped_interactions
802
805
  user_id item_1 item_2 list
803
- 0 0 0 0 [0, 1, 2]
804
- 1 1 1 1 [2, 3, 4]
805
- 2 2 2 2 [5, 6, 3]
806
+ 0 0 0 0 [2, 3, 4]
807
+ 1 1 1 1 [4, 5, 6]
808
+ 2 2 2 2 [1, 0, 5]
806
809
  >>> encoder.mapping
807
810
  {'user_id': {'u1': 0, 'u2': 1, 'u3': 2},
808
811
  'item_1': {'item_1': 0, 'item_2': 1, 'item_3': 2},
809
812
  'item_2': {'item_1': 0, 'item_2': 1, 'item_3': 2},
810
- 'list': {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, -1: 5, -2: 6}}
813
+ 'list': {-2: 0, -1: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6}}
811
814
  >>> encoder.inverse_mapping
812
815
  {'user_id': {0: 'u1', 1: 'u2', 2: 'u3'},
813
816
  'item_1': {0: 'item_1', 1: 'item_2', 2: 'item_3'},
814
817
  'item_2': {0: 'item_1', 1: 'item_2', 2: 'item_3'},
815
- 'list': {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: -1, 6: -2}}
818
+ 'list': {0: -2, 1: -1, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5}}
816
819
  >>> new_encoder = LabelEncoder([
817
820
  ... LabelEncodingRule("user_id", encoder.mapping["user_id"]),
818
821
  ... LabelEncodingRule("item_1", encoder.mapping["item_1"]),
@@ -834,14 +837,14 @@ class LabelEncoder:
834
837
  self.rules = rules
835
838
 
836
839
  @property
837
- def mapping(self) -> Mapping[str, Mapping]:
840
+ def mapping(self) -> dict[str, Mapping]:
838
841
  """
839
842
  Returns mapping of each column in given rules.
840
843
  """
841
844
  return {r.column: r.get_mapping() for r in self.rules}
842
845
 
843
846
  @property
844
- def inverse_mapping(self) -> Mapping[str, Mapping]:
847
+ def inverse_mapping(self) -> dict[str, Mapping]:
845
848
  """
846
849
  Returns inverse mapping of each column in given rules.
847
850
  """
@@ -0,0 +1,209 @@
1
+ import logging
2
+ from typing import (
3
+ Literal,
4
+ Optional,
5
+ Sequence,
6
+ )
7
+
8
+ import pandas as pd
9
+ import polars as pl
10
+
11
+ from replay.utils import (
12
+ PYSPARK_AVAILABLE,
13
+ DataFrameLike,
14
+ PandasDataFrame,
15
+ PolarsDataFrame,
16
+ SparkDataFrame,
17
+ )
18
+
19
+ if PYSPARK_AVAILABLE:
20
+ import pyspark.sql.functions as sf
21
+
22
+
23
+ def _ensure_columns_match(df, ref_cols, index: int, check_columns: bool) -> None:
24
+ if check_columns and set(df.columns) != set(ref_cols):
25
+ msg = f"Columns mismatch in dataframe #{index}: {sorted(df.columns)} != {sorted(ref_cols)}"
26
+ raise ValueError(msg)
27
+
28
+
29
+ def _merge_subsets_pandas(
30
+ dfs: Sequence[PandasDataFrame],
31
+ columns: Optional[Sequence[str]],
32
+ check_columns: bool,
33
+ subset_for_duplicates: Optional[Sequence[str]],
34
+ on_duplicate: Literal["error", "drop", "ignore"],
35
+ ) -> PandasDataFrame:
36
+ ref_cols = list(dfs[0].columns) if columns is None else list(columns)
37
+
38
+ aligned: list[PandasDataFrame] = []
39
+ for i, df in enumerate(dfs):
40
+ _ensure_columns_match(df, ref_cols, i, check_columns)
41
+ aligned.append(df[ref_cols])
42
+
43
+ merged = pd.concat(aligned, axis=0, ignore_index=True)
44
+
45
+ if on_duplicate == "ignore":
46
+ return merged
47
+
48
+ dup_subset = ref_cols if subset_for_duplicates is None else list(subset_for_duplicates)
49
+ dup_mask = merged.duplicated(subset=dup_subset, keep="first")
50
+ dup_count = int(dup_mask.sum())
51
+
52
+ if dup_count > 0:
53
+ if on_duplicate == "error":
54
+ msg = f"Found {dup_count} duplicate rows on subset {dup_subset}"
55
+ raise ValueError(msg)
56
+ if on_duplicate == "drop":
57
+ merged = merged.drop_duplicates(subset=dup_subset, keep="first").reset_index(drop=True)
58
+ logging.getLogger("replay").warning(
59
+ f"Found {dup_count} duplicate rows on subset {dup_subset} and dropped them"
60
+ )
61
+
62
+ return merged
63
+
64
+
65
+ def _merge_subsets_polars(
66
+ dfs: Sequence[PolarsDataFrame],
67
+ columns: Optional[Sequence[str]],
68
+ check_columns: bool,
69
+ subset_for_duplicates: Optional[Sequence[str]],
70
+ on_duplicate: Literal["error", "drop", "ignore"],
71
+ ) -> PolarsDataFrame:
72
+ ref_cols = list(dfs[0].columns) if columns is None else list(columns)
73
+
74
+ aligned: list[PolarsDataFrame] = []
75
+ for i, df in enumerate(dfs):
76
+ _ensure_columns_match(df, ref_cols, i, check_columns)
77
+ aligned.append(df.select(ref_cols))
78
+
79
+ merged = pl.concat(aligned, how="vertical")
80
+
81
+ if on_duplicate == "ignore":
82
+ return merged
83
+
84
+ dup_subset = ref_cols if subset_for_duplicates is None else list(subset_for_duplicates)
85
+ dup_mask = merged.select(dup_subset).is_duplicated()
86
+ dup_count = int(dup_mask.sum())
87
+
88
+ if dup_count > 0:
89
+ if on_duplicate == "error":
90
+ msg = f"Found {dup_count} duplicate rows on subset {dup_subset}"
91
+ raise ValueError(msg)
92
+ if on_duplicate == "drop":
93
+ merged = merged.unique(subset=dup_subset, keep="first", maintain_order=True)
94
+ logging.getLogger("replay").warning(
95
+ f"Found {dup_count} duplicate rows on subset {dup_subset} and dropped them"
96
+ )
97
+
98
+ return merged
99
+
100
+
101
+ def _merge_subsets_spark(
102
+ dfs: Sequence[SparkDataFrame],
103
+ columns: Optional[Sequence[str]],
104
+ check_columns: bool,
105
+ subset_for_duplicates: Optional[Sequence[str]],
106
+ on_duplicate: Literal["error", "drop", "ignore"],
107
+ ) -> SparkDataFrame:
108
+ ref_cols = list(dfs[0].columns) if columns is None else list(columns)
109
+
110
+ merged = None
111
+ for i, df in enumerate(dfs):
112
+ _ensure_columns_match(df, ref_cols, i, check_columns)
113
+ part = df.select(*ref_cols)
114
+ merged = part if merged is None else merged.unionByName(part)
115
+
116
+ if on_duplicate == "ignore":
117
+ return merged
118
+
119
+ dup_subset = ref_cols if subset_for_duplicates is None else list(subset_for_duplicates)
120
+ if on_duplicate == "error" and merged.groupBy(*dup_subset).count().filter(sf.col("count") > 1).limit(1).count() > 0:
121
+ msg = f"Found duplicate rows on subset {dup_subset}"
122
+ raise ValueError(msg)
123
+ if on_duplicate == "drop":
124
+ unique = merged.dropDuplicates(dup_subset)
125
+ logging.getLogger("replay").warning(
126
+ f"Found {merged.count() - unique.count()} duplicate rows on subset {dup_subset} and dropped them"
127
+ )
128
+ merged = unique
129
+
130
+ return merged
131
+
132
+
133
+ def merge_subsets(
134
+ dfs: Sequence[DataFrameLike],
135
+ columns: Optional[Sequence[str]] = None,
136
+ check_columns: bool = True,
137
+ subset_for_duplicates: Optional[Sequence[str]] = None,
138
+ on_duplicate: Literal["error", "drop", "ignore"] = "error",
139
+ ) -> DataFrameLike:
140
+ """Merge multiple dataframes of the same backend into a single one.
141
+
142
+ All inputs must be of the same dataframe type (pandas/Polars/Spark). Before
143
+ concatenation, each dataframe is aligned to a common set of columns: either
144
+ the provided ``columns`` or the columns of the first dataframe. Duplicate
145
+ rows are handled according to ``on_duplicate``.
146
+
147
+ Parameters
148
+ ----------
149
+ dfs : Sequence[DataFrameLike]
150
+ Dataframes to merge.
151
+ columns : Optional[Sequence[str]]
152
+ Columns to align to. If ``None``, columns of the first dataframe are used.
153
+ check_columns : bool
154
+ Whether to validate that all inputs have the same column set.
155
+ subset_for_duplicates : Optional[Sequence[str]]
156
+ Columns subset used to detect duplicates. If ``None``, all aligned columns
157
+ are used.
158
+ on_duplicate : {"error", "drop", "ignore"}
159
+ How to handle duplicates: raise an error, drop them, or ignore.
160
+
161
+ Returns
162
+ -------
163
+ DataFrameLike
164
+ Merged dataframe of the same backend as the inputs.
165
+
166
+ Raises
167
+ ------
168
+ ValueError
169
+ If ``dfs`` is empty, if duplicates are found with ``on_duplicate='error'``,
170
+ or if column sets differ when validation is enabled.
171
+ TypeError
172
+ If inputs are of different dataframe types.
173
+ """
174
+ if not dfs:
175
+ msg = "At least one dataframe is required"
176
+ raise ValueError(msg)
177
+
178
+ first = dfs[0]
179
+ if any(not isinstance(df, type(first)) for df in dfs):
180
+ msg = "All input dataframes must be of the same type"
181
+ raise TypeError(msg)
182
+
183
+ if isinstance(first, PandasDataFrame):
184
+ return _merge_subsets_pandas(
185
+ dfs,
186
+ columns=columns,
187
+ check_columns=check_columns,
188
+ subset_for_duplicates=subset_for_duplicates,
189
+ on_duplicate=on_duplicate,
190
+ )
191
+ if isinstance(first, PolarsDataFrame):
192
+ return _merge_subsets_polars(
193
+ dfs,
194
+ columns=columns,
195
+ check_columns=check_columns,
196
+ subset_for_duplicates=subset_for_duplicates,
197
+ on_duplicate=on_duplicate,
198
+ )
199
+ if isinstance(first, SparkDataFrame):
200
+ return _merge_subsets_spark(
201
+ dfs,
202
+ columns=columns,
203
+ check_columns=check_columns,
204
+ subset_for_duplicates=subset_for_duplicates,
205
+ on_duplicate=on_duplicate,
206
+ )
207
+
208
+ msg = f"Unsupported data frame type: {type(first)}"
209
+ raise NotImplementedError(msg)
@@ -7,6 +7,7 @@ from .cold_user_random_splitter import ColdUserRandomSplitter
7
7
  from .k_folds import KFolds
8
8
  from .last_n_splitter import LastNSplitter
9
9
  from .new_users_splitter import NewUsersSplitter
10
+ from .random_next_n_splitter import RandomNextNSplitter
10
11
  from .random_splitter import RandomSplitter
11
12
  from .ratio_splitter import RatioSplitter
12
13
  from .time_splitter import TimeSplitter