replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -1,7 +1,7 @@
1
- from typing import Optional, Union
1
+ from typing import Optional, Tuple
2
+
2
3
  import polars as pl
3
4
 
4
- from .base_splitter import Splitter, SplitterReturnType
5
5
  from replay.utils import (
6
6
  PYSPARK_AVAILABLE,
7
7
  DataFrameLike,
@@ -10,11 +10,12 @@ from replay.utils import (
10
10
  SparkDataFrame,
11
11
  )
12
12
 
13
+ from .base_splitter import Splitter, SplitterReturnType
14
+
13
15
  if PYSPARK_AVAILABLE:
14
16
  import pyspark.sql.functions as sf
15
17
 
16
18
 
17
- # pylint: disable=too-few-public-methods, duplicate-code
18
19
  class ColdUserRandomSplitter(Splitter):
19
20
  """
20
21
  Test set consists of all actions of randomly chosen users.
@@ -28,7 +29,6 @@ class ColdUserRandomSplitter(Splitter):
28
29
  "item_column",
29
30
  ]
30
31
 
31
- # pylint: disable=too-many-arguments
32
32
  def __init__(
33
33
  self,
34
34
  test_size: float,
@@ -52,14 +52,13 @@ class ColdUserRandomSplitter(Splitter):
52
52
  )
53
53
  self.seed = seed
54
54
  if test_size <= 0 or test_size >= 1:
55
- raise ValueError("test_size must between 0 and 1")
55
+ msg = "test_size must between 0 and 1"
56
+ raise ValueError(msg)
56
57
  self.test_size = test_size
57
58
 
58
59
  def _core_split_pandas(
59
- self,
60
- interactions: PandasDataFrame,
61
- threshold: float
62
- ) -> Union[PandasDataFrame, PandasDataFrame]:
60
+ self, interactions: PandasDataFrame, threshold: float
61
+ ) -> Tuple[PandasDataFrame, PandasDataFrame]:
63
62
  users = PandasDataFrame(interactions[self.query_column].unique(), columns=[self.query_column])
64
63
  train_users = users.sample(frac=(1 - threshold), random_state=self.seed)
65
64
  train_users["is_test"] = False
@@ -74,19 +73,15 @@ class ColdUserRandomSplitter(Splitter):
74
73
  return train, test
75
74
 
76
75
  def _core_split_spark(
77
- self,
78
- interactions: SparkDataFrame,
79
- threshold: float
80
- ) -> Union[SparkDataFrame, SparkDataFrame]:
76
+ self, interactions: SparkDataFrame, threshold: float
77
+ ) -> Tuple[SparkDataFrame, SparkDataFrame]:
81
78
  users = interactions.select(self.query_column).distinct()
82
79
  train_users, _ = users.randomSplit(
83
80
  [1 - threshold, threshold],
84
81
  seed=self.seed,
85
82
  )
86
83
  interactions = interactions.join(
87
- train_users.withColumn("is_test", sf.lit(False)),
88
- on=self.query_column,
89
- how="left"
84
+ train_users.withColumn("is_test", sf.lit(False)), on=self.query_column, how="left"
90
85
  ).na.fill({"is_test": True})
91
86
 
92
87
  train = interactions.filter(~sf.col("is_test")).drop("is_test")
@@ -95,27 +90,18 @@ class ColdUserRandomSplitter(Splitter):
95
90
  return train, test
96
91
 
97
92
  def _core_split_polars(
98
- self,
99
- interactions: PolarsDataFrame,
100
- threshold: float
101
- ) -> Union[PolarsDataFrame, PolarsDataFrame]:
93
+ self, interactions: PolarsDataFrame, threshold: float
94
+ ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
102
95
  train_users = (
103
- interactions
104
- .select(self.query_column)
96
+ interactions.select(self.query_column)
105
97
  .unique()
106
98
  .sample(fraction=(1 - threshold), seed=self.seed)
107
99
  .with_columns(pl.lit(False).alias("is_test"))
108
100
  )
109
101
 
110
- interactions = (
111
- interactions
112
- .join(
113
- train_users,
114
- on=self.query_column, how="left")
115
- .fill_null(True)
116
- )
102
+ interactions = interactions.join(train_users, on=self.query_column, how="left").fill_null(True)
117
103
 
118
- train = interactions.filter(~pl.col("is_test")).drop("is_test") # pylint: disable=invalid-unary-operand-type
104
+ train = interactions.filter(~pl.col("is_test")).drop("is_test")
119
105
  test = interactions.filter(pl.col("is_test")).drop("is_test")
120
106
  return train, test
121
107
 
@@ -127,4 +113,5 @@ class ColdUserRandomSplitter(Splitter):
127
113
  if isinstance(interactions, PolarsDataFrame):
128
114
  return self._core_split_polars(interactions, self.test_size)
129
115
 
130
- raise NotImplementedError(f"{self} is not implemented for {type(interactions)}")
116
+ msg = f"{self} is not implemented for {type(interactions)}"
117
+ raise NotImplementedError(msg)
@@ -1,9 +1,11 @@
1
1
  from typing import Literal, Optional, Tuple
2
+
2
3
  import polars as pl
3
4
 
4
- from .base_splitter import Splitter, SplitterReturnType
5
5
  from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
6
6
 
7
+ from .base_splitter import Splitter, SplitterReturnType
8
+
7
9
  if PYSPARK_AVAILABLE:
8
10
  import pyspark.sql.functions as sf
9
11
  from pyspark.sql import Window
@@ -11,11 +13,11 @@ if PYSPARK_AVAILABLE:
11
13
  StrategyName = Literal["query"]
12
14
 
13
15
 
14
- # pylint: disable=too-few-public-methods
15
16
  class KFolds(Splitter):
16
17
  """
17
18
  Splits interactions inside each query into folds at random.
18
19
  """
20
+
19
21
  _init_arg_names = [
20
22
  "n_folds",
21
23
  "strategy",
@@ -29,7 +31,6 @@ class KFolds(Splitter):
29
31
  "session_id_processing_strategy",
30
32
  ]
31
33
 
32
- # pylint: disable=too-many-arguments
33
34
  def __init__(
34
35
  self,
35
36
  n_folds: Optional[int] = 5,
@@ -64,11 +65,12 @@ class KFolds(Splitter):
64
65
  item_column=item_column,
65
66
  timestamp_column=timestamp_column,
66
67
  session_id_column=session_id_column,
67
- session_id_processing_strategy=session_id_processing_strategy
68
+ session_id_processing_strategy=session_id_processing_strategy,
68
69
  )
69
70
  self.n_folds = n_folds
70
71
  if strategy not in {"query"}:
71
- raise ValueError(f"Wrong splitter parameter: {strategy}")
72
+ msg = f"Wrong splitter parameter: {strategy}"
73
+ raise ValueError(msg)
72
74
  self.strategy = strategy
73
75
  self.seed = seed
74
76
 
@@ -85,16 +87,10 @@ class KFolds(Splitter):
85
87
  dataframe = interactions.withColumn("_rand", sf.rand(self.seed))
86
88
  dataframe = dataframe.withColumn(
87
89
  "fold",
88
- sf.row_number().over(
89
- Window.partitionBy(self.query_column).orderBy("_rand")
90
- )
91
- % self.n_folds,
90
+ sf.row_number().over(Window.partitionBy(self.query_column).orderBy("_rand")) % self.n_folds,
92
91
  ).drop("_rand")
93
92
  for i in range(self.n_folds):
94
- dataframe = dataframe.withColumn(
95
- "is_test",
96
- sf.when(sf.col("fold") == i, True).otherwise(False)
97
- )
93
+ dataframe = dataframe.withColumn("is_test", sf.when(sf.col("fold") == i, True).otherwise(False))
98
94
  if self.session_id_column:
99
95
  dataframe = self._recalculate_with_session_id_column(dataframe)
100
96
 
@@ -122,28 +118,21 @@ class KFolds(Splitter):
122
118
  def _query_split_polars(self, interactions: PolarsDataFrame) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
123
119
  dataframe = interactions.sample(fraction=1, shuffle=True, seed=self.seed).sort(self.query_column)
124
120
  dataframe = dataframe.with_columns(
125
- (pl.cum_count(self.query_column).over(self.query_column) % self.n_folds)
126
- .alias("fold")
121
+ (pl.cum_count(self.query_column).over(self.query_column) % self.n_folds).alias("fold")
127
122
  )
128
123
  for i in range(self.n_folds):
129
124
  dataframe = dataframe.with_columns(
130
- pl.when(
131
- pl.col("fold") == i
132
- )
133
- .then(True)
134
- .otherwise(False)
135
- .alias("is_test")
125
+ pl.when(pl.col("fold") == i).then(True).otherwise(False).alias("is_test")
136
126
  )
137
127
  if self.session_id_column:
138
128
  dataframe = self._recalculate_with_session_id_column(dataframe)
139
129
 
140
- train = dataframe.filter(~pl.col("is_test")).drop("is_test", "fold") # pylint: disable=invalid-unary-operand-type
130
+ train = dataframe.filter(~pl.col("is_test")).drop("is_test", "fold")
141
131
  test = dataframe.filter(pl.col("is_test")).drop("is_test", "fold")
142
132
 
143
133
  test = self._drop_cold_items_and_users(train, test)
144
134
  yield train, test
145
135
 
146
- # pylint: disable=inconsistent-return-statements
147
136
  def _core_split(self, interactions: DataFrameLike) -> SplitterReturnType:
148
137
  if self.strategy == "query":
149
138
  if isinstance(interactions, SparkDataFrame):
@@ -153,4 +142,5 @@ class KFolds(Splitter):
153
142
  if isinstance(interactions, PolarsDataFrame):
154
143
  return self._query_split_polars(interactions)
155
144
 
156
- raise NotImplementedError(f"{self} is not implemented for {type(interactions)}")
145
+ msg = f"{self} is not implemented for {type(interactions)}"
146
+ raise NotImplementedError(msg)
@@ -4,9 +4,10 @@ import numpy as np
4
4
  import pandas as pd
5
5
  import polars as pl
6
6
 
7
- from .base_splitter import Splitter
8
7
  from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
9
8
 
9
+ from .base_splitter import Splitter
10
+
10
11
  if PYSPARK_AVAILABLE:
11
12
  import pyspark.sql.functions as sf
12
13
  from pyspark.sql import Window
@@ -14,7 +15,6 @@ if PYSPARK_AVAILABLE:
14
15
  StrategyName = Literal["interactions", "timedelta"]
15
16
 
16
17
 
17
- # pylint: disable=too-few-public-methods
18
18
  class LastNSplitter(Splitter):
19
19
  """
20
20
  Split interactions by last N interactions/timedelta per user.
@@ -88,10 +88,11 @@ class LastNSplitter(Splitter):
88
88
  14 3 2 2020-01-05
89
89
  <BLANKLINE>
90
90
  """
91
+
91
92
  _init_arg_names = [
92
93
  "N",
93
94
  "divide_column",
94
- "timestamp_col_format",
95
+ "time_column_format",
95
96
  "strategy",
96
97
  "drop_cold_users",
97
98
  "drop_cold_items",
@@ -102,10 +103,9 @@ class LastNSplitter(Splitter):
102
103
  "session_id_processing_strategy",
103
104
  ]
104
105
 
105
- # pylint: disable=invalid-name, too-many-arguments
106
106
  def __init__(
107
107
  self,
108
- N: int,
108
+ N: int, # noqa: N803
109
109
  divide_column: str = "query_id",
110
110
  time_column_format: str = "yyyy-MM-dd HH:mm:ss",
111
111
  strategy: StrategyName = "interactions",
@@ -147,7 +147,8 @@ class LastNSplitter(Splitter):
147
147
  default: ``test``.
148
148
  """
149
149
  if strategy not in ["interactions", "timedelta"]:
150
- raise ValueError("strategy must be equal 'interactions' or 'timedelta'")
150
+ msg = "strategy must be equal 'interactions' or 'timedelta'"
151
+ raise ValueError(msg)
151
152
  super().__init__(
152
153
  drop_cold_users=drop_cold_users,
153
154
  drop_cold_items=drop_cold_items,
@@ -160,9 +161,9 @@ class LastNSplitter(Splitter):
160
161
  self.N = N
161
162
  self.strategy = strategy
162
163
  self.divide_column = divide_column
163
- self.timestamp_col_format = None
164
+ self.time_column_format = None
164
165
  if self.strategy == "timedelta":
165
- self.timestamp_col_format = time_column_format
166
+ self.time_column_format = time_column_format
166
167
 
167
168
  def _add_time_partition(self, interactions: DataFrameLike) -> DataFrameLike:
168
169
  if isinstance(interactions, SparkDataFrame):
@@ -172,7 +173,8 @@ class LastNSplitter(Splitter):
172
173
  if isinstance(interactions, PolarsDataFrame):
173
174
  return self._add_time_partition_to_polars(interactions)
174
175
 
175
- raise NotImplementedError(f"{self} is not implemented for {type(interactions)}")
176
+ msg = f"{self} is not implemented for {type(interactions)}"
177
+ raise NotImplementedError(msg)
176
178
 
177
179
  def _add_time_partition_to_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
178
180
  res = interactions.copy(deep=True)
@@ -191,8 +193,7 @@ class LastNSplitter(Splitter):
191
193
 
192
194
  def _add_time_partition_to_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
193
195
  res = interactions.sort(self.timestamp_column).with_columns(
194
- pl.col(self.divide_column).cumcount().over(pl.col(self.divide_column))
195
- .alias("row_num")
196
+ pl.col(self.divide_column).cumcount().over(pl.col(self.divide_column)).alias("row_num")
196
197
  )
197
198
 
198
199
  return res
@@ -205,7 +206,8 @@ class LastNSplitter(Splitter):
205
206
  if isinstance(interactions, PolarsDataFrame):
206
207
  return self._to_unix_timestamp_polars(interactions)
207
208
 
208
- raise NotImplementedError(f"{self} is not implemented for {type(interactions)}")
209
+ msg = f"{self} is not implemented for {type(interactions)}"
210
+ raise NotImplementedError(msg)
209
211
 
210
212
  def _to_unix_timestamp_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
211
213
  time_column_type = dict(interactions.dtypes)[self.timestamp_column]
@@ -221,7 +223,7 @@ class LastNSplitter(Splitter):
221
223
  time_column_type = dict(interactions.dtypes)[self.timestamp_column]
222
224
  if time_column_type == "date":
223
225
  interactions = interactions.withColumn(
224
- self.timestamp_column, sf.unix_timestamp(self.timestamp_column, self.timestamp_col_format)
226
+ self.timestamp_column, sf.unix_timestamp(self.timestamp_column, self.time_column_format)
225
227
  )
226
228
 
227
229
  return interactions
@@ -233,20 +235,19 @@ class LastNSplitter(Splitter):
233
235
 
234
236
  return interactions
235
237
 
236
- # pylint: disable=invalid-name
237
- def _partial_split_interactions(self, interactions: DataFrameLike, N: int) -> Tuple[DataFrameLike, DataFrameLike]:
238
+ def _partial_split_interactions(self, interactions: DataFrameLike, n: int) -> Tuple[DataFrameLike, DataFrameLike]:
238
239
  res = self._add_time_partition(interactions)
239
240
  if isinstance(interactions, SparkDataFrame):
240
- return self._partial_split_interactions_spark(res, N)
241
+ return self._partial_split_interactions_spark(res, n)
241
242
  if isinstance(interactions, PandasDataFrame):
242
- return self._partial_split_interactions_pandas(res, N)
243
- return self._partial_split_interactions_polars(res, N)
243
+ return self._partial_split_interactions_pandas(res, n)
244
+ return self._partial_split_interactions_polars(res, n)
244
245
 
245
246
  def _partial_split_interactions_pandas(
246
- self, interactions: PandasDataFrame, N: int
247
+ self, interactions: PandasDataFrame, n: int
247
248
  ) -> Tuple[PandasDataFrame, PandasDataFrame]:
248
249
  interactions["count"] = interactions.groupby(self.divide_column, sort=False)[self.divide_column].transform(len)
249
- interactions["is_test"] = interactions["row_num"] > (interactions["count"] - float(N))
250
+ interactions["is_test"] = interactions["row_num"] > (interactions["count"] - float(n))
250
251
  if self.session_id_column:
251
252
  interactions = self._recalculate_with_session_id_column(interactions)
252
253
 
@@ -256,14 +257,14 @@ class LastNSplitter(Splitter):
256
257
  return train, test
257
258
 
258
259
  def _partial_split_interactions_spark(
259
- self, interactions: SparkDataFrame, N: int
260
+ self, interactions: SparkDataFrame, n: int
260
261
  ) -> Tuple[SparkDataFrame, SparkDataFrame]:
261
262
  interactions = interactions.withColumn(
262
263
  "count", sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column))
263
264
  )
