replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -1,16 +1,17 @@
1
- import polars as pl
2
1
  from typing import List, Optional, Tuple
3
2
 
4
- from .base_splitter import Splitter
3
+ import polars as pl
4
+
5
5
  from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
6
6
 
7
+ from .base_splitter import Splitter
8
+
7
9
  if PYSPARK_AVAILABLE:
8
10
  import pyspark.sql.functions as sf
9
11
  from pyspark.sql import Window
10
12
  from pyspark.sql.types import IntegerType
11
13
 
12
14
 
13
- # pylint: disable=too-few-public-methods, too-many-instance-attributes
14
15
  class RatioSplitter(Splitter):
15
16
  """
16
17
  Split interactions into train and test by ratio. Split is made for each user separately.
@@ -82,6 +83,7 @@ class RatioSplitter(Splitter):
82
83
  14 3 2 2020-01-05
83
84
  <BLANKLINE>
84
85
  """
86
+
85
87
  _init_arg_names = [
86
88
  "test_size",
87
89
  "divide_column",
@@ -96,7 +98,6 @@ class RatioSplitter(Splitter):
96
98
  "session_id_processing_strategy",
97
99
  ]
98
100
 
99
- # pylint: disable=too-many-arguments
100
101
  def __init__(
101
102
  self,
102
103
  test_size: float,
@@ -160,7 +161,8 @@ class RatioSplitter(Splitter):
160
161
  self.min_interactions_per_group = min_interactions_per_group
161
162
  self.split_by_fractions = split_by_fractions
162
163
  if test_size < 0 or test_size > 1:
163
- raise ValueError("test_size must between 0 and 1")
164
+ msg = "test_size must between 0 and 1"
165
+ raise ValueError(msg)
164
166
  self.test_size = test_size
165
167
 
166
168
  def _add_time_partition(self, interactions: DataFrameLike) -> DataFrameLike:
@@ -171,7 +173,8 @@ class RatioSplitter(Splitter):
171
173
  if isinstance(interactions, PolarsDataFrame):
172
174
  return self._add_time_partition_to_polars(interactions)
173
175
 
174
- raise NotImplementedError(f"{self} is not implemented for {type(interactions)}")
176
+ msg = f"{self} is not implemented for {type(interactions)}"
177
+ raise NotImplementedError(msg)
175
178
 
176
179
  def _add_time_partition_to_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
177
180
  res = interactions.copy(deep=True)
@@ -189,14 +192,8 @@ class RatioSplitter(Splitter):
189
192
  return res
190
193
 
191
194
  def _add_time_partition_to_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
192
- res = (
193
- interactions
194
- .sort(self.timestamp_column)
195
- .with_columns(
196
- pl.cum_count(self.divide_column)
197
- .over(self.divide_column)
198
- .alias("row_num")
199
- )
195
+ res = interactions.sort(self.timestamp_column).with_columns(
196
+ pl.cum_count(self.divide_column).over(self.divide_column).alias("row_num")
200
197
  )
201
198
 
202
199
  return res
@@ -262,8 +259,7 @@ class RatioSplitter(Splitter):
262
259
  self, interactions: PolarsDataFrame, train_size: float
263
260
  ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
264
261
  interactions = interactions.with_columns(
265
- pl.count(self.timestamp_column).over(pl.col(self.divide_column))
266
- .alias("count")
262
+ pl.count(self.timestamp_column).over(pl.col(self.divide_column)).alias("count")
267
263
  )
268
264
  if self.min_interactions_per_group is not None:
269
265
  interactions = interactions.with_columns(
@@ -274,18 +270,14 @@ class RatioSplitter(Splitter):
274
270
  )
275
271
  else:
276
272
  interactions = interactions.with_columns(
277
- (pl.col("row_num") / pl.col("count")).round(self._precision)
278
- .alias("frac")
273
+ (pl.col("row_num") / pl.col("count")).round(self._precision).alias("frac")
279
274
  )
280
275
 
281
- interactions = interactions.with_columns(
282
- (pl.col("frac") > train_size)
283
- .alias("is_test")
284
- )
276
+ interactions = interactions.with_columns((pl.col("frac") > train_size).alias("is_test"))
285
277
  if self.session_id_column:
286
278
  interactions = self._recalculate_with_session_id_column(interactions)
287
279
 
288
- train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "frac", "is_test") # pylint: disable=invalid-unary-operand-type
280
+ train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "frac", "is_test")
289
281
  test = interactions.filter(pl.col("is_test")).drop("row_num", "count", "frac", "is_test")
