replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -6,9 +6,11 @@ Contains classes for encoding categorical data
6
6
  ``LabelEncoder`` to apply multiple LabelEncodingRule to dataframe.
7
7
  """
8
8
  import abc
9
- import polars as pl
9
+ import warnings
10
10
  from typing import Dict, List, Literal, Mapping, Optional, Sequence, Union
11
11
 
12
+ import polars as pl
13
+
12
14
  from replay.utils import (
13
15
  PYSPARK_AVAILABLE,
14
16
  DataFrameLike,
@@ -20,13 +22,16 @@ from replay.utils import (
20
22
 
21
23
  if PYSPARK_AVAILABLE:
22
24
  from pyspark.sql import functions as sf
25
+ from pyspark.sql.types import LongType, StructType
23
26
  from pyspark.storagelevel import StorageLevel
24
- from pyspark.sql.types import StructType, LongType
25
27
 
26
- HandleUnknownStrategies = Literal["error", "use_default_value"]
28
+ HandleUnknownStrategies = Literal["error", "use_default_value", "drop"]
29
+
30
+
31
+ class LabelEncoderTransformWarning(Warning):
32
+ """Label encoder transform warning."""
27
33
 
28
34
 
29
- # pylint: disable=missing-function-docstring
30
35
  class BaseLabelEncodingRule(abc.ABC): # pragma: no cover
31
36
  """
32
37
  Interface of the label encoding rule
@@ -70,7 +75,6 @@ class BaseLabelEncodingRule(abc.ABC): # pragma: no cover
70
75
  raise NotImplementedError()
71
76
 
72
77
 
73
- # pylint: disable=too-many-instance-attributes
74
78
  class LabelEncodingRule(BaseLabelEncodingRule):
75
79
  """
76
80
  Implementation of the encoding rule for categorical variables of PySpark and Pandas Data Frames.
@@ -79,7 +83,7 @@ class LabelEncodingRule(BaseLabelEncodingRule):
79
83
  """
80
84
 
81
85
  _ENCODED_COLUMN_SUFFIX: str = "_encoded"
82
- _HANDLE_UNKNOWN_STRATEGIES = ("error", "use_default_value")
86
+ _HANDLE_UNKNOWN_STRATEGIES = ("error", "use_default_value", "drop")
83
87
  _TRANSFORM_PERFORMANCE_THRESHOLD_FOR_PANDAS = 100_000
84
88
 