264
265
  # float(n) - because DataFrame.filter is changing order
265
266
  # of sorted DataFrame to descending
266
- interactions = interactions.withColumn("is_test", sf.col("row_num") > sf.col("count") - sf.lit(float(N)))
267
+ interactions = interactions.withColumn("is_test", sf.col("row_num") > sf.col("count") - sf.lit(float(n)))
267
268
  if self.session_id_column:
268
269
  interactions = self._recalculate_with_session_id_column(interactions)
269
270
 
@@ -273,27 +274,22 @@ class LastNSplitter(Splitter):
273
274
  return train, test
274
275
 
275
276
  def _partial_split_interactions_polars(
276
- self, interactions: PolarsDataFrame, N: int
277
+ self, interactions: PolarsDataFrame, n: int
277
278
  ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
278
279
  interactions = interactions.with_columns(
279
- pl.col(self.timestamp_column).count().over(self.divide_column)
280
- .alias("count")
281
- )
282
- interactions = interactions.with_columns(
283
- (pl.col("row_num") > (pl.col("count") - N))
284
- .alias("is_test")
280
+ pl.col(self.timestamp_column).count().over(self.divide_column).alias("count")
285
281
  )
282
+ interactions = interactions.with_columns((pl.col("row_num") > (pl.col("count") - n)).alias("is_test"))
286
283
  if self.session_id_column:
287
284
  interactions = self._recalculate_with_session_id_column(interactions)
288
285
 
289
- train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "is_test") # pylint: disable=invalid-unary-operand-type
286
+ train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "is_test")
290
287
  test = interactions.filter(pl.col("is_test")).drop("row_num", "count", "is_test")