290
282
 
291
283
  return train, test
@@ -316,7 +308,7 @@ class RatioSplitter(Splitter):
316
308
  "train_size",
317
309
  ] = (
318
310
  interactions["train_size"] - 1
319
- ) # pylint: disable=C0325
311
+ )
320
312
 
321
313
  interactions["is_test"] = interactions["row_num"] > interactions["train_size"]
322
314
  if self.session_id_column:
@@ -327,9 +319,7 @@ class RatioSplitter(Splitter):
327
319
 
328
320
  return train, test
329
321
 
330
- def _partial_split_spark(
331
- self, interactions: SparkDataFrame, ratio: float
332
- ) -> Tuple[SparkDataFrame, SparkDataFrame]:
322
+ def _partial_split_spark(self, interactions: SparkDataFrame, ratio: float) -> Tuple[SparkDataFrame, SparkDataFrame]:
333
323
  interactions = interactions.withColumn(
334
324
  "count", sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column))
335
325
  )
@@ -364,51 +354,37 @@ class RatioSplitter(Splitter):
364
354
  self, interactions: PolarsDataFrame, ratio: float
365
355
  ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
366
356
  interactions = interactions.with_columns(
367
- pl.count(self.timestamp_column).over(self.divide_column)
368
- .alias("count")
357
+ pl.count(self.timestamp_column).over(self.divide_column).alias("count")
369
358
  )
370
359
  if self.min_interactions_per_group is not None:
371
360
  interactions = interactions.with_columns(
372
- pl.when(
373
- pl.col("count") >= self.min_interactions_per_group
374
- )
375
- .then(
376
- pl.col("count") - (pl.col("count") * ratio).cast(interactions.get_column("count").dtype)
377
- )
361
+ pl.when(pl.col("count") >= self.min_interactions_per_group)
362
+ .then(pl.col("count") - (pl.col("count") * ratio).cast(interactions.get_column("count").dtype))
378
363
  .otherwise(pl.col("count"))
379
364
  .alias("train_size")
380
365
  )
381
366
  else:
382
- interactions = (
383
- interactions
384
- .with_columns(
385
- (pl.col("count") - (pl.col("count") * ratio).cast(interactions.get_column("count").dtype))
386
- .alias("train_size")
387
- )
388
- .with_columns(
389
- pl.when(
390
- (pl.col("count") * ratio > 0) & (pl.col("count") * ratio < 1) & (pl.col("train_size") > 1)
391
- )
392
- .then(pl.col("train_size") - 1)
393
- .otherwise(pl.col("train_size"))
394
- .alias("train_size")
367
+ interactions = interactions.with_columns(
368
+ (pl.col("count") - (pl.col("count") * ratio).cast(interactions.get_column("count").dtype)).alias(
369
+ "train_size"
395
370
  )
371
+ ).with_columns(
372
+ pl.when((pl.col("count") * ratio > 0) & (pl.col("count") * ratio < 1) & (pl.col("train_size") > 1))
373
+ .then(pl.col("train_size") - 1)
374
+ .otherwise(pl.col("train_size"))
375
+ .alias("train_size")
396
376
  )
397
377
 
398
- interactions = interactions.with_columns(
399
- (pl.col("row_num") > pl.col("train_size"))
400
- .alias("is_test")
401
- )
378
+ interactions = interactions.with_columns((pl.col("row_num") > pl.col("train_size")).alias("is_test"))
402
379
 
403
380
  if self.session_id_column:
404
381
  interactions = self._recalculate_with_session_id_column(interactions)
405
382
 
406
- train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "train_size", "is_test") # pylint: disable=invalid-unary-operand-type
383
+ train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "train_size", "is_test")
407
384
  test = interactions.filter(pl.col("is_test")).drop("row_num", "count", "train_size", "is_test")
