replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
replay/models/base_rec.py CHANGED
@@ -1,4 +1,3 @@
1
- # pylint: disable=too-many-lines
2
1
  """
3
2
  Base abstract classes:
4
3
  - BaseRecommender - the simplest base class
@@ -19,8 +18,8 @@ from copy import deepcopy
19
18
  from os.path import join
20
19
  from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
21
20
 
22
- import pandas as pd
23
21
  import numpy as np
22
+ import pandas as pd
24
23
  from numpy.random import default_rng
25
24
  from optuna import create_study
26
25
  from optuna.samplers import TPESampler
@@ -33,8 +32,10 @@ from replay.utils.session_handler import State
33
32
  from replay.utils.spark_utils import SparkCollectToMasterWarning
34
33
 
35
34
  if PYSPARK_AVAILABLE:
36
- from pyspark.sql import Window
37
- from pyspark.sql import functions as sf
35
+ from pyspark.sql import (
36
+ Window,
37
+ functions as sf,
38
+ )
38
39
 
39
40
  from replay.utils.spark_utils import (
40
41
  cache_temp_view,
@@ -53,7 +54,6 @@ if PYSPARK_AVAILABLE:
53
54
  )
54
55
 
55
56
 
56
- # pylint: disable=too-few-public-methods
57
57
  class IsSavable(ABC):
58
58
  """
59
59
  Common methods and attributes for saving and loading RePlay models
@@ -133,7 +133,7 @@ class RecommenderCommons:
133
133
  Create Spark SQL temporary view for df, cache it and add temp view name to self.cached_dfs.
134
134
  Temp view name is : "id_<python object id>_model_<RePlay model name>_<df_name>"
135
135
  """
136
- full_name = f"id_{id(self)}_model_{str(self)}_{df_name}"
136
+ full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
137
137
  cache_temp_view(df, full_name)
138
138
 
139
139
  if self.cached_dfs is None:
@@ -146,22 +146,19 @@ class RecommenderCommons:
146
146
  Temp view to replace will be constructed as
147
147
  "id_<python object id>_model_<RePlay model name>_<df_name>"
148
148
  """
149
- full_name = f"id_{id(self)}_model_{str(self)}_{df_name}"
149
+ full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
150
150
  drop_temp_view(full_name)
151
151
  if self.cached_dfs is not None:
152
152
  self.cached_dfs.discard(full_name)
153
153
 
154
154
 
155
- # pylint: disable=too-many-instance-attributes
156
155
  class BaseRecommender(RecommenderCommons, IsSavable, ABC):
157
156
  """Base recommender"""
158
157
 
159
158
  model: Any
160
159
  can_predict_cold_queries: bool = False
161
160
  can_predict_cold_items: bool = False
162
- _search_space: Optional[
163
- Dict[str, Union[str, Sequence[Union[str, int, float]]]]
164
- ] = None
161
+ _search_space: Optional[Dict[str, Union[str, Sequence[Union[str, int, float]]]]] = None
165
162
  _objective = MainObjective
166
163
  study = None
167
164
  criterion = None
@@ -172,7 +169,6 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
172
169
  _query_dim_size: int
173
170
  _item_dim_size: int
174
171
 
175
- # pylint: disable=too-many-arguments, too-many-locals, no-member
176
172
  def optimize(
177
173
  self,
178
174
  train_dataset: Dataset,
@@ -211,21 +207,14 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
211
207
  )
212
208
 
213
209
  if self._search_space is None:
214
- self.logger.warning(
215
- "%s has no hyper parameters to optimize", str(self)
216
- )
210
+ self.logger.warning("%s has no hyper parameters to optimize", str(self))
217
211
  return None
218
212
 
219
213
  if self.study is None or new_study:
220
- self.study = create_study(
221
- direction="maximize", sampler=TPESampler()
222
- )
214
+ self.study = create_study(direction="maximize", sampler=TPESampler())
223
215
 
224
216
  search_space = self._prepare_param_borders(param_borders)
225
- if (
226
- self._init_params_in_search_space(search_space)
227
- and not self._params_tried()
228
- ):
217
+ if self._init_params_in_search_space(search_space) and not self._params_tried():
229
218
  self.study.enqueue_trial(self._init_args)