291
288
 
292
289
  return train, test
293
290
 
294
291
  def _partial_split_timedelta(
295
- self,
296
- interactions: DataFrameLike, timedelta: int
292
+ self, interactions: DataFrameLike, timedelta: int
297
293
  ) -> Tuple[DataFrameLike, DataFrameLike]:
298
294
  if isinstance(interactions, SparkDataFrame):
299
295
  return self._partial_split_timedelta_spark(interactions, timedelta)
@@ -341,22 +337,16 @@ class LastNSplitter(Splitter):
341
337
  def _partial_split_timedelta_polars(
342
338
  self, interactions: PolarsDataFrame, timedelta: int
343
339
  ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
344
- res = (
345
- interactions
346
- .with_columns(
347
- (pl.col(self.timestamp_column).max().over(self.divide_column) - pl.col(self.timestamp_column))
348
- .alias("diff_timestamp")
340
+ res = interactions.with_columns(
341
+ (pl.col(self.timestamp_column).max().over(self.divide_column) - pl.col(self.timestamp_column)).alias(
342
+ "diff_timestamp"
349
343
  )
350
- .with_columns(
351
- (pl.col("diff_timestamp") < timedelta)
352
- .alias("is_test")
353
- )
354
- )
344
+ ).with_columns((pl.col("diff_timestamp") < timedelta).alias("is_test"))
355
345
 
356
346
  if self.session_id_column:
357
347
  res = self._recalculate_with_session_id_column(res)
358
348
 
359
- train = res.filter(~pl.col("is_test")).drop("diff_timestamp", "is_test") # pylint: disable=invalid-unary-operand-type
349
+ train = res.filter(~pl.col("is_test")).drop("diff_timestamp", "is_test")
360
350
  test = res.filter(pl.col("is_test")).drop("diff_timestamp", "is_test")
361
351
 
362
352
  return train, test
@@ -1,15 +1,16 @@
1
- from typing import Optional, Union
1
+ from typing import Optional, Tuple
2
+
2
3
  import polars as pl
3
4
 
4
- from .base_splitter import Splitter, SplitterReturnType
5
5
  from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
6
6
 
7
+ from .base_splitter import Splitter, SplitterReturnType
8
+
7
9
  if PYSPARK_AVAILABLE:
8
10
  import pyspark.sql.functions as sf
9
11
  from pyspark.sql import Window
10
12
 
11
13
 
12
- # pylint: disable=too-few-public-methods, duplicate-code
13
14
  class NewUsersSplitter(Splitter):
14
15
  """
15
16
  Only new users will be assigned to test set.
@@ -63,7 +64,6 @@ class NewUsersSplitter(Splitter):
63
64
  "session_id_processing_strategy",
64
65
  ]
65
66
 
66
- # pylint: disable=too-many-arguments
67
67
  def __init__(
68
68
  self,
69
69
  test_size: float,
@@ -91,24 +91,23 @@ class NewUsersSplitter(Splitter):
91
91
  item_column=item_column,
92
92
  timestamp_column=timestamp_column,
93
93
  session_id_column=session_id_column,
94
- session_id_processing_strategy=session_id_processing_strategy
94
+ session_id_processing_strategy=session_id_processing_strategy,
95
95
  )
96
96
  if test_size < 0 or test_size > 1:
97
- raise ValueError("test_size must between 0 and 1")
97
+ msg = "test_size must between 0 and 1"
98
+ raise ValueError(msg)
98
99
  self.test_size = test_size
99
100
 
100
101
  def _core_split_pandas(
101
- self,
102
- interactions: PandasDataFrame,
103
- threshold: float
104
- ) -> Union[PandasDataFrame, PandasDataFrame]:
105
- start_date_by_user = interactions.groupby(self.query_column).agg(
106
- _start_dt_by_user=(self.timestamp_column, "min")
107
- ).reset_index()
102
+ self, interactions: PandasDataFrame, threshold: float
103
+ ) -> Tuple[PandasDataFrame, PandasDataFrame]:
104
+ start_date_by_user = (
105
+ interactions.groupby(self.query_column).agg(_start_dt_by_user=(self.timestamp_column, "min")).reset_index()
106
+ )
108
107
  test_start_date = (
109
- start_date_by_user
110
- .groupby("_start_dt_by_user")
111
- .agg(_num_users_by_start_date=(self.query_column, "count")).reset_index()
108
+ start_date_by_user.groupby("_start_dt_by_user")
109
+ .agg(_num_users_by_start_date=(self.query_column, "count"))
110
+ .reset_index()
112
111
  .sort_values(by="_start_dt_by_user", ascending=False)
113
112
  )
114
113
  test_start_date["_cum_num_users_to_dt"] = test_start_date["_num_users_by_start_date"].cumsum()
@@ -120,9 +119,7 @@ class NewUsersSplitter(Splitter):
120
119
 
121
120
  train = interactions[interactions[self.timestamp_column] < test_start]
122
121
  test = interactions.merge(
123
- start_date_by_user[start_date_by_user["_start_dt_by_user"] >= test_start],
124
- how="inner",
125
- on=self.query_column
122
+ start_date_by_user[start_date_by_user["_start_dt_by_user"] >= test_start], how="inner", on=self.query_column
126
123
  ).drop(columns=["_start_dt_by_user"])
127
124
 
128
125
  if self.session_id_column:
@@ -136,10 +133,8 @@ class NewUsersSplitter(Splitter):
136
133
  return train, test
137
134
 
138
135
  def _core_split_spark(
139
- self,
140
- interactions: SparkDataFrame,
141
- threshold: float
142
- ) -> Union[SparkDataFrame, SparkDataFrame]:
136
+ self, interactions: SparkDataFrame, threshold: float
137
+ ) -> Tuple[SparkDataFrame, SparkDataFrame]:
143
138
  start_date_by_user = interactions.groupby(self.query_column).agg(
144
139
  sf.min(self.timestamp_column).alias("_start_dt_by_user")
145
140
  )
@@ -175,53 +170,33 @@ class NewUsersSplitter(Splitter):
175
170
  return train, test
176
171
 
177
172
  def _core_split_polars(
178
- self,
179
- interactions: PolarsDataFrame,
180
- threshold: float
181
- ) -> Union[PolarsDataFrame, PolarsDataFrame]:
182
- start_date_by_user = (
183
- interactions
184
- .group_by(self.query_column).agg(
185
- pl.col(self.timestamp_column).min()
186
- .alias("_start_dt_by_user")
187
- )
173
+ self, interactions: PolarsDataFrame, threshold: float
174
+ ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
175
+ start_date_by_user = interactions.group_by(self.query_column).agg(
176
+ pl.col(self.timestamp_column).min().alias("_start_dt_by_user")
188
177
  )
189
178
  test_start_date = (
190
- start_date_by_user
191
- .group_by("_start_dt_by_user").agg(
192
- pl.col(self.query_column).count()
193
- .alias("_num_users_by_start_date")
194
- )
179
+ start_date_by_user.group_by("_start_dt_by_user")
180
+ .agg(pl.col(self.query_column).count().alias("_num_users_by_start_date"))
195
181
  .sort("_start_dt_by_user", descending=True)
196
182
  .with_columns(
197
- pl.col("_num_users_by_start_date").cum_sum()
198
- .alias("cum_sum_users"),
183
+ pl.col("_num_users_by_start_date").cum_sum().alias("cum_sum_users"),
199
184
  )
200
- .filter(
201
- pl.col("cum_sum_users") >= pl.col("cum_sum_users").max() * threshold
202
- )
203
- ["_start_dt_by_user"]
185
+ .filter(pl.col("cum_sum_users") >= pl.col("cum_sum_users").max() * threshold)["_start_dt_by_user"]
204
186
  .max()
205
187
  )
206
188
 
207
189
  train = interactions.filter(pl.col(self.timestamp_column) < test_start_date)
208
190
  test = interactions.join(
209
- start_date_by_user.filter(pl.col("_start_dt_by_user") >= test_start_date),
210
- on=self.query_column,
211
- how="inner"
191
+ start_date_by_user.filter(pl.col("_start_dt_by_user") >= test_start_date), on=self.query_column, how="inner"
212
192
  ).drop("_start_dt_by_user")
213
193
 
214
194
  if self.session_id_column:
215
195
  interactions = interactions.with_columns(
216
- pl.when(
217
- pl.col(self.timestamp_column) < test_start_date
218
- )
219
- .then(False)
220
- .otherwise(True)
221
- .alias("is_test")
196
+ pl.when(pl.col(self.timestamp_column) < test_start_date).then(False).otherwise(True).alias("is_test")
222
197
  )
223
198
  interactions = self._recalculate_with_session_id_column(interactions)
224
- train = interactions.filter(~pl.col("is_test")).drop("is_test") # pylint: disable=invalid-unary-operand-type
199
+ train = interactions.filter(~pl.col("is_test")).drop("is_test")
225
200
  test = interactions.filter(pl.col("is_test")).drop("is_test")
226
201
 
227
202
  return train, test
@@ -234,4 +209,5 @@ class NewUsersSplitter(Splitter):
234
209
  if isinstance(interactions, PolarsDataFrame):
235
210
  return self._core_split_polars(interactions, self.test_size)
236
211
 
237
- raise NotImplementedError(f"{self} is not implemented for {type(interactions)}")
212
+ msg = f"{self} is not implemented for {type(interactions)}"
213
+ raise NotImplementedError(msg)
@@ -1,10 +1,10 @@
1
- from typing import Optional, Union
1
+ from typing import Optional, Tuple
2
2
 
3
- from .base_splitter import Splitter, SplitterReturnType
4
3
  from replay.utils import DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
5
4
 
5
+ from .base_splitter import Splitter, SplitterReturnType
6
+
6
7
 
7
- # pylint: disable=too-few-public-methods, duplicate-code
8
8
  class RandomSplitter(Splitter):
9
9
  """Assign records into train and test at random."""
10
10
 
@@ -17,7 +17,6 @@ class RandomSplitter(Splitter):
17
17
  "item_column",
18
18
  ]
19
19
 
20
- # pylint: disable=too-many-arguments
21
20
  def __init__(
22
21
  self,
23
22
  test_size: float,
@@ -25,7 +24,7 @@ class RandomSplitter(Splitter):
25
24
  drop_cold_users: bool = False,
26
25
  seed: Optional[int] = None,
27
26
  query_column: str = "query_id",
28
- item_column: str = "item_id"
27
+ item_column: str = "item_id",
29
28
  ):
30
29
  """
31
30
  :param test_size: test size 0 to 1
@@ -39,37 +38,30 @@ class RandomSplitter(Splitter):
39
38
  drop_cold_items=drop_cold_items,
40
39
  drop_cold_users=drop_cold_users,
41
40
  query_column=query_column,
42
- item_column=item_column
41
+ item_column=item_column,
43
42
  )
44
43
  self.seed = seed
45
44
  if test_size < 0 or test_size > 1:
46
- raise ValueError("test_size must between 0 and 1")
45
+ msg = "test_size must between 0 and 1"
46
+ raise ValueError(msg)
47
47
  self.test_size = test_size
48
48
 
49
49
  def _random_split_spark(
50
- self,
51
- interactions: SparkDataFrame,
52
- threshold: float
53
- ) -> Union[SparkDataFrame, SparkDataFrame]:
54
- train, test = interactions.randomSplit(
55
- [1 - threshold, threshold], self.seed
56
- )
50
+ self, interactions: SparkDataFrame, threshold: float
51
+ ) -> Tuple[SparkDataFrame, SparkDataFrame]:
52
+ train, test = interactions.randomSplit([1 - threshold, threshold], self.seed)
57
53
  return train, test