408
385
 
409
386
  return train, test
410
387
 
411
- # pylint: disable=invalid-name
412
388
  def _core_split(self, interactions: DataFrameLike) -> List[DataFrameLike]:
413
389
  if self.split_by_fractions:
414
390
  return self._partial_split_fractions(interactions, self.test_size)
@@ -3,21 +3,21 @@ from typing import List, Optional, Tuple, Union
3
3
 
4
4
  import polars as pl
5
5
 
6
- from .base_splitter import Splitter
7
6
  from replay.utils import (
8
7
  PYSPARK_AVAILABLE,
9
8
  DataFrameLike,
10
9
  PandasDataFrame,
11
- SparkDataFrame,
12
10
  PolarsDataFrame,
11
+ SparkDataFrame,
13
12
  )
14
13
 
14
+ from .base_splitter import Splitter
15
+
15
16
  if PYSPARK_AVAILABLE:
16
17
  import pyspark.sql.functions as sf
17
18
  from pyspark.sql import Window
18
19
 
19
20
 
20
- # pylint: disable=too-few-public-methods
21
21
  class TimeSplitter(Splitter):
22
22
  """
23
23
  Split interactions by time.
@@ -85,6 +85,7 @@ class TimeSplitter(Splitter):
85
85
  14 3 2 2020-01-05
86
86
  <BLANKLINE>
87
87
  """
88
+
88
89
  _init_arg_names = [
89
90
  "time_threshold",
90
91
  "drop_cold_users",
@@ -97,10 +98,9 @@ class TimeSplitter(Splitter):
97
98
  "time_column_format",
98
99
  ]
99
100
 
