replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -1,7 +1,7 @@
1
1
  import warnings
2
2
  from typing import Dict, List, Optional, Tuple, Union
3
3
 
4
- from replay.utils import PandasDataFrame, SparkDataFrame, PolarsDataFrame
4
+ from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
5
5
 
6
6
  from .base_metric import Metric, MetricsDataFrameLike, MetricsReturnType
7
7
  from .coverage import Coverage
@@ -10,7 +10,6 @@ from .recall import Recall
10
10
  from .surprisal import Surprisal
11
11
 
12
12
 
13
- # pylint: disable=too-few-public-methods
14
13
  class OfflineMetrics:
15
14
  """
16
15
  Designed for efficient calculation of offline metrics provided by the RePlay.
@@ -146,7 +145,6 @@ class OfflineMetrics:
146
145
  "Recall": ["ground_truth"],
147
146
  }
148
147
 
149
- # pylint: disable=too-many-arguments
150
148
  def __init__(
151
149
  self,
152
150
  metrics: List[Metric],
@@ -220,15 +218,11 @@ class OfflineMetrics:
220
218
  default_metric._check_duplicates_polars(recommendations)
221
219
  unchanged_recs = recommendations
222
220
 
223
- # pylint: disable=too-many-function-args
224
- result_dict["default"] = default_metric._get_enriched_recommendations(
225
- recommendations, ground_truth
226
- )
221
+ result_dict["default"] = default_metric._get_enriched_recommendations(recommendations, ground_truth)
227
222
 
228
223
  for metric in self.metrics:
229
224
  # find Coverage
230
225
  if metric.__class__.__name__ == "Coverage":
231
- # pylint: disable=protected-access
232
226
  result_dict["Coverage"] = Coverage(
233
227
  topk=2,
234
228
  query_column=query_column,
@@ -244,9 +238,7 @@ class OfflineMetrics:
244
238
  item_column=item_column,
245
239
  rating_column=rating_column,
246
240
  )
247
- cur_recs = novelty_metric._get_enriched_recommendations(
248
- unchanged_recs, train
249
- )
241
+ cur_recs = novelty_metric._get_enriched_recommendations(unchanged_recs, train)
250
242
  if is_spark:
251
243
  cur_recs = cur_recs.withColumnRenamed("ground_truth", "train")
252
244
  else:
@@ -265,12 +257,10 @@ class OfflineMetrics:
265
257
 
266
258
  return result_dict, train
267
259
 
268
- # pylint: disable=no-self-use
269
260
  def _cache_dataframes(self, dataframes: Dict[str, SparkDataFrame]) -> None:
270
261
  for data in dataframes.values():
271
262
  data.cache()
272
263
 
273
- # pylint: disable=no-self-use
274
264
  def _unpersist_dataframes(self, dataframes: Dict[str, SparkDataFrame]) -> None:
275
265
  for data in dataframes.values():
276
266
  data.unpersist()
@@ -294,22 +284,18 @@ class OfflineMetrics:
294
284
  else:
295
285
  metric_args["recs"] = enriched_recs_dict["default"]
296
286
 
297
- # pylint: disable=protected-access
298
287
  if is_spark:
299
288
  result.update(metric._spark_compute(**metric_args))
300
289
  else:
301
290
  result.update(metric._polars_compute(**metric_args))
302
291
  return result
303
292
 
304
- # pylint: disable=no-self-use
305
293
  def _check_dataframes_types(
306
294
  self,
307
295
  recommendations: MetricsDataFrameLike,
308
296
  ground_truth: MetricsDataFrameLike,
309
297
  train: Optional[MetricsDataFrameLike],
310
- base_recommendations: Optional[
311
- Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]
312
- ],
298
+ base_recommendations: Optional[Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]],
313
299
  ) -> None:
314
300
  types = set()
315
301
  types.add(type(recommendations))
@@ -317,7 +303,7 @@ class OfflineMetrics:
317
303
  if train is not None:
318
304
  types.add(type(train))
319
305
  if isinstance(base_recommendations, dict):
320
- for _, df in base_recommendations.items():
306
+ for df in base_recommendations.values():
321
307
  if not isinstance(df, list):
322
308
  types.add(type(df))
323
309
  else:
@@ -327,7 +313,8 @@ class OfflineMetrics:
327
313
  types.add(type(base_recommendations))
328
314
 
329
315
  if len(types) != 1:
330
- raise ValueError("All given data frames must have the same type")
316
+ msg = "All given data frames must have the same type"
317
+ raise ValueError(msg)
331
318
 
332
319
  def _check_query_column_present(
333
320
  self,
@@ -350,7 +337,8 @@ class OfflineMetrics:
350
337
  dataset_names = dataset.columns
351
338
 
352
339
  if not isinstance(dataset, dict) and query_column not in dataset_names:
353
- raise KeyError(f"Query column {query_column} is not present in {dataset_name} dataframe")
340
+ msg = f"Query column {query_column} is not present in {dataset_name} dataframe"
341
+ raise KeyError(msg)
354
342
 
355
343
  def _get_unique_queries(
356
344
  self,
@@ -386,14 +374,12 @@ class OfflineMetrics:
386
374
  if queries.issubset(other_queries) is False:
387
375
  warnings.warn(f"{dataset_name} contains queries that are not presented in recommendations")
388
376
 
389
- def __call__( # pylint: disable=too-many-branches, too-many-locals, too-many-statements
377
+ def __call__( # noqa: C901
390
378
  self,
391
379
  recommendations: MetricsDataFrameLike,
392
380
  ground_truth: MetricsDataFrameLike,
393
381
  train: Optional[MetricsDataFrameLike] = None,
394
- base_recommendations: Optional[
395
- Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]
396
- ] = None,
382
+ base_recommendations: Optional[Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]] = None,
397
383
  ) -> Dict[str, float]:
398
384
  """