230
219
 
231
220
  split_data = self._prepare_split_data(train_dataset, test_dataset)
@@ -244,7 +233,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
244
233
 
245
234
  def _init_params_in_search_space(self, search_space):
246
235
  """Check if model params are inside search space"""
247
- params = self._init_args # pylint: disable=no-member
236
+ params = self._init_args
248
237
  outside_search_space = {}
249
238
  for param, value in params.items():
250
239
  if param not in search_space:
@@ -252,12 +241,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
252
241
  borders = search_space[param]["args"]
253
242
  param_type = search_space[param]["type"]
254
243
 
255
- extra_category = (
256
- param_type == "categorical" and value not in borders
257
- )
258
- param_out_of_bounds = param_type != "categorical" and (
259
- value < borders[0] or value > borders[1]
260
- )
244
+ extra_category = param_type == "categorical" and value not in borders
245
+ param_out_of_bounds = param_type != "categorical" and (value < borders[0] or value > borders[1])
261
246
  if extra_category or param_out_of_bounds:
262
247
  outside_search_space[param] = {
263
248
  "borders": borders,
@@ -299,11 +284,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
299
284
  # If used didn't specify some params to be tested optuna still needs to suggest them
300
285
  # This part makes sure this suggestion will be constant
301
286
  args = self._init_args
302
- missing_borders = {
303
- param: args[param]
304
- for param in search_space
305
- if param not in param_borders
306
- }
287
+ missing_borders = {param: args[param] for param in search_space if param not in param_borders}
307
288
  for param, value in missing_borders.items():
308
289
  if search_space[param]["type"] == "categorical":
309
290
  search_space[param]["args"] = [value]
@@ -315,21 +296,14 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
315
296
  def _check_borders(self, param, borders):
316
297
  """Raise value error if param borders are not valid"""
317
298
  if param not in self._search_space:
318
- raise ValueError(
319
- f"Hyper parameter {param} is not defined for {str(self)}"
320
- )
299
+ msg = f"Hyper parameter {param} is not defined for {self!s}"
300
+ raise ValueError(msg)
321
301
  if not isinstance(borders, list):
322
- raise ValueError(f"Parameter {param} borders are not a list")
323
- if (
324
- self._search_space[param]["type"] != "categorical"
325
- and len(borders) != 2
326
- ):
327
- raise ValueError(
328
- f"""
329
- Hyper parameter {param} is numerical
330
- but bounds are not in ([lower, upper]) format
331
- """
332
- )
302
+ msg = f"Parameter {param} borders are not a list"
303
+ raise ValueError()
304
+ if self._search_space[param]["type"] != "categorical" and len(borders) != 2:
305
+ msg = f"Hyper parameter {param} is numerical but bounds are not in ([lower, upper]) format"
306
+ raise ValueError(msg)
333
307
 
334
308
  def _prepare_split_data(
335
309
  self,
@@ -373,16 +347,12 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
373
347
  item_features = None
374
348
  if dataset.query_features is not None:
375
349
  query_features = dataset.query_features.join(
376
- dataset.interactions.select(
377
- dataset.feature_schema.query_id_column
378
- ).distinct(),
350
+ dataset.interactions.select(dataset.feature_schema.query_id_column).distinct(),
379
351
  on=dataset.feature_schema.query_id_column,
380
352
  )
381
353
  if dataset.item_features is not None:
382
354
  item_features = dataset.item_features.join(
383
- dataset.interactions.select(
384
- dataset.feature_schema.item_id_column
385
- ).distinct(),
355
+ dataset.interactions.select(dataset.feature_schema.item_id_column).distinct(),
386
356
  on=dataset.feature_schema.item_id_column,
387
357
  )
388
358
 
@@ -431,12 +401,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
431
401
  self.fit_items = sf.broadcast(items)
432
402
  self._num_queries = self.fit_queries.count()
433
403
  self._num_items = self.fit_items.count()
434
- self._query_dim_size = (
435
- self.fit_queries.agg({self.query_column: "max"}).collect()[0][0] + 1
436
- )
437
- self._item_dim_size = (
438
- self.fit_items.agg({self.item_column: "max"}).collect()[0][0] + 1
439
- )
404
+ self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).collect()[0][0] + 1
405
+ self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).collect()[0][0] + 1
440
406
  self._fit(dataset)