85
89
  def __init__(
@@ -99,6 +103,7 @@ class LabelEncodingRule(BaseLabelEncodingRule):
99
103
  When set to ``error`` an error will be raised in case an unknown label is present during transform.
100
104
  When set to ``use_default_value``, the encoded value of unknown label will be set
101
105
  to the value given for the parameter default_value.
106
+ When set to ``drop``, the unknown labels will be dropped.
102
107
  Default: ``error``.
103
108
  :param default_value: Default value that will fill the unknown labels after transform.
104
109
  When the parameter handle_unknown is set to ``use_default_value``,
@@ -110,7 +115,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
110
115
  Default: ``None``.
111
116
  """
112
117
  if handle_unknown not in self._HANDLE_UNKNOWN_STRATEGIES:
113
- raise ValueError(f"handle_unknown should be either 'error' or 'use_default_value', got {handle_unknown}.")
118
+ msg = f"handle_unknown should be either 'error' or 'use_default_value', got {handle_unknown}."
119
+ raise ValueError(msg)
114
120
  self._handle_unknown = handle_unknown
115
121
  if (
116
122
  self._handle_unknown == "use_default_value"
@@ -118,7 +124,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
118
124
  and not isinstance(default_value, int)
119
125
  and default_value != "last"
120
126
  ):
121
- raise ValueError("Default value should be None, int or 'last'")
127
+ msg = "Default value should be None, int or 'last'"
128
+ raise ValueError(msg)
122
129
 
123
130
  self._default_value = default_value
124
131
  self._col = column
@@ -135,12 +142,14 @@ class LabelEncodingRule(BaseLabelEncodingRule):
135
142
 
136
143
  def get_mapping(self) -> Mapping:
137
144
  if self._mapping is None:
138
- raise RuntimeError("Label encoder is not fitted")
145
+ msg = "Label encoder is not fitted"
146
+ raise RuntimeError(msg)
139
147
  return self._mapping
140
148
 
141
149
  def get_inverse_mapping(self) -> Mapping:
142
150
  if self._mapping is None:
143
- raise RuntimeError("Label encoder is not fitted")
151
+ msg = "Label encoder is not fitted"
152
+ raise RuntimeError(msg)
144
153
  return self._inverse_mapping
145
154
 
146
155
  def _make_inverse_mapping(self) -> Mapping:
@@ -159,17 +168,14 @@ class LabelEncodingRule(BaseLabelEncodingRule):
159
168
  unique_col_values.rdd.zipWithIndex()
160
169
  .toDF(
161
170
  StructType()
162
- .add("_1",
163
- StructType()
164
- .add(self._col, df.schema[self._col].dataType, True),
165
- True)
171
+ .add("_1", StructType().add(self._col, df.schema[self._col].dataType, True), True)
166
172
  .add("_2", LongType(), True)
167
173
  )
168
174
  .select(sf.col(f"_1.{self._col}").alias(self._col), sf.col("_2").alias(self._target_col))
169
175
  .persist(StorageLevel.MEMORY_ONLY)
170
176
  )
171
177
 
172
- self._mapping = mapping_on_spark.rdd.collectAsMap() # type: ignore
178
+ self._mapping = mapping_on_spark.rdd.collectAsMap()
173
179
  mapping_on_spark.unpersist()
174
180
  unique_col_values.unpersist()
175
181
 
@@ -198,17 +204,18 @@ class LabelEncodingRule(BaseLabelEncodingRule):
198
204
  elif isinstance(df, PolarsDataFrame):
199
205
  self._fit_polars(df)
200
206
  else:
201
- raise NotImplementedError(f"{self.__class__.__name__} is not implemented for {type(df)}")
207
+ msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
208
+ raise NotImplementedError(msg)
202
209
  self._inverse_mapping = self._make_inverse_mapping()
203
210
  self._inverse_mapping_list = self._make_inverse_mapping_list()
204
- if self._handle_unknown == "use_default_value":
205
- if self._default_value in self._inverse_mapping:
206
- raise ValueError(
207
- "The used value for default_value "
208
- f"{self._default_value} is one of the "
209
- "values already used for encoding the "
210
- "seen labels."
211
- )
211
+ if self._handle_unknown == "use_default_value" and self._default_value in self._inverse_mapping:
212
+ msg = (
213
+ "The used value for default_value "
214
+ f"{self._default_value} is one of the "
215
+ "values already used for encoding the "
216
+ "seen labels."
217
+ )
218
+ raise ValueError(msg)
212
219
  self._is_fitted = True
213
220
  return self
214
221
 
@@ -226,18 +233,15 @@ class LabelEncodingRule(BaseLabelEncodingRule):
226
233
  new_unique_values.rdd.zipWithIndex()
227
234
  .toDF(
228
235
  StructType()
229
- .add("_1",
230
- StructType()
231
- .add(self._col, df.schema[self._col].dataType),
232
- True)
236
+ .add("_1", StructType().add(self._col, df.schema[self._col].dataType), True)
233
237
  .add("_2", LongType(), True)
234
238
  )
235
239
  .select(sf.col(f"_1.{self._col}").alias(self._col), sf.col("_2").alias(self._target_col))
236
240
  .withColumn(self._target_col, sf.col(self._target_col) + max_value)
237
- .rdd.collectAsMap() # type: ignore
241
+ .rdd.collectAsMap()
238
242
  )
239
- self._mapping.update(new_data) # type: ignore
240
- self._inverse_mapping.update({v: k for k, v in new_data.items()}) # type: ignore
243
+ self._mapping.update(new_data)
244
+ self._inverse_mapping.update({v: k for k, v in new_data.items()})
241
245
  self._inverse_mapping_list.extend(new_data.keys())
242
246
  new_unique_values.unpersist()
243
247
 
@@ -245,9 +249,10 @@ class LabelEncodingRule(BaseLabelEncodingRule):
245
249
  assert self._mapping is not None
246
250
 
247
251
  new_unique_values = set(df[self._col].tolist()) - set(self._mapping)
248
- new_data: dict = {value: max(self._mapping.values()) + i for i, value in enumerate(new_unique_values, start=1)}
249
- self._mapping.update(new_data) # type: ignore
250
- self._inverse_mapping.update({v: k for k, v in new_data.items()}) # type: ignore
252
+ last_mapping_value = max(self._mapping.values())
253
+ new_data: dict = {value: last_mapping_value + i for i, value in enumerate(new_unique_values, start=1)}
254
+ self._mapping.update(new_data)
255
+ self._inverse_mapping.update({v: k for k, v in new_data.items()})
251
256
  self._inverse_mapping_list.extend(new_data.keys())
252
257
 
253
258
  def _partial_fit_polars(self, df: PolarsDataFrame) -> None:
@@ -255,8 +260,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
255
260
 
256
261
  new_unique_values = set(df.select(self._col).unique().to_series().to_list()) - set(self._mapping)
257
262
  new_data: dict = {value: max(self._mapping.values()) + i for i, value in enumerate(new_unique_values, start=1)}
258
- self._mapping.update(new_data) # type: ignore
259
- self._inverse_mapping.update({v: k for k, v in new_data.items()}) # type: ignore
263
+ self._mapping.update(new_data)
264
+ self._inverse_mapping.update({v: k for k, v in new_data.items()})
260
265
  self._inverse_mapping_list.extend(new_data.keys())
261
266
 
262
267
  def partial_fit(self, df: DataFrameLike) -> "LabelEncodingRule":
@@ -276,7 +281,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
276
281
  elif isinstance(df, PolarsDataFrame):
277
282
  self._partial_fit_polars(df)
278
283
  else:
279
- raise NotImplementedError(f"{self.__class__.__name__} is not implemented for {type(df)}")
284
+ msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
285
+ raise NotImplementedError(msg)
280
286
 
281
287
  self._is_fitted = True
282
288
  return self
@@ -299,14 +305,24 @@ class LabelEncodingRule(BaseLabelEncodingRule):
299
305
  joined_df.loc[unknown_mask, self._target_col] = -1
300
306
  is_unknown_label |= unknown_mask.sum() > 0
301
307
 
302
- if is_unknown_label and default_value != -1:
308
+ if is_unknown_label:
303
309
  unknown_mask = joined_df[self._target_col] == -1
304
- if self._handle_unknown == "error":
310
+ if self._handle_unknown == "drop":
311
+ joined_df.drop(joined_df[unknown_mask].index, inplace=True)
312
+ if joined_df.empty:
313
+ warnings.warn(
314
+ f"You are trying to transform dataframe with all values are unknown for {self._col}, "
315
+ "with `handle_unknown_strategy=drop` leads to empty dataframe",
316
+ LabelEncoderTransformWarning,
317
+ )
318
+ elif self._handle_unknown == "error":
305
319
  unknown_unique_labels = joined_df[self._col][unknown_mask].unique().tolist()
306
320
  msg = f"Found unknown labels {unknown_unique_labels} in column {self._col} during transform"
307
321
  raise ValueError(msg)
308
- joined_df[self._target_col] = joined_df[self._target_col].astype("int")
309
- joined_df[self._target_col] = joined_df[self._target_col].replace({-1: default_value})
322
+ else:
323
+ if default_value != -1:
324
+ joined_df[self._target_col] = joined_df[self._target_col].astype("int")
325
+ joined_df[self._target_col] = joined_df[self._target_col].replace({-1: default_value})
310
326
 
311
327
  result_df = joined_df.drop(self._col, axis=1).rename(columns={self._target_col: self._col})
312
328
  return result_df
@@ -318,17 +334,24 @@ class LabelEncodingRule(BaseLabelEncodingRule):
318
334
  transformed_df = df.join(mapping_on_spark, on=self._col, how="left").withColumn(
319
335
  "unknown_mask", sf.isnull(self._target_col)
320
336
  )
321
- unknown_label_count = transformed_df.select(sf.sum(sf.col("unknown_mask").cast("long"))).first()[
322
- 0
323
- ] # type: ignore
337
+ unknown_label_count = transformed_df.select(sf.sum(sf.col("unknown_mask").cast("long"))).first()[0]
324
338
  if unknown_label_count > 0:
325
- if self._handle_unknown == "error":
339
+ if self._handle_unknown == "drop":
340
+ transformed_df = transformed_df.filter("unknown_mask == False")
341
+ if transformed_df.rdd.isEmpty():
342
+ warnings.warn(
343
+ f"You are trying to transform dataframe with all values are unknown for {self._col}, "
344
+ "with `handle_unknown_strategy=drop` leads to empty dataframe",
345
+ LabelEncoderTransformWarning,
346
+ )
347
+ elif self._handle_unknown == "error":
326
348
  collected_list = transformed_df.filter("unknown_mask == True").select(self._col).distinct().collect()
327
349
  unique_labels = [row[self._col] for row in collected_list]
328
350
  msg = f"Found unknown labels {unique_labels} in column {self._col} during transform"
329
351
  raise ValueError(msg)
330
- if default_value:
331
- transformed_df = transformed_df.fillna({self._target_col: default_value})
352
+ else:
353
+ if default_value:
354
+ transformed_df = transformed_df.fillna({self._target_col: default_value})
332
355
 
333
356
  result_df = transformed_df.drop(self._col, "unknown_mask").withColumnRenamed(self._target_col, self._col)
334
357
  return result_df
@@ -338,20 +361,27 @@ class LabelEncodingRule(BaseLabelEncodingRule):
338
361
  [list(self.get_mapping().keys()), list(self.get_mapping().values())],
339
362
  schema=[self._col, self._target_col],
340
363
  )
341
- mapping_on_polars = mapping_on_polars.with_columns(
342
- pl.col(self._col).cast(df.get_column(self._col).dtype)
343
- )
364
+ mapping_on_polars = mapping_on_polars.with_columns(pl.col(self._col).cast(df.get_column(self._col).dtype))
344
365
  transformed_df = df.join(mapping_on_polars, on=self._col, how="left").with_columns(
345
366
  pl.col(self._target_col).is_null().alias("unknown_mask")
346
367
  )
347
368
  unknown_df = transformed_df.filter(pl.col("unknown_mask"))
348
369
  if not unknown_df.is_empty():
349
- if self._handle_unknown == "error":
370
+ if self._handle_unknown == "drop":
371
+ transformed_df = transformed_df.filter(pl.col("unknown_mask") == "false")
372
+ if transformed_df.is_empty():
373
+ warnings.warn(
374
+ f"You are trying to transform dataframe with all values are unknown for {self._col}, "
375
+ "with `handle_unknown_strategy=drop` leads to empty dataframe",
376
+ LabelEncoderTransformWarning,
377
+ )
378
+ elif self._handle_unknown == "error":
350
379
  unique_labels = unknown_df.select(self._col).unique().to_series().to_list()
351
380
  msg = f"Found unknown labels {unique_labels} in column {self._col} during transform"
352
381
  raise ValueError(msg)
353
- if default_value:
354
- transformed_df = transformed_df.with_columns(pl.col(self._target_col).fill_null(default_value))
382
+ else:
383
+ if default_value:
384
+ transformed_df = transformed_df.with_columns(pl.col(self._target_col).fill_null(default_value))
355
385
 
356
386
  result_df = transformed_df.drop([self._col, "unknown_mask"]).rename({self._target_col: self._col})
357
387
  return result_df
@@ -364,18 +394,20 @@ class LabelEncodingRule(BaseLabelEncodingRule):
364
394
  :returns: transformed dataframe.
365
395
  """
366
396
  if self._mapping is None:
367
- raise RuntimeError("Label encoder is not fitted")
397
+ msg = "Label encoder is not fitted"
398
+ raise RuntimeError(msg)
368
399
 
369
400
  default_value = len(self._mapping) if self._default_value == "last" else self._default_value
370
401
 
371
402
  if isinstance(df, PandasDataFrame):
372
- transformed_df = self._transform_pandas(df, default_value) # type: ignore
403
+ transformed_df = self._transform_pandas(df, default_value)
373
404
  elif isinstance(df, SparkDataFrame):
374
- transformed_df = self._transform_spark(df, default_value) # type: ignore
405
+ transformed_df = self._transform_spark(df, default_value)
375
406
  elif isinstance(df, PolarsDataFrame):
376
- transformed_df = self._transform_polars(df, default_value) # type: ignore
407
+ transformed_df = self._transform_polars(df, default_value)
377
408
  else:
378
- raise NotImplementedError(f"{self.__class__.__name__} is not implemented for {type(df)}")
409
+ msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
410
+ raise NotImplementedError(msg)
379
411
  return transformed_df
380
412
 
381
413
  def _inverse_transform_pandas(self, df: PandasDataFrame) -> PandasDataFrame:
@@ -414,7 +446,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
414
446
  :returns: initial dataframe.
415
447
  """
416
448
  if self._mapping is None:
417
- raise RuntimeError("Label encoder is not fitted")
449
+ msg = "Label encoder is not fitted"
450
+ raise RuntimeError(msg)
418
451
 
419
452
  if isinstance(df, PandasDataFrame):
420
453
  transformed_df = self._inverse_transform_pandas(df)
@@ -423,7 +456,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
423
456
  elif isinstance(df, PolarsDataFrame):
424
457
  transformed_df = self._inverse_transform_polars(df)
425
458
  else:
426
- raise NotImplementedError(f"{self.__class__.__name__} is not implemented for {type(df)}")
459
+ msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
460
+ raise NotImplementedError(msg)
427
461
  return transformed_df
428
462
 
429
463
  def set_default_value(self, default_value: Optional[Union[int, str]]) -> None:
@@ -434,7 +468,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
434
468
  :param default_value: default value.
435
469
  """
436
470
  if default_value is not None and not isinstance(default_value, int) and default_value != "last":
437
- raise ValueError("Default value should be None, int or 'last'")
471
+ msg = "Default value should be None, int or 'last'"
472
+ raise ValueError(msg)
438
473
  self._default_value = default_value
439
474
 
440
475
  def set_handle_unknown(self, handle_unknown: HandleUnknownStrategies) -> None:
@@ -444,7 +479,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
444
479
  :param handle_unknown: handle unknown strategy.
445
480
  """
446
481
  if handle_unknown not in self._HANDLE_UNKNOWN_STRATEGIES:
447
- raise ValueError(f"handle_unknown should be either 'error' or 'use_default_value', got {handle_unknown}.")
482
+ msg = f"handle_unknown should be either 'error' or 'use_default_value', got {handle_unknown}."
483
+ raise ValueError(msg)
448
484
  self._handle_unknown = handle_unknown
449
485
 
450
486
 
@@ -582,11 +618,12 @@ class LabelEncoder:
582
618
  If ``str`` value, should be \"last\" only, then fill by n_classes number.
583
619
  Default ``None``.
584
620
  """
585
- columns = [i.column for i in self.rules] # pylint: disable=W0212
621
+ columns = [i.column for i in self.rules]
586
622
  for column, handle_unknown in handle_unknown_rules.items():
587
623
  if column not in columns:
588
- raise ValueError(f"Column {column} not found.")
589
- rule = list(filter(lambda x: x.column == column, self.rules)) # pylint: disable = W0212, W0640
624
+ msg = f"Column {column} not found."
625
+ raise ValueError(msg)
626
+ rule = list(filter(lambda x: x.column == column, self.rules))
590
627
  rule[0].set_handle_unknown(handle_unknown)
591
628
 
592
629
  def set_default_values(self, default_value_rules: Dict[str, Optional[Union[int, str]]]) -> None:
@@ -605,9 +642,10 @@ class LabelEncoder:
605
642
  to the value given for the parameter default_value.
606
643
  Default: ``error``.
607
644
  """
608
- columns = [i.column for i in self.rules] # pylint: disable=W0212
645
+ columns = [i.column for i in self.rules]
609
646
  for column, default_value in default_value_rules.items():
610
647
  if column not in columns:
611
- raise ValueError(f"Column {column} not found.")
612
- rule = list(filter(lambda x: x.column == column, self.rules)) # pylint: disable = W0212, W0640
648
+ msg = f"Column {column} not found."
649
+ raise ValueError(msg)
650
+ rule = list(filter(lambda x: x.column == column, self.rules))
613
651
  rule[0].set_default_value(default_value)
@@ -10,7 +10,6 @@ if PYSPARK_AVAILABLE:
10
10
  from pyspark.sql.window import Window
11
11
 
12
12
 
13
- # pylint: disable=too-many-instance-attributes, too-few-public-methods
14
13
  class Sessionizer:
15
14
  """
16
15
  Create and filter sessions from given interactions.
@@ -51,7 +50,6 @@ class Sessionizer:
51
50
  <BLANKLINE>
52
51
  """
53
52
 
54
- # pylint: disable=too-many-arguments
55
53
  def __init__(
56
54
  self,
57
55
  user_column: str = "user_id",
@@ -191,7 +189,6 @@ class Sessionizer:
191
189
  Window.partitionBy(self.user_column).orderBy(sf.col(self.time_column), sf.col("timestamp_diff").desc())
192
190
  ),
193
191
  )
194
- # data_with_sum_timediff.cache()
195
192
 
196
193
  grouped_users = data_with_sum_timediff.groupBy(self.user_column).count()
197
194
  grouped_users_with_cumsum = grouped_users.withColumn(
@@ -212,11 +209,9 @@ class Sessionizer:
212
209
  )
213
210
  )
214
211
 
215
- # data_with_sum_timediff.unpersist()
216
212
  return result
217
213
 
218
214
  def _filter_sessions(self, interactions: DataFrameLike) -> DataFrameLike:
219
- # interactions.cache()
220
215
  if isinstance(interactions, SparkDataFrame):
221
216
  return self._filter_sessions_spark(interactions)
222
217
 
@@ -254,8 +249,6 @@ class Sessionizer:
254
249
  entries_counter.select(self.session_column), self.session_column, how="right"
255
250
  )
256
251
 
257
- # filtered_interactions.cache()
258
-
259
252
  nunique = filtered_interactions.groupby(self.user_column).agg(
260
253
  sf.expr("count(distinct session_id)").alias("nunique")
261
254
  )
@@ -284,9 +277,6 @@ class Sessionizer:
284
277
  result = self._filter_sessions(result)
285
278
  columns_order += [self.session_column]
286
279
 
287
- if isinstance(result, SparkDataFrame):
288
- result = result.select(*columns_order)
289
- else:
290
- result = result[columns_order]
280
+ result = result.select(*columns_order) if isinstance(result, SparkDataFrame) else result[columns_order]
291
281
 
292
282
  return result
@@ -1,16 +1,14 @@
1
- # pylint: disable=protected-access
2
1
  from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
3
2
 
4
3
  from replay.data import Dataset
5
- from replay.preprocessing.filters import MinCountFilter
6
4
  from replay.metrics import NDCG, Metric
7
5
  from replay.models import PopRec
8
6
  from replay.models.base_rec import BaseRecommender
7
+ from replay.preprocessing.filters import MinCountFilter
9
8
  from replay.utils import SparkDataFrame
10
9
  from replay.utils.spark_utils import fallback, get_unique_entities
11
10
 
12
11
 
13
- # pylint: disable=too-many-instance-attributes
14
12
  class Fallback(BaseRecommender):
15
13
  """Fill missing recommendations using fallback model.
16
14
  Behaves like a recommender and have the same interface."""
@@ -33,16 +31,15 @@ class Fallback(BaseRecommender):
33
31
  self.threshold = threshold
34
32
  self.hot_queries = None
35
33
  self.main_model = main_model
36
- # pylint: disable=invalid-name
37
34
  self.fb_model = fallback_model
38
35
 
39
- # TO DO: add save/load for scenarios
36
+ # TODO: add save/load for scenarios
40
37
  @property
41
38
  def _init_args(self):
42
39
  return {"threshold": self.threshold}
43
40
 
44
41
  def __str__(self):
45
- return f"Fallback_{str(self.main_model)}_{str(self.fb_model)}"
42
+ return f"Fallback_{self.main_model!s}_{self.fb_model!s}"
46
43
 
47
44
  def fit(
48
45
  self,
@@ -67,7 +64,6 @@ class Fallback(BaseRecommender):
67
64
  self._fit_wrap(hot_dataset)
68
65
  self.fb_model._fit_wrap(dataset)
69
66
 
70
- # pylint: disable=too-many-arguments
71
67
  def predict(
72
68
  self,
73
69
  dataset: Dataset,
@@ -125,7 +121,6 @@ class Fallback(BaseRecommender):
125
121
  pred = fallback(hot_pred, cold_pred, k)
126
122
  return pred
127
123
 
128
- # pylint: disable=too-many-arguments, too-many-locals
129
124
  def optimize(
130
125
  self,
131
126
  train_dataset: Dataset,
@@ -1,4 +1,6 @@
1
+ import json
1
2
  from abc import ABC, abstractmethod
3
+ from pathlib import Path
2
4
  from typing import Optional, Tuple
3
5
 
4
6
  import polars as pl
@@ -7,19 +9,20 @@ from replay.utils import (
7
9
  PYSPARK_AVAILABLE,
8
10
  DataFrameLike,
9
11
  PandasDataFrame,
10
- SparkDataFrame,
11
12
  PolarsDataFrame,
13
+ SparkDataFrame,
12
14
  )
13
15
 
14
16
  if PYSPARK_AVAILABLE:
15
- from pyspark.sql import Window
16
- from pyspark.sql import functions as sf
17
+ from pyspark.sql import (
18
+ Window,
19
+ functions as sf,
20
+ )
17
21
 
18
22
 
19
23
  SplitterReturnType = Tuple[DataFrameLike, DataFrameLike]
20
24
 
21
25
 
22
- # pylint: disable=too-few-public-methods, too-many-instance-attributes
23
26
  class Splitter(ABC):
24
27
  """Base class"""
25
28
 
@@ -33,7 +36,6 @@ class Splitter(ABC):
33
36
  "session_id_processing_strategy",
34
37
  ]
35
38
 
36
- # pylint: disable=too-many-arguments
37
39
  def __init__(
38
40
  self,
39
41
  drop_cold_items: bool = False,
@@ -68,17 +70,43 @@ class Splitter(ABC):
68
70
  def _init_args(self):
69
71
  return {name: getattr(self, name) for name in self._init_arg_names}
70
72
 
73
+ def save(self, path: str) -> None:
74
+ """
75
+ Method for saving splitter in `.replay` directory.
76
+ """
77
+ base_path = Path(path).with_suffix(".replay").resolve()
78
+ base_path.mkdir(parents=True, exist_ok=True)
79
+
80
+ splitter_dict = {}
81
+ splitter_dict["init_args"] = self._init_args
82
+ splitter_dict["_class_name"] = str(self)
83
+
84
+ with open(base_path / "init_args.json", "w+") as file:
85
+ json.dump(splitter_dict, file)
86
+
87
+ @classmethod
88
+ def load(cls, path: str) -> "Splitter":
89
+ """
90
+ Method for loading splitter from `.replay` directory.
91
+ """
92
+ base_path = Path(path).with_suffix(".replay").resolve()
93
+ with open(base_path / "init_args.json", "r") as file:
94
+ splitter_dict = json.loads(file.read())
95
+ splitter = cls(**splitter_dict["init_args"])
96
+
97
+ return splitter
98
+
71
99
  def __str__(self):
72
100
  return type(self).__name__
73
101
 
74
- # pylint: disable=too-many-arguments
75
102
  def _drop_cold_items_and_users(
76
103
  self,
77
104
  train: DataFrameLike,
78
105
  test: DataFrameLike,
79
106
  ) -> DataFrameLike:
80
107
  if isinstance(train, type(test)) is False:
81
- raise TypeError("Train and test dataframes must have consistent types")
108
+ msg = "Train and test dataframes must have consistent types"
109
+ raise TypeError(msg)
82
110
 
83
111
  if isinstance(test, SparkDataFrame):
84
112
  return self._drop_cold_items_and_users_from_spark(train, test)
@@ -105,7 +133,6 @@ class Splitter(ABC):
105
133
  train: SparkDataFrame,
106
134
  test: SparkDataFrame,
107
135
  ) -> SparkDataFrame:
108
-
109
136
  if self.drop_cold_items:
110
137
  train_tmp = train.select(sf.col(self.item_column).alias("item")).distinct()
111
138
  test = test.join(train_tmp, train_tmp["item"] == test[self.item_column]).drop("item")
@@ -121,7 +148,6 @@ class Splitter(ABC):
121
148
  train: PolarsDataFrame,
122
149
  test: PolarsDataFrame,
123
150
  ) -> PolarsDataFrame:
124
-
125
151
  if self.drop_cold_items:
126
152
  train_tmp = train.select(self.item_column).unique()
127
153
  test = test.join(train_tmp, on=self.item_column)
@@ -164,9 +190,9 @@ class Splitter(ABC):
164
190
  def _recalculate_with_session_id_column_pandas(self, data: PandasDataFrame) -> PandasDataFrame:
165
191
  agg_function_name = "first" if self.session_id_processing_strategy == "train" else "last"
166
192
  res = data.copy()
167
- res["is_test"] = res.groupby(
168
- [self.query_column, self.session_id_column]
169
- )["is_test"].transform(agg_function_name)
193
+ res["is_test"] = res.groupby([self.query_column, self.session_id_column])["is_test"].transform(
194
+ agg_function_name
195
+ )
170
196
 
171
197
  return res
172
198
 
@@ -176,7 +202,7 @@ class Splitter(ABC):
176
202
  "is_test",
177
203
  agg_function("is_test").over(
178
204
  Window.orderBy(self.timestamp_column)
179
- .partitionBy(self.query_column, self.session_id_column) # type: ignore
205
+ .partitionBy(self.query_column, self.session_id_column)
180
206
  .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
181
207
  ),
182
208
  )
@@ -186,7 +212,9 @@ class Splitter(ABC):
186
212
  def _recalculate_with_session_id_column_polars(self, data: PolarsDataFrame) -> PolarsDataFrame:
187
213
  agg_function = pl.Expr.first if self.session_id_processing_strategy == "train" else pl.Expr.last
188
214
  res = data.with_columns(
189
- agg_function(pl.col("is_test").sort_by(self.timestamp_column))
190
- .over([self.query_column, self.session_id_column]))
215
+ agg_function(pl.col("is_test").sort_by(self.timestamp_column)).over(
216
+ [self.query_column, self.session_id_column]
217
+ )
218
+ )
191
219
 
192
220
  return res