399
385
  Compute metrics.
@@ -424,9 +410,7 @@ class OfflineMetrics:
424
410
 
425
411
  :return: metric values
426
412
  """
427
- self._check_dataframes_types(
428
- recommendations, ground_truth, train, base_recommendations
429
- )
413
+ self._check_dataframes_types(recommendations, ground_truth, train, base_recommendations)
430
414
 
431
415
  if len(self.main_metrics) > 0:
432
416
  query_column = self.main_metrics[0].query_column
@@ -443,31 +427,22 @@ class OfflineMetrics:
443
427
 
444
428
  if train is not None:
445
429
  self._check_query_column_present(train, query_column, "train")
446
- self._check_contains(
447
- recs_queries,
448
- self._get_unique_queries(train, query_column),
449
- "train"
450
- )
430
+ self._check_contains(recs_queries, self._get_unique_queries(train, query_column), "train")
451
431
  if base_recommendations is not None:
452
- if (not isinstance(base_recommendations, dict)
453
- or isinstance(next(iter(base_recommendations.values())), list)):
432
+ if not isinstance(base_recommendations, dict) or isinstance(
433
+ next(iter(base_recommendations.values())), list
434
+ ):
454
435
  base_recommendations = {"base_recommendations": base_recommendations}
455
436
  for name, dataset in base_recommendations.items():
456
437
  self._check_query_column_present(dataset, query_column, name)
457
- self._check_contains(
458
- recs_queries,
459
- self._get_unique_queries(dataset, query_column),
460
- name
461
- )
438
+ self._check_contains(recs_queries, self._get_unique_queries(dataset, query_column), name)
462
439
 
463
440
  result = {}
464
441
  if isinstance(recommendations, (SparkDataFrame, PolarsDataFrame)):
465
442
  is_spark = isinstance(recommendations, SparkDataFrame)
466
443
  assert isinstance(ground_truth, type(recommendations))
467
444
  assert train is None or isinstance(train, type(recommendations))
468
- enriched_recs_dict, train = self._get_enriched_recommendations(
469
- recommendations, ground_truth, train
470
- )
445
+ enriched_recs_dict, train = self._get_enriched_recommendations(recommendations, ground_truth, train)
471
446
 
472
447
  if is_spark and self._allow_caching:
473
448
  self._cache_dataframes(enriched_recs_dict)
@@ -480,12 +455,8 @@ class OfflineMetrics:
480
455
  "train": train,
481
456
  }
482
457
  for metric in self.metrics:
483
- args_to_call: Dict[str, Union[PandasDataFrame, Dict]] = {
484
- "recommendations": recommendations
485
- }
486
- for data_name in self._metrics_call_requirement_map[
487
- str(metric.__class__.__name__)
488
- ]:
458
+ args_to_call: Dict[str, Union[PandasDataFrame, Dict]] = {"recommendations": recommendations}
459
+ for data_name in self._metrics_call_requirement_map[str(metric.__class__.__name__)]:
489
460
  args_to_call[data_name] = current_map[data_name]
490
461
  result.update(metric(**args_to_call))
491
462
  unexpectedness_result = {}
@@ -493,23 +464,17 @@ class OfflineMetrics:
493
464
 
494
465
  if len(self.unexpectedness_metric) != 0:
495
466
  if base_recommendations is None:
496
- raise ValueError(
497
- "Can not calculate Unexpectedness because base_recommendations is None"
498
- )
499
- if isinstance(base_recommendations, dict) and not isinstance(
500
- list(base_recommendations.values())[0], list
501
- ):
467
+ msg = "Can not calculate Unexpectedness because base_recommendations is None"
468
+ raise ValueError(msg)
469
+ first_element = next(iter(base_recommendations.values()))
470
+ if isinstance(base_recommendations, dict) and not isinstance(first_element, list):
502
471
  for unexp in self.unexpectedness_metric:
503
472
  for model_name in base_recommendations:
504
- cur_result = unexp(
505
- recommendations, base_recommendations[model_name]
506
- )
473
+ cur_result = unexp(recommendations, base_recommendations[model_name])
507
474
  for metric_name in cur_result:
508
475
  splitted = metric_name.split("@")
509
476
  splitted[0] += "_" + model_name
510
- unexpectedness_result["@".join(splitted)] = cur_result[
511
- metric_name
512
- ]
477
+ unexpectedness_result["@".join(splitted)] = cur_result[metric_name]
513
478
 
514
479
  if len(self.diversity_metric) != 0:
515
480
  for diversity in self.diversity_metric:
@@ -3,16 +3,15 @@ from typing import List
3
3
  from .base_metric import Metric
4
4
 
5
5
 
6
- # pylint: disable=too-few-public-methods
7
6
  class Precision(Metric):
8
7
  """