441
407
 
442
408
  @abstractmethod
@@ -452,18 +418,14 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
452
418
  :return:
453
419
  """
454
420
 
455
- def _filter_seen(
456
- self, recs: SparkDataFrame, interactions: SparkDataFrame, k: int, queries: SparkDataFrame
457
- ):
421
+ def _filter_seen(self, recs: SparkDataFrame, interactions: SparkDataFrame, k: int, queries: SparkDataFrame):
458
422
  """
459
423
  Filter seen items (presented in interactions) out of the queries' recommendations.
460
424
  For each query return from `k` to `k + number of seen by query` recommendations.
461
425
  """
462
426
  queries_interactions = interactions.join(queries, on=self.query_column)
463
427
  self._cache_model_temp_view(queries_interactions, "filter_seen_queries_interactions")
464
- num_seen = queries_interactions.groupBy(self.query_column).agg(
465
- sf.count(self.item_column).alias("seen_count")
466
- )
428
+ num_seen = queries_interactions.groupBy(self.query_column).agg(sf.count(self.item_column).alias("seen_count"))
467
429
  self._cache_model_temp_view(num_seen, "filter_seen_num_seen")
468
430
 
469
431
  # count maximal number of items seen by queries
@@ -474,11 +436,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
474
436
  # crop recommendations to first k + max_seen items for each query
475
437
  recs = recs.withColumn(
476
438
  "temp_rank",
477
- sf.row_number().over(
478
- Window.partitionBy(self.query_column).orderBy(
479
- sf.col(self.rating_column).desc()
480
- )
481
- ),
439
+ sf.row_number().over(Window.partitionBy(self.query_column).orderBy(sf.col(self.rating_column).desc())),
482
440
  ).filter(sf.col("temp_rank") <= sf.lit(max_seen + k))
483
441
 
484
442
  # leave k + number of items seen by query recommendations in recs
@@ -494,8 +452,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
494
452
  queries_interactions.withColumnRenamed(self.item_column, "item")
495
453
  .withColumnRenamed(self.query_column, "query")
496
454
  .select("query", "item"),
497
- on=(sf.col(self.query_column) == sf.col("query"))
498
- & (sf.col(self.item_column) == sf.col("item")),
455
+ on=(sf.col(self.query_column) == sf.col("query")) & (sf.col(self.item_column) == sf.col("item")),
499
456
  how="anti",
500
457
  ).drop("query", "item")
501
458
 
@@ -556,7 +513,6 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
556
513
  )
557
514
  return dataset, queries, items
558
515
 
559
- # pylint: disable=too-many-arguments
560
516
  def _predict_wrap(
561
517
  self,
562
518
  dataset: Optional[Dataset],
@@ -589,9 +545,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
589
545
  :return: cached recommendation dataframe with columns ``[user_idx, item_idx, rating]``
590
546
  or None if `file_path` is provided
591
547
  """
592
- dataset, queries, items = self._filter_interactions_queries_items_dataframes(
593
- dataset, k, queries, items
594
- )
548
+ dataset, queries, items = self._filter_interactions_queries_items_dataframes(dataset, k, queries, items)
595
549
 