58
54
 
59
55
  def _random_split_pandas(
60
- self,
61
- interactions: PandasDataFrame,
62
- threshold: float
63
- ) -> Union[PandasDataFrame, PandasDataFrame]:
56
+ self, interactions: PandasDataFrame, threshold: float
57
+ ) -> Tuple[PandasDataFrame, PandasDataFrame]:
64
58
  train = interactions.sample(frac=(1 - threshold), random_state=self.seed)
65
59
  test = interactions.drop(train.index)
66
60
  return train, test
67
61
 
68
62
  def _random_split_polars(
69
- self,
70
- interactions: PolarsDataFrame,
71
- threshold: float
72
- ) -> Union[PolarsDataFrame, PolarsDataFrame]:
63
+ self, interactions: PolarsDataFrame, threshold: float
64
+ ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
73
65
  train_size = int(len(interactions) * (1 - threshold)) + 1
74
66
  shuffled_interactions = interactions.sample(fraction=1, shuffle=True, seed=self.seed)
75
67
  train = shuffled_interactions[:train_size]
@@ -84,4 +76,5 @@ class RandomSplitter(Splitter):
84
76
  if isinstance(interactions, PolarsDataFrame):
85
77
  return self._random_split_polars(interactions, self.test_size)
86
78
 
87
- raise NotImplementedError(f"{self} is not implemented for {type(interactions)}")
79
+ msg = f"{self} is not implemented for {type(interactions)}"
80
+ raise NotImplementedError(msg)