9
8
  Mean percentage of relevant items among top ``K`` recommendations.
10
9
 
11
10
  .. math::
12
- Precision@K(i) = \\frac {\sum_{j=1}^{K}\mathbb{1}_{r_{ij}}}{K}
11
+ Precision@K(i) = \\frac {\\sum_{j=1}^{K}\\mathbb{1}_{r_{ij}}}{K}
13
12
 
14
13
  .. math::
15
- Precision@K = \\frac {\sum_{i=1}^{N}Precision@K(i)}{N}
14
+ Precision@K = \\frac {\\sum_{i=1}^{N}Precision@K(i)}{N}
16
15
 
17
16
  :math:`\\mathbb{1}_{r_{ij}}` -- indicator function showing that user :math:`i` interacted with item :math:`j`
18
17
 
@@ -62,9 +61,7 @@ class Precision(Metric):
62
61
  """
63
62
 
64
63
  @staticmethod
65
- def _get_metric_value_by_user( # pylint: disable=arguments-differ
66
- ks: List[int], ground_truth: List, pred: List
67
- ) -> List[float]:
64
+ def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
68
65
  if not ground_truth or not pred:
69
66
  return [0.0 for _ in ks]
70
67
  set_gt = set(ground_truth)
replay/metrics/recall.py CHANGED
@@ -3,7 +3,6 @@ from typing import List
3
3
  from .base_metric import Metric
4
4
 
5
5
 
6
- # pylint: disable=too-few-public-methods
7
6
  class Recall(Metric):
8
7
  """
9
8
  Recall measures the coverage of the recommended items, and is defined as:
@@ -11,10 +10,10 @@ class Recall(Metric):
11
10
  Mean percentage of relevant items, that was shown among top ``K`` recommendations.
12
11
 
13
12
  .. math::
14
- Recall@K(i) = \\frac {\sum_{j=1}^{K}\mathbb{1}_{r_{ij}}}{|Rel_i|}
13
+ Recall@K(i) = \\frac {\\sum_{j=1}^{K}\\mathbb{1}_{r_{ij}}}{|Rel_i|}
15
14
 
16
15
  .. math::
17
- Recall@K = \\frac {\sum_{i=1}^{N}Recall@K(i)}{N}
16
+ Recall@K = \\frac {\\sum_{i=1}^{N}Recall@K(i)}{N}
18
17
 