596
550
  recs = self._predict(
597
551
  dataset,
@@ -630,21 +584,16 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
630
584
  if can_predict_cold:
631
585
  return main_df, interactions_df
632
586
 
633
- num_new, main_df = filter_cold(
634
- main_df, fit_entities, col_name=column
635
- )
587
+ num_new, main_df = filter_cold(main_df, fit_entities, col_name=column)
636
588
  if num_new > 0:
637
589
  self.logger.info(
638
590
  "%s model can't predict cold %ss, they will be ignored",
639
591
  self,
640
592
  entity,
641
593
  )
642
- _, interactions_df = filter_cold(
643
- interactions_df, fit_entities, col_name=column
644
- )
594
+ _, interactions_df = filter_cold(interactions_df, fit_entities, col_name=column)
645
595
  return main_df, interactions_df
646
596
 
647
- # pylint: disable=too-many-arguments
648
597
  @abstractmethod
649
598
  def _predict(
650
599
  self,
@@ -673,12 +622,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
673
622
  """
674
623
 
675
624
  def _predict_proba(
676
- self,
677
- dataset : Dataset,
678
- k: int,
679
- queries: SparkDataFrame,
680
- items: SparkDataFrame,
681
- filter_seen_items: bool = True
625
+ self, dataset: Dataset, k: int, queries: SparkDataFrame, items: SparkDataFrame, filter_seen_items: bool = True
682
626
  ) -> np.ndarray:
683
627
  """
684
628
  Inner method where model actually predicts.
@@ -706,11 +650,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
706
650
  n_users = queries.select("user_idx").count()
707
651
  n_items = items.select("item_idx").count()
708
652
 
709
- recs = self._predict(dataset,
710
- k,
711
- queries,
712
- items,
713
- filter_seen_items)
653
+ recs = self._predict(dataset, k, queries, items, filter_seen_items)
714
654
 
715
655
  recs = get_top_k_recs(recs, k=k, query_column=self.query_column, rating_column=self.rating_column).select(
716
656
  self.query_column, self.item_column, self.rating_column
@@ -718,17 +658,20 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
718
658
 
719
659
  cols = [f"k{i}" for i in range(k)]
720
660
 
721
- recs_items = recs.groupBy("user_idx").agg(
722
- sf.collect_list("item_idx").alias("item_idx")).select(
723
- [sf.col("item_idx")[i].alias(cols[i]) for i in range(k)]
661
+ recs_items = (
662
+ recs.groupBy("user_idx")
663
+ .agg(sf.collect_list("item_idx").alias("item_idx"))
664
+ .select([sf.col("item_idx")[i].alias(cols[i]) for i in range(k)])
724
665
  )
725
666
 
726
667
  action_dist = np.zeros(shape=(n_users, n_items, k))
727
668
 
728
669
  for i in range(k):
729
- action_dist[np.arange(n_users),
730
- recs_items.select(cols[i]).toPandas()[cols[i]].to_numpy(),
731
- np.ones(n_users, dtype=int) * i] += 1
670
+ action_dist[
671
+ np.arange(n_users),
672
+ recs_items.select(cols[i]).toPandas()[cols[i]].to_numpy(),
673
+ np.ones(n_users, dtype=int) * i,
674
+ ] += 1
732
675
 
733
676
  return action_dist
734
677
 
@@ -765,10 +708,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
765
708
  setattr(
766
709
  self,
767
710
  dim_size,
768
- fit_entities
769
- .agg({column: "max"})
770
- .collect()[0][0]
771
- + 1,
711
+ fit_entities.agg({column: "max"}).collect()[0][0] + 1,
772
712
  )
773
713
  return getattr(self, dim_size)
774
714
 
@@ -829,13 +769,11 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
829
769
  """
830
770
  if dataset is not None:
831
771
  interactions, query_features, item_features, pairs = [
832
- convert2spark(df)
833
- for df in [dataset.interactions, dataset.query_features, dataset.item_features, pairs]
772
+ convert2spark(df) for df in [dataset.interactions, dataset.query_features, dataset.item_features, pairs]
834
773
  ]
835
- if set(pairs.columns) != set([self.item_column, self.query_column]):
836
- raise ValueError(
837
- "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
838
- )
774
+ if set(pairs.columns) != {self.item_column, self.query_column}:
775
+ msg = "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
776
+ raise ValueError(msg)
839
777
  pairs, interactions = self._filter_cold_for_predict(pairs, interactions, "query")
840
778
  pairs, interactions = self._filter_cold_for_predict(pairs, interactions, "item")
841
779
 
@@ -908,13 +846,13 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
908
846
  self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
909
847
  ) -> Optional[Tuple[SparkDataFrame, int]]:
910
848
  if self.query_column not in ids.columns and self.item_column not in ids.columns:
911
- raise ValueError(f"{self.query_column} or {self.item_column} missing")
849
+ msg = f"{self.query_column} or {self.item_column} missing"
850
+ raise ValueError(msg)
912
851
  vectors, rank = self._get_features(ids, features)
913
852
  return vectors, rank
914
853
 
915
- # pylint: disable=unused-argument
916
854
  def _get_features(
917
- self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
855
+ self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
918
856
  ) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
919
857
  """
920
858
  Get embeddings from model
@@ -961,39 +899,26 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
961
899
  k=k,
962
900
  )
963
901
 
964
- nearest_items = nearest_items.withColumnRenamed(
965
- "item_idx_two", "neighbour_item_idx"
966
- )
967
- nearest_items = nearest_items.withColumnRenamed(
968
- "item_idx_one", self.item_column
969
- )
902
+ nearest_items = nearest_items.withColumnRenamed("item_idx_two", "neighbour_item_idx")
903
+ nearest_items = nearest_items.withColumnRenamed("item_idx_one", self.item_column)
970
904
  return nearest_items
971
905
 
972
906
  def _get_nearest_items(
973
907
  self,
974
- items: SparkDataFrame,
975
- metric: Optional[str] = None,
976
- candidates: Optional[SparkDataFrame] = None,
908
+ items: SparkDataFrame, # noqa: ARG002
909
+ metric: Optional[str] = None, # noqa: ARG002
910
+ candidates: Optional[SparkDataFrame] = None, # noqa: ARG002
977
911
  ) -> Optional[SparkDataFrame]:
978
- raise NotImplementedError(
979
- f"item-to-item prediction is not implemented for {self}"
980
- )
912
+ msg = f"item-to-item prediction is not implemented for {self}"
913
+ raise NotImplementedError(msg)
981
914
 
982
915
  def _params_tried(self):
983
916
  """check if current parameters were already evaluated"""
984
917
  if self.study is None:
985
918
  return False
986
919
 
987
- params = {
988
- name: value
989
- for name, value in self._init_args.items()
990
- if name in self._search_space
991
- }
992
- for trial in self.study.trials:
993
- if params == trial.params:
994
- return True
995
-
996
- return False
920
+ params = {name: value for name, value in self._init_args.items() if name in self._search_space}
921
+ return any(params == trial.params for trial in self.study.trials)
997
922
 
998
923
  def _save_model(self, path: str, additional_params: Optional[dict] = None):
999
924
  saved_params = {
@@ -1004,10 +929,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
1004
929
  }
1005
930
  if additional_params is not None:
1006
931
  saved_params.update(additional_params)
1007
- save_picklable_to_parquet(
1008
- saved_params,
1009
- join(path, "params.dump")
1010
- )
932
+ save_picklable_to_parquet(saved_params, join(path, "params.dump"))
1011
933
 
1012
934
  def _load_model(self, path: str):
1013
935
  loaded_params = load_pickled_from_parquet(join(path, "params.dump"))
@@ -1053,10 +975,8 @@ class ItemVectorModel(BaseRecommender):
1053
975
  spark-dataframe with columns ``[item_idx, neighbour_item_idx, similarity]``
1054
976
  """
1055
977
  if metric not in self.item_to_item_metrics:
1056
- raise ValueError(
1057
- f"Select one of the valid distance metrics: "
1058
- f"{self.item_to_item_metrics}"
1059
- )
978
+ msg = f"Select one of the valid distance metrics: {self.item_to_item_metrics}"
979
+ raise ValueError(msg)
1060
980
 
1061
981
  return self._get_nearest_items_wrap(
1062
982
  items=items,
@@ -1098,9 +1018,9 @@ class ItemVectorModel(BaseRecommender):
1098
1018
  )
1099
1019
  )
1100
1020
 
1101
- right_part = items_vectors.withColumnRenamed(
1102
- self.item_column, "item_idx_two"
1103
- ).withColumnRenamed("item_vector", "item_vector_two")
1021
+ right_part = items_vectors.withColumnRenamed(self.item_column, "item_idx_two").withColumnRenamed(
1022
+ "item_vector", "item_vector_two"
1023
+ )
1104
1024
 
1105
1025
  if candidates is not None:
1106
1026
  right_part = right_part.join(
@@ -1108,25 +1028,18 @@ class ItemVectorModel(BaseRecommender):
1108
1028
  on="item_idx_two",
1109
1029
  )
1110
1030
 
1111
- joined_factors = left_part.join(
1112
- right_part, on=sf.col("item_idx_one") != sf.col("item_idx_two")
1113
- )
1031
+ joined_factors = left_part.join(right_part, on=sf.col("item_idx_one") != sf.col("item_idx_two"))
1114
1032
 
1115
1033
  joined_factors = joined_factors.withColumn(
1116
1034
  metric,
1117
- dist_function(
1118
- sf.col("item_vector_one"), sf.col("item_vector_two")
1119
- ),
1035
+ dist_function(sf.col("item_vector_one"), sf.col("item_vector_two")),
1120
1036
  )
1121
1037
 
1122
- similarity_matrix = joined_factors.select(
1123
- "item_idx_one", "item_idx_two", metric
1124
- )
1038
+ similarity_matrix = joined_factors.select("item_idx_one", "item_idx_two", metric)
1125
1039
 
1126
1040
  return similarity_matrix
1127
1041
 
1128
1042
 
1129
- # pylint: disable=abstract-method
1130
1043
  class HybridRecommender(BaseRecommender, ABC):
1131
1044
  """Base class for models that can use extra features"""
1132
1045
 
@@ -1143,7 +1056,6 @@ class HybridRecommender(BaseRecommender, ABC):
1143
1056
  """
1144
1057
  self._fit_wrap(dataset=dataset)
1145
1058
 
1146
- # pylint: disable=too-many-arguments
1147
1059
  def predict(
1148
1060
  self,
1149
1061
  dataset: Dataset,
@@ -1260,7 +1172,6 @@ class HybridRecommender(BaseRecommender, ABC):
1260
1172
  return self._get_features_wrap(ids, features)
1261
1173
 
1262
1174
 
1263
- # pylint: disable=abstract-method
1264
1175
  class Recommender(BaseRecommender, ABC):
1265
1176
  """Usual recommender class for models without features."""
1266
1177
 
@@ -1274,7 +1185,6 @@ class Recommender(BaseRecommender, ABC):
1274
1185
  """
1275
1186
  self._fit_wrap(dataset=dataset)
1276
1187
 
1277
- # pylint: disable=too-many-arguments
1278
1188
  def predict(
1279
1189
  self,
1280
1190
  dataset: Dataset,
@@ -1340,7 +1250,6 @@ class Recommender(BaseRecommender, ABC):
1340
1250
  k=k,
1341
1251
  )
1342
1252
 
1343
- # pylint: disable=too-many-arguments
1344
1253
  def fit_predict(
1345
1254
  self,
1346
1255
  dataset: Dataset,
@@ -1406,7 +1315,6 @@ class QueryRecommender(BaseRecommender, ABC):
1406
1315
  """
1407
1316
  self._fit_wrap(dataset=dataset)
1408
1317
 
1409
- # pylint: disable=too-many-arguments
1410
1318
  def predict(
1411
1319
  self,
1412
1320
  dataset: Dataset,
@@ -1436,7 +1344,8 @@ class QueryRecommender(BaseRecommender, ABC):
1436
1344
  or None if `file_path` is provided
1437
1345
  """
1438
1346
  if not dataset or not dataset.query_features:
1439
- raise ValueError("Query features are missing for predict")
1347
+ msg = "Query features are missing for predict"
1348
+ raise ValueError(msg)
1440
1349
 
1441
1350
  return self._predict_wrap(
1442
1351
  dataset=dataset,
@@ -1469,7 +1378,8 @@ class QueryRecommender(BaseRecommender, ABC):
1469
1378
  or None if `file_path` is provided
1470
1379
  """
1471
1380
  if not dataset or not dataset.query_features:
1472
- raise ValueError("Query features are missing for predict")
1381
+ msg = "Query features are missing for predict"
1382
+ raise ValueError(msg)
1473
1383
 
1474
1384
  return self._predict_pairs_wrap(
1475
1385
  pairs=pairs,
@@ -1496,15 +1406,14 @@ class NonPersonalizedRecommender(Recommender, ABC):
1496
1406
  if 0 < cold_weight <= 1:
1497
1407
  self.cold_weight = cold_weight
1498
1408
  else:
1499
- raise ValueError(
1500
- "`cold_weight` value should be in interval (0, 1]"
1501
- )
1409
+ msg = "`cold_weight` value should be in interval (0, 1]"
1410
+ raise ValueError(msg)
1502
1411
 
1503
1412
  @property
1504
1413
  def _dataframes(self):
1505
1414
  return {"item_popularity": self.item_popularity}
1506
1415
 
1507
- def _save_model(self, path: str, additional_params: Optional[dict] = None):
1416
+ def _save_model(self, path: str, additional_params: Optional[dict] = None): # noqa: ARG002
1508
1417
  super()._save_model(path, additional_params={"fill": self.fill})
1509
1418
 
1510
1419
  def _clear_cache(self):
@@ -1517,10 +1426,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1517
1426
  Calculating a fill value a the minimal rating
1518
1427
  calculated during model training multiplied by weight.
1519
1428
  """
1520
- return (
1521
- item_popularity.select(sf.min(rating_column)).collect()[0][0]
1522
- * weight
1523
- )
1429
+ return item_popularity.select(sf.min(rating_column)).collect()[0][0] * weight
1524
1430
 
1525
1431
  @staticmethod
1526
1432
  def _check_rating(dataset: Dataset):
@@ -1529,7 +1435,8 @@ class NonPersonalizedRecommender(Recommender, ABC):
1529
1435
  (sf.col(rating_column) != 1) & (sf.col(rating_column) != 0)
1530
1436
  )
1531
1437
  if vals.count() > 0:
1532
- raise ValueError("Rating values in interactions must be 0 or 1")
1438
+ msg = "Rating values in interactions must be 0 or 1"
1439
+ raise ValueError(msg)
1533
1440
 
1534
1441
  def _get_selected_item_popularity(self, items: SparkDataFrame) -> SparkDataFrame:
1535
1442
  """
@@ -1561,7 +1468,6 @@ class NonPersonalizedRecommender(Recommender, ABC):
1561
1468
 
1562
1469
  return max_hist_len
1563
1470
 
1564
- # pylint: disable=too-many-arguments
1565
1471
  def _predict_without_sampling(
1566
1472
  self,
1567
1473
  dataset: Dataset,
@@ -1577,11 +1483,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1577
1483
  selected_item_popularity = self._get_selected_item_popularity(items)
1578
1484
  selected_item_popularity = selected_item_popularity.withColumn(
1579
1485
  "rank",
1580
- sf.row_number().over(
1581
- Window.orderBy(
1582
- sf.col(self.rating_column).desc(), sf.col(self.item_column).desc()
1583
- )
1584
- ),
1486
+ sf.row_number().over(Window.orderBy(sf.col(self.rating_column).desc(), sf.col(self.item_column).desc())),
1585
1487
  )
1586
1488
 
1587
1489
  if filter_seen_items and dataset is not None:
@@ -1594,15 +1496,10 @@ class NonPersonalizedRecommender(Recommender, ABC):
1594
1496
  queries = queries.fillna(0, "num_items")
1595
1497
  # 'selected_item_popularity' truncation by k + max_seen
1596
1498
  max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).collect()[0][0]
1597
- selected_item_popularity = selected_item_popularity\
1598
- .filter(sf.col("rank") <= k + max_seen)
1599
- return queries.join(
1600
- selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left"
1601
- )
1499
+ selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
1500
+ return queries.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
1602
1501
 
1603
- return queries.crossJoin(
1604
- selected_item_popularity.filter(sf.col("rank") <= k)
1605
- ).drop("rank")
1502
+ return queries.crossJoin(selected_item_popularity.filter(sf.col("rank") <= k)).drop("rank")
1606
1503
 
1607
1504
  def get_items_pd(self, items: SparkDataFrame) -> pd.DataFrame:
1608
1505
  """
@@ -1612,26 +1509,22 @@ class NonPersonalizedRecommender(Recommender, ABC):
1612
1509
  selected_item_popularity = self._get_selected_item_popularity(items)
1613
1510
  selected_item_popularity = selected_item_popularity.withColumn(
1614
1511
  self.rating_column,
1615
- sf.when(sf.col(self.rating_column) == sf.lit(0.0), 0.1**6).otherwise(
1616
- sf.col(self.rating_column)
1617
- ),
1512
+ sf.when(sf.col(self.rating_column) == sf.lit(0.0), 0.1**6).otherwise(sf.col(self.rating_column)),
1618
1513
  )
1619
1514
 
1620
1515
  warnings.warn(
1621
1516
  "Prediction with sampling performs spark to pandas convertion to master node, "
1622
1517
  "this may lead to OOM exception for large item catalogue.",
1623
- SparkCollectToMasterWarning
1518
+ SparkCollectToMasterWarning,
1624
1519
  )
1625
1520
 
1626
1521
  items_pd = selected_item_popularity.withColumn(
1627
1522
  "probability",
1628
- sf.col(self.rating_column)
1629
- / selected_item_popularity.select(sf.sum(self.rating_column)).first()[0],
1523
+ sf.col(self.rating_column) / selected_item_popularity.select(sf.sum(self.rating_column)).first()[0],
1630
1524
  ).toPandas()
1631
1525
 
1632
1526
  return items_pd
1633
1527
 
1634
- # pylint: disable=too-many-locals
1635
1528
  def _predict_with_sampling(
1636
1529
  self,
1637
1530
  dataset: Dataset,
@@ -1667,10 +1560,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1667
1560
  query_idx = pandas_df[query_column][0]
1668
1561
  cnt = pandas_df["cnt"][0]
1669
1562
 
1670
- if seed is not None:
1671
- local_rng = default_rng(seed + query_idx)
1672
- else:
1673
- local_rng = default_rng()
1563
+ local_rng = default_rng(seed + query_idx) if seed is not None else default_rng()
1674
1564
 
1675
1565
  items_positions = local_rng.choice(
1676
1566
  np.arange(items_pd.shape[0]),
@@ -1716,7 +1606,6 @@ class NonPersonalizedRecommender(Recommender, ABC):
1716
1606
 
1717
1607
  return recs.groupby(self.query_column).applyInPandas(grouped_map, rec_schema)
1718
1608
 
1719
- # pylint: disable=too-many-arguments
1720
1609
  def _predict(
1721
1610
  self,
1722
1611
  dataset: Dataset,
@@ -1725,7 +1614,6 @@ class NonPersonalizedRecommender(Recommender, ABC):
1725
1614
  items: SparkDataFrame,
1726
1615
  filter_seen_items: bool = True,
1727
1616
  ) -> SparkDataFrame:
1728
-
1729
1617
  if self.sample:
1730
1618
  return self._predict_with_sampling(
1731
1619
  dataset=dataset,
@@ -1735,14 +1623,12 @@ class NonPersonalizedRecommender(Recommender, ABC):
1735
1623
  filter_seen_items=filter_seen_items,
1736
1624
  )
1737
1625
  else:
1738
- return self._predict_without_sampling(
1739
- dataset, k, queries, items, filter_seen_items
1740
- )
1626
+ return self._predict_without_sampling(dataset, k, queries, items, filter_seen_items)
1741
1627
 
1742
1628
  def _predict_pairs(
1743
1629
  self,
1744
1630
  pairs: SparkDataFrame,
1745
- dataset: Optional[Dataset] = None,
1631
+ dataset: Optional[Dataset] = None, # noqa: ARG002
1746
1632
  ) -> SparkDataFrame:
1747
1633
  return (
1748
1634
  pairs.join(
@@ -1755,12 +1641,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1755
1641
  )
1756
1642
 
1757
1643
  def _predict_proba(
1758
- self,
1759
- dataset : Dataset,
1760
- k: int,
1761
- queries: SparkDataFrame,
1762
- items: SparkDataFrame,
1763
- filter_seen_items: bool = True
1644
+ self, dataset: Dataset, k: int, queries: SparkDataFrame, items: SparkDataFrame, filter_seen_items: bool = True
1764
1645
  ) -> np.ndarray:
1765
1646
  """
1766
1647
  Inner method where model actually predicts.
@@ -1799,8 +1680,4 @@ class NonPersonalizedRecommender(Recommender, ABC):
1799
1680
 
1800
1681
  return np.tile(items_pd, (n_users, k)).reshape(n_users, k, n_items).transpose((0, 2, 1))
1801
1682
 
1802
- return super()._predict_proba(dataset,
1803
- k,
1804
- queries,
1805
- items,
1806
- filter_seen_items)
1683
+ return super()._predict_proba(dataset, k, queries, items, filter_seen_items)