100
- # pylint: disable=too-many-arguments
101
101
  def __init__(
102
102
  self,
103
- time_threshold: Union[datetime, str, int, float],
103
+ time_threshold: Union[datetime, str, float],
104
104
  query_column: str = "query_id",
105
105
  drop_cold_users: bool = False,
106
106
  drop_cold_items: bool = False,
@@ -144,7 +144,8 @@ class TimeSplitter(Splitter):
144
144
  self._precision = 3
145
145
  self.time_column_format = time_column_format
146
146
  if isinstance(time_threshold, float) and (time_threshold < 0 or time_threshold > 1):
147
- raise ValueError("time_threshold must be between 0 and 1")
147
+ msg = "time_threshold must be between 0 and 1"
148
+ raise ValueError(msg)
148
149
  self.time_threshold = time_threshold
149
150
 
150
151
  def _partial_split(
@@ -160,7 +161,8 @@ class TimeSplitter(Splitter):
160
161
  if isinstance(interactions, PolarsDataFrame):
161
162
  return self._partial_split_polars(interactions, threshold)
162
163
 
163
- raise NotImplementedError(f"{self} is not implemented for {type(interactions)}")
164
+ msg = f"{self} is not implemented for {type(interactions)}"
165
+ raise NotImplementedError(msg)
164
166
 
165
167
  def _partial_split_pandas(
166
168
  self, interactions: PandasDataFrame, threshold: Union[datetime, str, int]
@@ -191,9 +193,7 @@ class TimeSplitter(Splitter):
191
193
  )
192
194
  test_start = int(dates.count() * (1 - threshold)) + 1
193
195
  test_start = (
194
- dates.filter(sf.col("_row_number_by_ts") == test_start)
195
- .select(self.timestamp_column)
196
- .collect()[0][0]
196
+ dates.filter(sf.col("_row_number_by_ts") == test_start).select(self.timestamp_column).collect()[0][0]
197
197
  )
198
198
  res = interactions.withColumn("is_test", sf.col(self.timestamp_column) >= test_start)
199
199
  else:
@@ -212,20 +212,15 @@ class TimeSplitter(Splitter):
212
212
  if isinstance(threshold, float):
213
213
  test_start = int(len(interactions) * (1 - threshold)) + 1
214
214
 
215
- res = (
216
- interactions
217
- .sort(self.timestamp_column)
218
- .with_columns(
219
- (pl.col(self.timestamp_column).cum_count() >= test_start)
220
- .alias("is_test")
221
- )
215
+ res = interactions.sort(self.timestamp_column).with_columns(
216
+ (pl.col(self.timestamp_column).cum_count() >= test_start).alias("is_test")
222
217
  )
223
218
  else:
224
219
  res = interactions.with_columns((pl.col(self.timestamp_column) >= threshold).alias("is_test"))
225
220
 
226
221
  if self.session_id_column:
227
222
  res = self._recalculate_with_session_id_column(res)
228
- train = res.filter(~pl.col("is_test")).drop("is_test") # pylint: disable=invalid-unary-operand-type
223
+ train = res.filter(~pl.col("is_test")).drop("is_test")
229
224
  test = res.filter("is_test").drop("is_test")
230
225
 
231
226
  return train, test
@@ -1,18 +1,19 @@
1
1
  """
2
2
  This splitter split data by two columns.
3
3
  """
4
- from typing import Optional, Union
4
+ from typing import Optional, Tuple
5
+
5
6
  import polars as pl
6
7
 
7
- from .base_splitter import Splitter, SplitterReturnType
8
8
  from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
9
9
 
10
+ from .base_splitter import Splitter, SplitterReturnType
11
+
10
12
  if PYSPARK_AVAILABLE:
11
13
  import pyspark.sql.functions as sf
12
14
  from pyspark.sql import Window
13
15
 
14
16
 
15
- # pylint: disable=too-few-public-methods
16
17
  class TwoStageSplitter(Splitter):
17
18
  """
18
19
  Split data by two columns.
@@ -73,11 +74,10 @@ class TwoStageSplitter(Splitter):
73
74
  "timestamp_column",
74
75
  ]
75
76
 
76
- # pylint: disable=too-many-arguments
77
77
  def __init__(
78
78
  self,
79
- first_divide_size: Union[float, int],
80
- second_divide_size: Union[float, int],
79
+ first_divide_size: float,
80
+ second_divide_size: float,
81
81
  first_divide_column: str = "query_id",
82
82
  second_divide_column: str = "item_id",
83
83
  shuffle=False,
@@ -147,17 +147,12 @@ class TwoStageSplitter(Splitter):
147
147
  else:
148
148
  value_error = True
149
149
  if value_error:
150
- raise ValueError(
151
- f"""
152
- Invalid value for user_test_size: {self.first_divide_size}
153
- """
154
- )
150
+ msg = f"Invalid value for user_test_size: {self.first_divide_size}"
151
+ raise ValueError(msg)
155
152
  if isinstance(interactions, SparkDataFrame):
156
153
  test_users = (
157
154
  all_values.withColumn("_rand", sf.rand(self.seed))
158
- .withColumn(
159
- "_row_num", sf.row_number().over(Window.orderBy("_rand"))
160
- )
155
+ .withColumn("_row_num", sf.row_number().over(Window.orderBy("_rand")))
161
156
  .filter(f"_row_num <= {test_user_count}")
162
157
  .drop("_rand", "_row_num")
163
158
  )
@@ -168,11 +163,9 @@ class TwoStageSplitter(Splitter):
168
163
 
169
164
  return test_users
170
165
 
171
- def _split_proportion_spark(self, interactions: SparkDataFrame) -> Union[SparkDataFrame, SparkDataFrame]:
166
+ def _split_proportion_spark(self, interactions: SparkDataFrame) -> Tuple[SparkDataFrame, SparkDataFrame]:
172
167
  counts = interactions.groupBy(self.first_divide_column).count()
173
- test_users = self._get_test_values(interactions).withColumn(
174
- "is_test", sf.lit(True)
175
- )
168
+ test_users = self._get_test_values(interactions).withColumn("is_test", sf.lit(True))
176
169
  if self.shuffle:
177
170
  res = self._add_random_partition_spark(
178
171
  interactions.join(test_users, how="left", on=self.first_divide_column)
@@ -202,10 +195,10 @@ class TwoStageSplitter(Splitter):
202
195
 
203
196
  return train, test
204
197
 
205
- def _split_proportion_pandas(self, interactions: PandasDataFrame) -> Union[PandasDataFrame, PandasDataFrame]:
206
- counts = interactions.groupby(self.first_divide_column).agg(
207
- count=(self.first_divide_column, "count")
208
- ).reset_index()
198
+ def _split_proportion_pandas(self, interactions: PandasDataFrame) -> Tuple[PandasDataFrame, PandasDataFrame]:
199
+ counts = (
200
+ interactions.groupby(self.first_divide_column).agg(count=(self.first_divide_column, "count")).reset_index()
201
+ )
209
202
  test_users = self._get_test_values(interactions)
210
203
  test_users["is_test"] = True
211
204
  if self.shuffle:
@@ -229,11 +222,9 @@ class TwoStageSplitter(Splitter):
229
222
 
230
223
  return train, test
231
224
 
232
- def _split_proportion_polars(self, interactions: PolarsDataFrame) -> Union[PolarsDataFrame, PolarsDataFrame]:
225
+ def _split_proportion_polars(self, interactions: PolarsDataFrame) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
233
226
  counts = interactions.group_by(self.first_divide_column).count()
234
- test_users = self._get_test_values(interactions).with_columns(
235
- pl.lit(True).alias("is_test")
236
- )
227
+ test_users = self._get_test_values(interactions).with_columns(pl.lit(True).alias("is_test"))
237
228
  if self.shuffle:
238
229
  res = self._add_random_partition_polars(
239
230
  interactions.join(test_users, how="left", on=self.first_divide_column)
@@ -245,18 +236,15 @@ class TwoStageSplitter(Splitter):
245
236
  )
246
237
 
247
238
  res = res.join(counts, on=self.first_divide_column, how="left")
248
- res = res.with_columns(
249
- (pl.col("_row_num") / pl.col("count"))
250
- .alias("_frac")
251
- )
239
+ res = res.with_columns((pl.col("_row_num") / pl.col("count")).alias("_frac"))
252
240
  res = res.fill_null(False)
253
241
 
254
- train = res.filter(
255
- (pl.col("_frac") > self.second_divide_size) | (~pl.col("is_test")) # pylint: disable=invalid-unary-operand-type
256
- ).drop("_rand", "_row_num", "count", "_frac", "is_test")
257
- test = res.filter(
258
- (pl.col("_frac") <= self.second_divide_size) & pl.col("is_test")
259
- ).drop("_rand", "_row_num", "count", "_frac", "is_test")
242
+ train = res.filter((pl.col("_frac") > self.second_divide_size) | (~pl.col("is_test"))).drop(
243
+ "_rand", "_row_num", "count", "_frac", "is_test"
244
+ )
245
+ test = res.filter((pl.col("_frac") <= self.second_divide_size) & pl.col("is_test")).drop(
246
+ "_rand", "_row_num", "count", "_frac", "is_test"
247
+ )
260
248
 
261
249
  return train, test
262
250
 
@@ -274,12 +262,11 @@ class TwoStageSplitter(Splitter):
274
262
  if isinstance(interactions, PolarsDataFrame):
275
263
  return self._split_proportion_polars(interactions)
276
264
 
277
- raise NotImplementedError(f"{self} is not implemented for {type(interactions)}")
265
+ msg = f"{self} is not implemented for {type(interactions)}"
266
+ raise NotImplementedError(msg)
278
267
 
279
268
  def _split_quantity_spark(self, interactions: SparkDataFrame) -> SparkDataFrame:
280
- test_users = self._get_test_values(interactions).withColumn(
281
- "is_test", sf.lit(True)
282
- )
269
+ test_users = self._get_test_values(interactions).withColumn("is_test", sf.lit(True))
283
270
  if self.shuffle:
284
271
  res = self._add_random_partition_spark(
285
272
  interactions.join(test_users, how="left", on=self.first_divide_column)
@@ -328,9 +315,7 @@ class TwoStageSplitter(Splitter):
328
315
  return train, test
329
316
 
330
317
  def _split_quantity_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
331
- test_users = self._get_test_values(interactions).with_columns(
332
- pl.lit(True).alias("is_test")
333
- )
318
+ test_users = self._get_test_values(interactions).with_columns(pl.lit(True).alias("is_test"))
334
319
  if self.shuffle:
335
320
  res = self._add_random_partition_polars(
336
321
  interactions.join(test_users, how="left", on=self.first_divide_column)
@@ -342,12 +327,12 @@ class TwoStageSplitter(Splitter):
342
327
  )
343
328
 
344
329
  res = res.fill_null(False)
345
- train = res.filter(
346
- (pl.col("_row_num") > self.second_divide_size) | (~pl.col("is_test")) # pylint: disable=invalid-unary-operand-type
347
- ).drop("_row_num", "is_test")
348
- test = res.filter(
349
- (pl.col("_row_num") <= self.second_divide_size) & pl.col("is_test")
350
- ).drop("_row_num", "is_test")
330
+ train = res.filter((pl.col("_row_num") > self.second_divide_size) | (~pl.col("is_test"))).drop(
331
+ "_row_num", "is_test"
332
+ )
333
+ test = res.filter((pl.col("_row_num") <= self.second_divide_size) & pl.col("is_test")).drop(
334
+ "_row_num", "is_test"
335
+ )
351
336
 
352
337
  return train, test
353
338
 
@@ -365,7 +350,8 @@ class TwoStageSplitter(Splitter):
365
350
  if isinstance(interactions, PolarsDataFrame):
366
351
  return self._split_quantity_polars(interactions)
367
352
 
368
- raise NotImplementedError(f"{self} is not implemented for {type(interactions)}")
353
+ msg = f"{self} is not implemented for {type(interactions)}"
354
+ raise NotImplementedError(msg)
369
355
 
370
356
  def _core_split(self, interactions: DataFrameLike) -> SplitterReturnType:
371
357
  if 0 <= self.second_divide_size < 1.0:
@@ -373,11 +359,8 @@ class TwoStageSplitter(Splitter):
373
359
  elif self.second_divide_size >= 1 and isinstance(self.second_divide_size, int):
374
360
  train, test = self._split_quantity(interactions)
375
361
  else:
376
- raise ValueError(
377
- "`test_size` value must be [0, 1) or "
378
- "a positive integer; "
379
- f"test_size={self.second_divide_size}"
380
- )
362
+ msg = f"`test_size` value must be [0, 1) or a positive integer; test_size={self.second_divide_size}"
363
+ raise ValueError(msg)
381
364
 
382
365
  return train, test
383
366
 
@@ -391,9 +374,7 @@ class TwoStageSplitter(Splitter):
391
374
  dataframe = dataframe.withColumn("_rand", sf.rand(self.seed))
392
375
  dataframe = dataframe.withColumn(
393
376
  "_row_num",
394
- sf.row_number().over(
395
- Window.partitionBy(self.first_divide_column).orderBy("_rand")
396
- ),
377
+ sf.row_number().over(Window.partitionBy(self.first_divide_column).orderBy("_rand")),
397
378
  )
398
379
  return dataframe
399
380
 
@@ -404,14 +385,8 @@ class TwoStageSplitter(Splitter):
404
385
  return res
405
386
 
406
387
  def _add_random_partition_polars(self, dataframe: PolarsDataFrame) -> PolarsDataFrame:
407
- res = (
408
- dataframe
409
- .sample(fraction=1, shuffle=True, seed=self.seed)
410
- .with_columns(
411
- pl.cum_count(self.first_divide_column)
412
- .over(self.first_divide_column)
413
- .alias("_row_num")
414
- )
388
+ res = dataframe.sample(fraction=1, shuffle=True, seed=self.seed).with_columns(
389
+ pl.cum_count(self.first_divide_column).over(self.first_divide_column).alias("_row_num")
415
390
  )
416
391
  return res
417
392
 
@@ -431,11 +406,7 @@ class TwoStageSplitter(Splitter):
431
406
  """
432
407
  res = dataframe.withColumn(
433
408
  "_row_num",
434
- sf.row_number().over(
435
- Window.partitionBy(query_column).orderBy(
436
- sf.col(date_column).desc()
437
- )
438
- ),
409
+ sf.row_number().over(Window.partitionBy(query_column).orderBy(sf.col(date_column).desc())),
439
410
  )
440
411
  return res
441
412
 
@@ -456,13 +427,7 @@ class TwoStageSplitter(Splitter):
456
427
  query_column: str = "query_id",
457
428
  date_column: str = "timestamp",
458
429
  ) -> PolarsDataFrame:
459
- res = (
460
- dataframe
461
- .sort(date_column, descending=True)
462
- .with_columns(
463
- pl.cum_count(query_column)
464
- .over(query_column)
465
- .alias("_row_num")
466
- )
430
+ res = dataframe.sort(date_column, descending=True).with_columns(
431
+ pl.cum_count(query_column).over(query_column).alias("_row_num")
467
432
  )
468
433
  return res
replay/utils/__init__.py CHANGED
@@ -7,6 +7,6 @@ from .types import (
7
7
  MissingImportType,
8
8
  NumType,
9
9
  PandasDataFrame,
10
- SparkDataFrame,
11
10
  PolarsDataFrame,
11
+ SparkDataFrame,
12
12
  )
replay/utils/common.py ADDED
@@ -0,0 +1,65 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ from replay.splitters import (
6
+ ColdUserRandomSplitter,
7
+ KFolds,
8
+ LastNSplitter,
9
+ NewUsersSplitter,
10
+ RandomSplitter,
11
+ RatioSplitter,
12
+ TimeSplitter,
13
+ TwoStageSplitter,
14
+ )
15
+ from replay.utils import TORCH_AVAILABLE
16
+
17
+ SavableObject = Union[
18
+ ColdUserRandomSplitter,
19
+ KFolds,
20
+ LastNSplitter,
21
+ NewUsersSplitter,
22
+ RandomSplitter,
23
+ RatioSplitter,
24
+ TimeSplitter,
25
+ TwoStageSplitter,
26
+ ]
27
+
28
+ if TORCH_AVAILABLE:
29
+ from replay.data.nn import SequenceTokenizer
30
+
31
+ SavableObject = Union[
32
+ ColdUserRandomSplitter,
33
+ KFolds,
34
+ LastNSplitter,
35
+ NewUsersSplitter,
36
+ RandomSplitter,
37
+ RatioSplitter,
38
+ TimeSplitter,
39
+ TwoStageSplitter,
40
+ SequenceTokenizer,
41
+ ]
42
+
43
+
44
+ def save_to_replay(obj: SavableObject, path: Union[str, Path]) -> None:
45
+ """
46
+ General function to save RePlay models, splitters and tokenizer.
47
+
48
+ :param path: Path to save the object.
49
+ """
50
+ obj.save(path)
51
+
52
+
53
+ def load_from_replay(path: Union[str, Path]) -> SavableObject:
54
+ """
55
+ General function to load RePlay models, splitters and tokenizer.
56
+
57
+ :param path: Path to save the object.
58
+ """
59
+ path = Path(path).with_suffix(".replay").resolve()
60
+ with open(path / "init_args.json", "r") as file:
61
+ class_name = json.loads(file.read())["_class_name"]
62
+ obj_type = globals()[class_name]
63
+ obj = obj_type.load(path)
64
+
65
+ return obj