19
18
  :math:`\\mathbb{1}_{r_{ij}}` -- indicator function showing that user :math:`i` interacted with item :math:`j`
20
19
 
@@ -66,9 +65,7 @@ class Recall(Metric):
66
65
  """
67
66
 
68
67
  @staticmethod
69
- def _get_metric_value_by_user( # pylint: disable=arguments-differ
70
- ks: List[int], ground_truth: List, pred: List
71
- ) -> List[float]:
68
+ def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
72
69
  if not ground_truth or not pred:
73
70
  return [0.0 for _ in ks]
74
71
  set_gt = set(ground_truth)
replay/metrics/rocauc.py CHANGED
@@ -3,7 +3,6 @@ from typing import List
3
3
  from .base_metric import Metric
4
4
 
5
5
 
6
- # pylint: disable=too-few-public-methods
7
6
  class RocAuc(Metric):
8
7
  """
9
8
  Receiver Operating Characteristic/Area Under the Curve is the aggregated performance measure,
@@ -13,21 +12,21 @@ class RocAuc(Metric):
13
12
  The bigger the value of AUC, the better the classification model.
14
13
 
15
14
  .. math::
16
- ROCAUC@K(i) = \\frac {\sum_{s=1}^{K}\sum_{t=1}^{K}
17
- \mathbb{1}_{r_{si}<r_{ti}}
18
- \mathbb{1}_{gt_{si}<gt_{ti}}}
19
- {\sum_{s=1}^{K}\sum_{t=1}^{K} \mathbb{1}_{gt_{si}<gt_{tj}}}
15
+ ROCAUC@K(i) = \\frac {\\sum_{s=1}^{K}\\sum_{t=1}^{K}
16
+ \\mathbb{1}_{r_{si}<r_{ti}}
17
+ \\mathbb{1}_{gt_{si}<gt_{ti}}}
18
+ {\\sum_{s=1}^{K}\\sum_{t=1}^{K} \\mathbb{1}_{gt_{si}<gt_{tj}}}
20
19
 
21
20
  :math:`\\mathbb{1}_{r_{si}<r_{ti}}` -- indicator function showing that recommendation score for
22
21
  user :math:`i` for item :math:`s` is bigger than for item :math:`t`
23
22
 
24
- :math:`\mathbb{1}_{gt_{si}<gt_{ti}}` -- indicator function showing that
23
+ :math:`\\mathbb{1}_{gt_{si}<gt_{ti}}` -- indicator function showing that
25
24
  user :math:`i` values item :math:`s` more than item :math:`t`.
26
25
 
27
26
  Metric is averaged by all users.
28
27
 
29
28
  .. math::
30
- ROCAUC@K = \\frac {\sum_{i=1}^{N}ROCAUC@K(i)}{N}
29
+ ROCAUC@K = \\frac {\\sum_{i=1}^{N}ROCAUC@K(i)}{N}
31
30
 
32
31
  >>> recommendations
33
32
  query_id item_id rating
@@ -75,9 +74,7 @@ class RocAuc(Metric):
75
74
  """
76
75
 
77
76
  @staticmethod
78
- def _get_metric_value_by_user( # pylint: disable=arguments-differ
79
- ks: List[int], ground_truth: List, pred: List
80
- ) -> List[float]:
77
+ def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
81
78
  if not ground_truth or not pred:
82
79
  return [0.0 for _ in ks]
83
80
  set_gt = set(ground_truth)
@@ -4,7 +4,7 @@ from typing import Dict, List, Union
4
4
  import numpy as np
5
5
  import polars as pl
6
6
 
7
- from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame, PolarsDataFrame
7
+ from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, PolarsDataFrame, SparkDataFrame
8
8
 
9
9
  from .base_metric import Metric, MetricsDataFrameLike, MetricsReturnType
10
10
 
@@ -12,13 +12,12 @@ if PYSPARK_AVAILABLE:
12
12
  from pyspark.sql import functions as sf
13
13
 
14
14
 
15
- # pylint: disable=too-few-public-methods
16
15
  class Surprisal(Metric):
17
16
  """
18
17
  Measures how many surprising rare items are present in recommendations.
19
18
 
20
19
  .. math::
21
- \\textit{Self-Information}(j)= -\log_2 \\frac {u_j}{N}
20
+ \\textit{Self-Information}(j)= -\\log_2 \\frac {u_j}{N}
22
21
 
23
22
  :math:`u_j` -- number of users that interacted with item :math:`j`.
24
23
  Cold items are treated as if they were rated by 1 user.
@@ -32,12 +31,12 @@ class Surprisal(Metric):
32
31
  Recommendation list surprisal is the average surprisal of items in it.
33
32
 
34
33
  .. math::
35
- Surprisal@K(i) = \\frac {\sum_{j=1}^{K}Surprisal(j)} {K}
34
+ Surprisal@K(i) = \\frac {\\sum_{j=1}^{K}Surprisal(j)} {K}
36
35
 
37
36
  Final metric is averaged by users.
38
37
 
39
38
  .. math::
40
- Surprisal@K = \\frac {\sum_{i=1}^{N}Surprisal@K(i)}{N}
39
+ Surprisal@K = \\frac {\\sum_{i=1}^{N}Surprisal@K(i)}{N}
41
40
 
42
41
  :math:`N` -- the number of users.
43
42
 
@@ -83,7 +82,6 @@ class Surprisal(Metric):
83
82
  <BLANKLINE>
84
83
  """
85
84
 
86
- # pylint: disable=no-self-use
87
85
  def _get_weights(self, train: Dict) -> Dict:
88
86
  n_users = len(train.keys())
89
87
  items_counter = defaultdict(set)
@@ -102,7 +100,6 @@ class Surprisal(Metric):
102
100
  recs_with_weights[user] = [weights.get(i, 1) for i in items]
103
101
  return recs_with_weights
104
102
 
105
- # pylint: disable=arguments-renamed
106
103
  def _get_enriched_recommendations(
107
104
  self,
108
105
  recommendations: Union[PolarsDataFrame, SparkDataFrame],
@@ -113,38 +110,28 @@ class Surprisal(Metric):
113
110
  else:
114
111
  return self._get_enriched_recommendations_polars(recommendations, train)
115
112
 
116
- def _get_enriched_recommendations_spark( # pylint: disable=arguments-renamed
113
+ def _get_enriched_recommendations_spark(
117
114
  self, recommendations: SparkDataFrame, train: SparkDataFrame
118
115
  ) -> SparkDataFrame:
119
116
  n_users = train.select(self.query_column).distinct().count()
120
117
  item_weights = train.groupby(self.item_column).agg(
121
- (
122
- sf.log2(n_users / sf.countDistinct(self.query_column)) / np.log2(n_users)
123
- ).alias("weight")
118
+ (sf.log2(n_users / sf.countDistinct(self.query_column)) / np.log2(n_users)).alias("weight")
124
119
  )
125
- recommendations = recommendations.join(
126
- item_weights, on=self.item_column, how="left"
127
- ).fillna(1.0)
120
+ recommendations = recommendations.join(item_weights, on=self.item_column, how="left").fillna(1.0)
128
121
 
129
- sorted_by_score_recommendations = self._get_items_list_per_user(
130
- recommendations, "weight"
131
- )
122
+ sorted_by_score_recommendations = self._get_items_list_per_user(recommendations, "weight")
132
123
  return self._rearrange_columns(sorted_by_score_recommendations)
133
124
 
134
- def _get_enriched_recommendations_polars( # pylint: disable=arguments-renamed
125
+ def _get_enriched_recommendations_polars(
135
126
  self, recommendations: PolarsDataFrame, train: PolarsDataFrame
136
127
  ) -> PolarsDataFrame:
137
128
  n_users = train.select(self.query_column).n_unique()
138
129
  item_weights = train.group_by(self.item_column).agg(
139
130
  (np.log2(n_users / pl.col(self.query_column).n_unique()) / np.log2(n_users)).alias("weight")
140
131
  )
141
- recommendations = recommendations.join(
142
- item_weights, on=self.item_column, how="left"
143
- ).fill_nan(1.0)
132
+ recommendations = recommendations.join(item_weights, on=self.item_column, how="left").fill_nan(1.0)
144
133
 
145
- sorted_by_score_recommendations = self._get_items_list_per_user(
146
- recommendations, "weight"
147
- )
134
+ sorted_by_score_recommendations = self._get_items_list_per_user(recommendations, "weight")
148
135
  return self._rearrange_columns(sorted_by_score_recommendations)
149
136
 
150
137
  def __call__(
@@ -183,9 +170,7 @@ class Surprisal(Metric):
183
170
  else self._convert_dict_to_dict_with_score(recommendations)
184
171
  )
185
172
  self._check_duplicates_dict(recommendations)
186
- train = (
187
- self._convert_pandas_to_dict_without_score(train) if is_pandas else train
188
- )
173
+ train = self._convert_pandas_to_dict_without_score(train) if is_pandas else train
189
174
  assert isinstance(train, dict)
190
175
 
191
176
  weights = self._get_recommendation_weights(recommendations, train)
@@ -196,9 +181,7 @@ class Surprisal(Metric):
196
181
  )
197
182
 
198
183
  @staticmethod
199
- def _get_metric_value_by_user( # pylint: disable=arguments-differ
200
- ks: List[int], pred_item_ids: List, pred_weights: List
201
- ) -> List[float]:
184
+ def _get_metric_value_by_user(ks: List[int], pred_item_ids: List, pred_weights: List) -> List[float]:
202
185
  if not pred_item_ids:
203
186
  return [0.0 for _ in ks]
204
187
  res = []
@@ -28,7 +28,6 @@ DEFAULT_METRICS: List[MetricName] = [
28
28
  DEFAULT_KS: List[int] = [1, 5, 10, 20]
29
29
 
30
30
 
31
- # pylint: disable=too-many-instance-attributes
32
31
  @dataclass
33
32
  class _MetricRequirements:
34
33
  """
@@ -113,7 +112,6 @@ class _CoverageHelper:
113
112
  self._train_hist = torch.zeros(self.item_count)
114
113
  self._pred_hist: Dict[int, torch.Tensor] = {k: torch.zeros(self.item_count) for k in self._top_k}
115
114
 
116
- # pylint: disable=attribute-defined-outside-init
117
115
  def _ensure_hists_on_device(self, device: torch.device) -> None:
118
116
  self._train_hist = self._train_hist.to(device)
119
117
  for k in self._top_k:
@@ -192,13 +190,11 @@ class _MetricBuilder(abc.ABC):
192
190
  """
193
191
 
194
192
 
195
- # pylint: disable=too-many-instance-attributes
196
193
  class TorchMetricsBuilder(_MetricBuilder):
197
194
  """
198
195
  Computes specified metrics over multiple batches
199
196
  """
200
197
 
201
- # pylint: disable=dangerous-default-value
202
198
  def __init__(
203
199
  self,
204
200
  metrics: List[MetricName] = DEFAULT_METRICS,
@@ -1,11 +1,10 @@
1
1
  from typing import List, Optional, Union
2
2
 
3
- from replay.utils import PandasDataFrame, SparkDataFrame, PolarsDataFrame
3
+ from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
4
4
 
5
5
  from .base_metric import Metric, MetricsDataFrameLike, MetricsReturnType
6
6
 
7
7
 
8
- # pylint: disable=too-few-public-methods
9
8
  class Unexpectedness(Metric):
10
9
  """
11
10
  Fraction of recommended items that are not present in some baseline\
@@ -13,11 +12,12 @@ class Unexpectedness(Metric):
13
12
 
14
13
  .. math::
15
14
  Unexpectedness@K(i) = 1 -
16
- \\frac {\parallel R^{i}_{1..\min(K, \parallel R^{i} \parallel)} \cap BR^{i}_{1..\min(K, \parallel BR^{i} \parallel)} \parallel}
15
+ \\frac {\\parallel R^{i}_{1..\\min(K, \\parallel R^{i} \\parallel)}
16
+ \\cap BR^{i}_{1..\\min(K, \\parallel BR^{i} \\parallel)} \\parallel}
17
17
  {K}
18
18
 
19
19
  .. math::
20
- Unexpectedness@K = \\frac {1}{N}\sum_{i=1}^{N}Unexpectedness@K(i)
20
+ Unexpectedness@K = \\frac {1}{N}\\sum_{i=1}^{N}Unexpectedness@K(i)
21
21
 
22
22
  :math:`R_{1..j}^{i}` -- the first :math:`j` recommendations for the :math:`i`-th user.
23
23
 
@@ -61,7 +61,7 @@ class Unexpectedness(Metric):
61
61
  'Unexpectedness-ConfidenceInterval@4': 0.0}
62
62
  <BLANKLINE>
63
63
  """
64
- # pylint: disable=arguments-renamed
64
+
65
65
  def _get_enriched_recommendations(
66
66
  self,
67
67
  recommendations: Union[PolarsDataFrame, SparkDataFrame],
@@ -72,14 +72,14 @@ class Unexpectedness(Metric):
72
72
  else:
73
73
  return self._get_enriched_recommendations_polars(recommendations, base_recommendations)
74
74
 
75
- def _get_enriched_recommendations_spark( # pylint: disable=arguments-renamed
75
+ def _get_enriched_recommendations_spark(
76
76
  self, recommendations: SparkDataFrame, base_recommendations: SparkDataFrame
77
77
  ) -> SparkDataFrame:
78
78
  sorted_by_score_recommendations = self._get_items_list_per_user(recommendations)
79
79
 
80
- sorted_by_score_base_recommendations = self._get_items_list_per_user(
81
- base_recommendations
82
- ).withColumnRenamed("pred_item_id", "base_pred_item_id")
80
+ sorted_by_score_base_recommendations = self._get_items_list_per_user(base_recommendations).withColumnRenamed(
81
+ "pred_item_id", "base_pred_item_id"
82
+ )
83
83
 
84
84
  enriched_recommendations = sorted_by_score_recommendations.join(
85
85
  sorted_by_score_base_recommendations, how="left", on=self.query_column
@@ -87,14 +87,14 @@ class Unexpectedness(Metric):
87
87
 
88
88
  return self._rearrange_columns(enriched_recommendations)
89
89
 
90
- def _get_enriched_recommendations_polars( # pylint: disable=arguments-renamed
90
+ def _get_enriched_recommendations_polars(
91
91
  self, recommendations: PolarsDataFrame, base_recommendations: PolarsDataFrame
92
92
  ) -> PolarsDataFrame:
93
93
  sorted_by_score_recommendations = self._get_items_list_per_user(recommendations)
94
94
 
95
- sorted_by_score_base_recommendations = self._get_items_list_per_user(
96
- base_recommendations
97
- ).rename({"pred_item_id": "base_pred_item_id"})
95
+ sorted_by_score_base_recommendations = self._get_items_list_per_user(base_recommendations).rename(
96
+ {"pred_item_id": "base_pred_item_id"}
97
+ )
98
98
 
99
99
  enriched_recommendations = sorted_by_score_recommendations.join(
100
100
  sorted_by_score_base_recommendations, how="left", on=self.query_column
@@ -152,12 +152,7 @@ class Unexpectedness(Metric):
152
152
  )
153
153
 
154
154
  @staticmethod
155
- def _get_metric_value_by_user( # pylint: disable=arguments-differ
156
- ks: List[int], base_recs: Optional[List], recs: Optional[List]
157
- ) -> List[float]:
155
+ def _get_metric_value_by_user(ks: List[int], base_recs: Optional[List], recs: Optional[List]) -> List[float]:
158
156
  if not base_recs or not recs:
159
157
  return [0.0 for _ in ks]
160
- res = []
161
- for k in ks:
162
- res.append(1.0 - len(set(recs[:k]) & set(base_recs[:k])) / k)
163
- return res
158
+ return [1.0 - len(set(recs[:k]) & set(base_recs[:k])) / k for k in ks]
replay/models/__init__.py CHANGED
@@ -12,6 +12,7 @@ from .association_rules import AssociationRulesItemRec
12
12
  from .base_rec import Recommender
13
13
  from .cat_pop_rec import CatPopRec
14
14
  from .cluster import ClusterRec
15
+ from .kl_ucb import KLUCB
15
16
  from .knn import ItemKNN
16
17
  from .pop_rec import PopRec
17
18
  from .query_pop_rec import QueryPopRec
@@ -19,7 +20,5 @@ from .random_rec import RandomRec
19
20
  from .slim import SLIM
20
21
  from .thompson_sampling import ThompsonSampling
21
22
  from .ucb import UCB
22
- # pylint: disable=cyclic-import
23
- from .kl_ucb import KLUCB
24
23
  from .wilson import Wilson
25
24
  from .word2vec import Word2VecRec
replay/models/als.py CHANGED
@@ -2,9 +2,10 @@ from os.path import join
2
2
  from typing import Optional, Tuple
3
3
 
4
4
  from replay.data import Dataset
5
- from .base_rec import ItemVectorModel, Recommender
6
5
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
7
6
 
7
+ from .base_rec import ItemVectorModel, Recommender
8
+
8
9
  if PYSPARK_AVAILABLE:
9
10
  import pyspark.sql.functions as sf
10
11
  from pyspark.ml.recommendation import ALS, ALSModel
@@ -13,7 +14,6 @@ if PYSPARK_AVAILABLE:
13
14
  from replay.utils.spark_utils import list_to_vector_udf
14
15
 
15
16
 
16
- # pylint: disable=too-many-instance-attributes
17
17
  class ALSWrap(Recommender, ItemVectorModel):
18
18
  """Wrapper for `Spark ALS
19
19
  <https://spark.apache.org/docs/latest/api/python/pyspark.mllib.html#pyspark.mllib.recommendation.ALS>`_.
@@ -24,7 +24,6 @@ class ALSWrap(Recommender, ItemVectorModel):
24
24
  "rank": {"type": "loguniform_int", "args": [8, 256]},
25
25
  }
26
26
 
27
- # pylint: disable=too-many-arguments
28
27
  def __init__(
29
28
  self,
30
29
  rank: int = 10,
@@ -98,7 +97,6 @@ class ALSWrap(Recommender, ItemVectorModel):
98
97
  self.model.itemFactors.unpersist()
99
98
  self.model.userFactors.unpersist()
100
99
 
101
- # pylint: disable=too-many-arguments
102
100
  def _predict(
103
101
  self,
104
102
  dataset: Optional[Dataset],
@@ -107,10 +105,8 @@ class ALSWrap(Recommender, ItemVectorModel):
107
105
  items: SparkDataFrame,
108
106
  filter_seen_items: bool = True,
109
107
  ) -> SparkDataFrame:
110
-
111
108
  if (items.count() == self.fit_items.count()) and (
112
- items.join(self.fit_items, on=self.item_column, how="inner").count()
113
- == self.fit_items.count()
109
+ items.join(self.fit_items, on=self.item_column, how="inner").count() == self.fit_items.count()
114
110
  ):
115
111
  max_seen = 0
116
112
  if filter_seen_items and dataset is not None:
@@ -125,9 +121,7 @@ class ALSWrap(Recommender, ItemVectorModel):
125
121
 
126
122
  recs_als = self.model.recommendForUserSubset(queries, k + max_seen)
127
123
  return (
128
- recs_als.withColumn(
129
- "recommendations", sf.explode("recommendations")
130
- )
124
+ recs_als.withColumn("recommendations", sf.explode("recommendations"))
131
125
  .withColumn(self.item_column, sf.col(f"recommendations.{self.item_column}"))
132
126
  .withColumn(
133
127
  self.rating_column,
@@ -144,7 +138,7 @@ class ALSWrap(Recommender, ItemVectorModel):
144
138
  def _predict_pairs(
145
139
  self,
146
140
  pairs: SparkDataFrame,
147
- dataset: Optional[Dataset] = None,
141
+ dataset: Optional[Dataset] = None, # noqa: ARG002
148
142
  ) -> SparkDataFrame:
149
143
  return (
150
144
  self.model.transform(pairs)
@@ -153,15 +147,13 @@ class ALSWrap(Recommender, ItemVectorModel):
153
147
  )
154
148
 
155
149
  def _get_features(
156
- self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
150
+ self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
157
151
  ) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
158
152
  entity = "user" if self.query_column in ids.columns else "item"
159
153
  entity_col = self.query_column if self.query_column in ids.columns else self.item_column
160
154
 
161
155
  als_factors = getattr(self.model, f"{entity}Factors")
162
- als_factors = als_factors.withColumnRenamed(
163
- "id", entity_col
164
- ).withColumnRenamed("features", f"{entity}_factors")
156
+ als_factors = als_factors.withColumnRenamed("id", entity_col).withColumnRenamed("features", f"{entity}_factors")
165
157
  return (
166
158
  als_factors.join(ids, how="right", on=entity_col),
167
159
  self.model.rank,