replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
replay/data/nn/utils.py CHANGED
@@ -2,11 +2,11 @@ from typing import Optional
2
2
 
3
3
  import polars as pl
4
4
 
5
- from replay.utils.spark_utils import spark_to_pandas
6
5
  from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame
6
+ from replay.utils.spark_utils import spark_to_pandas
7
7
 
8
8
  if PYSPARK_AVAILABLE: # pragma: no cover
9
- import pyspark.sql.functions as F
9
+ import pyspark.sql.functions as sf
10
10
 
11
11
 
12
12
  def groupby_sequences(events: DataFrameLike, groupby_col: str, sort_col: Optional[str] = None) -> DataFrameLike:
@@ -38,9 +38,7 @@ def groupby_sequences(events: DataFrameLike, groupby_col: str, sort_col: Optiona
38
38
  event_cols_without_groupby.insert(0, sort_col)
39
39
  events = events.sort(event_cols_without_groupby)
40
40
 
41
- grouped_sequences = events.group_by(groupby_col).agg(
42
- *[pl.col(x) for x in event_cols_without_groupby]
43
- )
41
+ grouped_sequences = events.group_by(groupby_col).agg(*[pl.col(x) for x in event_cols_without_groupby])
44
42
  else:
45
43
  event_cols_without_groupby = events.columns.copy()
46
44
  event_cols_without_groupby.remove(groupby_col)
@@ -49,16 +47,16 @@ def groupby_sequences(events: DataFrameLike, groupby_col: str, sort_col: Optiona
49
47
  event_cols_without_groupby.remove(sort_col)
50
48
  event_cols_without_groupby.insert(0, sort_col)
51
49
 
52
- all_cols_struct = F.struct(event_cols_without_groupby) # type: ignore
50
+ all_cols_struct = sf.struct(event_cols_without_groupby)
53
51
 
54
- collect_fn = F.collect_list(all_cols_struct)
52
+ collect_fn = sf.collect_list(all_cols_struct)
55
53
  if sort_col:
56
- collect_fn = F.sort_array(collect_fn)
54
+ collect_fn = sf.sort_array(collect_fn)
57
55
 
58
56
  grouped_sequences = (
59
57
  events.groupby(groupby_col)
60
58
  .agg(collect_fn.alias("_"))
61
- .select([F.col(groupby_col)] + [F.col(f"_.{col}").alias(col) for col in event_cols_without_groupby])
59
+ .select([sf.col(groupby_col)] + [sf.col(f"_.{col}").alias(col) for col in event_cols_without_groupby])
62
60
  .drop("_")
63
61
  )
64
62
 
replay/data/schema.py CHANGED
@@ -45,7 +45,6 @@ class FeatureInfo:
45
45
  Information about a feature.
46
46
  """
47
47
 
48
- # pylint: disable=too-many-arguments
49
48
  def __init__(
50
49
  self,
51
50
  column: str,
@@ -72,7 +71,8 @@ class FeatureInfo:
72
71
  self._feature_hint = feature_hint
73
72
 
74
73
  if feature_type == FeatureType.NUMERICAL and cardinality:
75
- raise ValueError("Cardinality is needed only with categorical feature_type.")
74
+ msg = "Cardinality is needed only with categorical feature_type."
75
+ raise ValueError(msg)
76
76
  self._cardinality = cardinality
77
77
 
78
78
  @property
@@ -112,14 +112,12 @@ class FeatureInfo:
112
112
  :returns: cardinality of the feature.
113
113
  """
114
114
  if self.feature_type != FeatureType.CATEGORICAL:
115
- raise RuntimeError(
116
- f"Can not get cardinality because feature_type of {self.column} column is not categorical."
117
- )
115
+ msg = f"Can not get cardinality because feature_type of {self.column} column is not categorical."
116
+ raise RuntimeError(msg)
118
117
  if hasattr(self, "_cardinality_callback") and self._cardinality is None:
119
118
  self._cardinality = self._cardinality_callback(self._column)
120
119
  return self._cardinality
121
120
 
122
- # pylint: disable=attribute-defined-outside-init
123
121
  def _set_cardinality_callback(self, callback: Callable) -> None:
124
122
  self._cardinality_callback = callback
125
123
 
@@ -130,7 +128,6 @@ class FeatureInfo:
130
128
  self._cardinality = None
131
129
 
132
130
 
133
- # pylint: disable=too-many-public-methods
134
131
  class FeatureSchema(Mapping[str, FeatureInfo]):
135
132
  """
136
133
  Key-value like collection with information about all dataset features.
@@ -174,8 +171,9 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
174
171
  :returns: extract a feature information from a schema.
175
172
  """
176
173
  if len(self._features_schema) > 1:
177
- raise ValueError("Only one element feature schema can be converted to single feature")
178
- return list(self._features_schema.values())[0]
174
+ msg = "Only one element feature schema can be converted to single feature"
175
+ raise ValueError(msg)
176
+ return next(iter(self._features_schema.values()))
179
177
 
180
178
  def items(self) -> ItemsView[str, FeatureInfo]:
181
179
  return self._features_schema.items()
@@ -186,7 +184,7 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
186
184
  def values(self) -> ValuesView[FeatureInfo]:
187
185
  return self._features_schema.values()
188
186
 
189
- def get( # type: ignore
187
+ def get(
190
188
  self,
191
189
  key: str,
192
190
  default: Optional[FeatureInfo] = None,
@@ -358,7 +356,7 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
358
356
  for filtration_func, filtration_param in zip(filter_functions, filter_parameters):
359
357
  filtered_features = list(
360
358
  filter(
361
- lambda x: filtration_func(x, filtration_param), # type: ignore # pylint: disable=W0640
359
+ lambda x: filtration_func(x, filtration_param),
362
360
  filtered_features,
363
361
  )
364
362
  )
@@ -391,7 +389,7 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
391
389
  for filtration_func, filtration_param in zip(filter_functions, filter_parameters):
392
390
  filtered_features = list(
393
391
  filter(
394
- lambda x: filtration_func(x, filtration_param), # type: ignore # pylint: disable=W0640
392
+ lambda x: filtration_func(x, filtration_param),
395
393
  filtered_features,
396
394
  )
397
395
  )
@@ -426,7 +424,6 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
426
424
  def _type_drop(value: FeatureInfo, feature_type: FeatureType) -> bool:
427
425
  return value.feature_type != feature_type if feature_type else True
428
426
 
429
- # pylint: disable=no-self-use
430
427
  @staticmethod
431
428
  def _hint_drop(value: FeatureInfo, feature_hint: FeatureHint) -> bool:
432
429
  return value.feature_hint != feature_hint if feature_hint else True
@@ -451,13 +448,16 @@ class FeatureSchema(Mapping[str, FeatureInfo]):
451
448
  item_query_names[feature.feature_hint] += [feature.column]
452
449
 
453
450
  if len(duplicates) > 0:
454
- raise ValueError(
451
+ msg = (
455
452
  "Features column names should be unique, exept ITEM_ID and QUERY_ID columns. "
456
- + f"{duplicates} columns are not unique."
453
+ f"{duplicates} columns are not unique."
457
454
  )
455
+ raise ValueError(msg)
458
456
 
459
457
  if len(item_query_names[FeatureHint.ITEM_ID]) > 1:
460
- raise ValueError(f"ITEM_ID must be present only once. Rename {item_query_names[FeatureHint.ITEM_ID]}")
458
+ msg = f"ITEM_ID must be present only once. Rename {item_query_names[FeatureHint.ITEM_ID]}"
459
+ raise ValueError(msg)
461
460
 
462
461
  if len(item_query_names[FeatureHint.QUERY_ID]) > 1:
463
- raise ValueError(f"QUERY_ID must be present only once. Rename {item_query_names[FeatureHint.QUERY_ID]}")
462
+ msg = f"QUERY_ID must be present only once. Rename {item_query_names[FeatureHint.QUERY_ID]}"
463
+ raise ValueError(msg)
@@ -4,7 +4,6 @@ if PYSPARK_AVAILABLE:
4
4
  from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType, TimestampType
5
5
 
6
6
 
7
- # pylint: disable=too-many-arguments
8
7
  def get_schema(
9
8
  query_column: str = "query_id",
10
9
  item_column: str = "item_id",
@@ -1,11 +1,11 @@
1
1
  import warnings
2
2
  from abc import ABC, abstractmethod
3
- from typing import Any, Dict, List, Mapping, Union
3
+ from typing import Any, Dict, List, Mapping, Optional, Union
4
4
 
5
5
  import numpy as np
6
6
  import polars as pl
7
7
 
8
- from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, SparkDataFrame, PolarsDataFrame
8
+ from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
9
9
 
10
10
  from .descriptors import CalculationDescriptor, Mean
11
11
 
@@ -27,7 +27,7 @@ class MetricDuplicatesWarning(Warning):
27
27
  class Metric(ABC):
28
28
  """Base metric class"""
29
29
 
30
- def __init__( # pylint: disable=too-many-arguments
30
+ def __init__(
31
31
  self,
32
32
  topk: Union[List[int], int],
33
33
  query_column: str = "query_id",
@@ -46,11 +46,13 @@ class Metric(ABC):
46
46
  if isinstance(topk, list):
47
47
  for item in topk:
48
48
  if not isinstance(item, int):
49
- raise ValueError(f"{item} is not int")
49
+ msg = f"{item} is not int"
50
+ raise ValueError(msg)
50
51
  elif isinstance(topk, int):
51
52
  topk = [topk]
52
53
  else:
53
- raise ValueError("topk not list or int")
54
+ msg = "topk not list or int"
55
+ raise ValueError(msg)
54
56
  self.topk = sorted(topk)
55
57
  self.query_column = query_column
56
58
  self.item_column = item_column
@@ -60,11 +62,8 @@ class Metric(ABC):
60
62
  @property
61
63
  def __name__(self) -> str:
62
64
  mode_name = self._mode.__name__
63
- return str(type(self).__name__) + (
64
- f"-{mode_name}" if mode_name != "Mean" else ""
65
- )
65
+ return str(type(self).__name__) + (f"-{mode_name}" if mode_name != "Mean" else "")
66
66
 
67
- # pylint: disable=no-self-use
68
67
  def _check_dataframes_equal_types(
69
68
  self,
70
69
  recommendations: MetricsDataFrameLike,
@@ -74,39 +73,31 @@ class Metric(ABC):
74
73
  Types of all data frames must be the same.
75
74
  """
76
75
  if not isinstance(recommendations, type(ground_truth)):
77
- raise ValueError("All given data frames must have the same type")
76
+ msg = "All given data frames must have the same type"
77
+ raise ValueError(msg)
78
78
 
79
79
  def _duplicate_warn(self):
80
80
  warnings.warn(
81
- "The recommendations contain duplicated users and items."
82
- "The metrics may be higher than the actual ones.",
81
+ "The recommendations contain duplicated users and items.The metrics may be higher than the actual ones.",
83
82
  MetricDuplicatesWarning,
84
83
  )
85
84
 
86
85
  def _check_duplicates_spark(self, recommendations: SparkDataFrame) -> None:
87
86
  duplicates_count = (
88
- recommendations.groupBy(self.query_column, self.item_column)
89
- .count()
90
- .filter("count >= 2")
91
- .count()
87
+ recommendations.groupBy(self.query_column, self.item_column).count().filter("count >= 2").count()
92
88
  )
93
89
  if duplicates_count:
94
90
  self._duplicate_warn()
95
91
 
96
92
  def _check_duplicates_dict(self, recommendations: Dict) -> None:
97
- for _, items in recommendations.items():
93
+ for items in recommendations.values():
98
94
  items_set = set(items)
99
95
  if len(items) != len(items_set):
100
96
  self._duplicate_warn()
101
97
  return
102
98
 
103
99
  def _check_duplicates_polars(self, recommendations: PolarsDataFrame) -> None:
104
- duplicates_count = (
105
- recommendations
106
- .group_by(self.query_column, self.item_column)
107
- .len()
108
- .filter(pl.col("len") > 1)
109
- )
100
+ duplicates_count = recommendations.group_by(self.query_column, self.item_column).len().filter(pl.col("len") > 1)
110
101
  if not duplicates_count.is_empty():
111
102
  self._duplicate_warn()
112
103
 
@@ -144,11 +135,7 @@ class Metric(ABC):
144
135
  else self._convert_dict_to_dict_with_score(recommendations)
145
136
  )
146
137
  self._check_duplicates_dict(recommendations)
147
- ground_truth = (
148
- self._convert_pandas_to_dict_without_score(ground_truth)
149
- if is_pandas
150
- else ground_truth
151
- )
138
+ ground_truth = self._convert_pandas_to_dict_without_score(ground_truth) if is_pandas else ground_truth
152
139
  assert isinstance(ground_truth, dict)
153
140
  return self._dict_call(
154
141
  list(ground_truth),
@@ -164,7 +151,6 @@ class Metric(ABC):
164
151
  .to_dict()
165
152
  )
166
153
 
167
- # pylint: disable=no-self-use
168
154
  def _convert_dict_to_dict_with_score(self, data: Dict) -> Dict:
169
155
  converted_data = {}
170
156
  for user, items in data.items():
@@ -191,31 +177,21 @@ class Metric(ABC):
191
177
  distribution_per_user = {}
192
178
  for user in users:
193
179
  args = [kwargs[key].get(user, None) for key in keys_list]
194
- distribution_per_user[user] = self._get_metric_value_by_user(
195
- self.topk, *args
196
- ) # pylint: disable=protected-access
180
+ distribution_per_user[user] = self._get_metric_value_by_user(self.topk, *args)
197
181
  if self._mode.__name__ == "PerUser":
198
182
  return self._aggregate_results_per_user(distribution_per_user)
199
183
  distribution = np.stack(list(distribution_per_user.values()))
200
184
  assert distribution.shape[1] == len(self.topk)
201
- metrics = []
202
- for k in range(distribution.shape[1]):
203
- metrics.append(self._mode.cpu(distribution[:, k]))
185
+ metrics = [self._mode.cpu(distribution[:, k]) for k in range(distribution.shape[1])]
204
186
  return self._aggregate_results(metrics)
205
187
 
206
188
  def _get_items_list_per_user_spark(
207
- self, recommendations: SparkDataFrame, extra_column: str = None
189
+ self, recommendations: SparkDataFrame, extra_column: Optional[str] = None
208
190
  ) -> SparkDataFrame:
209
191
  recommendations = recommendations.groupby(self.query_column).agg(
210
192
  sf.sort_array(
211
193
  sf.collect_list(
212
- sf.struct(
213
- *[
214
- c
215
- for c in [self.rating_column, self.item_column, extra_column]
216
- if c is not None
217
- ]
218
- )
194
+ sf.struct(*[c for c in [self.rating_column, self.item_column, extra_column] if c is not None])
219
195
  ),
220
196
  False,
221
197
  ).alias("pred")
@@ -231,7 +207,7 @@ class Metric(ABC):
231
207
  return recommendations
232
208
 
233
209
  def _get_items_list_per_user_polars(
234
- self, recommendations: PolarsDataFrame, extra_column: str = None
210
+ self, recommendations: PolarsDataFrame, extra_column: Optional[str] = None
235
211
  ) -> PolarsDataFrame:
236
212
  selection = [self.query_column, "pred_item_id"]
237
213
  sorting = [self.rating_column, self.item_column]
@@ -242,8 +218,7 @@ class Metric(ABC):
242
218
  selection.append(extra_column)
243
219
 
244
220
  recommendations = (
245
- recommendations
246
- .sort(sorting, descending=True)
221
+ recommendations.sort(sorting, descending=True)
247
222
  .group_by(self.query_column)
248
223
  .agg(*agg)
249
224
  .rename({self.item_column: "pred_item_id"})
@@ -253,7 +228,7 @@ class Metric(ABC):
253
228
  return recommendations
254
229
 
255
230
  def _get_items_list_per_user(
256
- self, recommendations: Union[SparkDataFrame, PolarsDataFrame], extra_column: str = None
231
+ self, recommendations: Union[SparkDataFrame, PolarsDataFrame], extra_column: Optional[str] = None
257
232
  ) -> Union[SparkDataFrame, PolarsDataFrame]:
258
233
  if isinstance(recommendations, SparkDataFrame):
259
234
  return self._get_items_list_per_user_spark(recommendations, extra_column)
@@ -265,7 +240,7 @@ class Metric(ABC):
265
240
  ) -> Union[SparkDataFrame, PolarsDataFrame]:
266
241
  cols = data.columns
267
242
  cols.remove(self.query_column)
268
- cols = [self.query_column] + sorted(cols)
243
+ cols = [self.query_column, *sorted(cols)]
269
244
  return data.select(*cols)
270
245
 
271
246
  def _get_enriched_recommendations(
@@ -300,8 +275,7 @@ class Metric(ABC):
300
275
  ground_truth: PolarsDataFrame,
301
276
  ) -> PolarsDataFrame:
302
277
  true_items_by_users = (
303
- ground_truth
304
- .group_by(self.query_column)
278
+ ground_truth.group_by(self.query_column)
305
279
  .agg(pl.col(self.item_column))
306
280
  .rename({self.item_column: "ground_truth"})
307
281
  )
@@ -313,9 +287,7 @@ class Metric(ABC):
313
287
  )
314
288
  return self._rearrange_columns(enriched_recommendations)
315
289
 
316
- def _aggregate_results_per_user(
317
- self, distribution_per_user: Dict[Any, List[float]]
318
- ) -> MetricsPerUserReturnType:
290
+ def _aggregate_results_per_user(self, distribution_per_user: Dict[Any, List[float]]) -> MetricsPerUserReturnType:
319
291
  res: MetricsPerUserReturnType = {}
320
292
  for index, val in enumerate(self.topk):
321
293
  metric_name = f"{self.__name__}@{val}"
@@ -335,18 +307,12 @@ class Metric(ABC):
335
307
  """
336
308
  Calculating metrics for PySpark DataFrame.
337
309
  """
338
- recs_with_topk_list = recs.withColumn(
339
- "k", sf.array(*[sf.lit(x) for x in self.topk])
340
- )
310
+ recs_with_topk_list = recs.withColumn("k", sf.array(*[sf.lit(x) for x in self.topk]))
341
311
  distribution = self._get_metric_distribution(recs_with_topk_list)
342
312
  if self._mode.__name__ == "PerUser":
343
313
  return self._aggregate_results_per_user(distribution.rdd.collectAsMap())
344
314
  metrics = [
345
- self._mode.spark(
346
- distribution.select(sf.col("value").getItem(i)).withColumnRenamed(
347
- f"value[{i}]", "val"
348
- )
349
- )
315
+ self._mode.spark(distribution.select(sf.col("value").getItem(i)).withColumnRenamed(f"value[{i}]", "val"))
350
316
  for i in range(len(self.topk))
351
317
  ]
352
318
  return self._aggregate_results(metrics)
@@ -355,27 +321,23 @@ class Metric(ABC):
355
321
  distribution = self._get_metric_distribution(recs)
356
322
  if self._mode.__name__ == "PerUser":
357
323
  return self._aggregate_results_per_user(
358
- dict(distribution.select(
359
- self.query_column,
360
- value=pl.concat_list(pl.exclude(self.query_column))
361
- ).iter_rows())
324
+ dict(
325
+ distribution.select(
326
+ self.query_column, value=pl.concat_list(pl.exclude(self.query_column))
327
+ ).iter_rows()
328
+ )
362
329
  )
363
- metrics = [self._mode.cpu(distribution.select(column))
364
- for column in distribution.columns[1:]]
330
+ metrics = [self._mode.cpu(distribution.select(column)) for column in distribution.columns[1:]]
365
331
  return self._aggregate_results(metrics)
366
332
 
367
- def _spark_call(
368
- self, recommendations: SparkDataFrame, ground_truth: SparkDataFrame
369
- ) -> MetricsReturnType:
333
+ def _spark_call(self, recommendations: SparkDataFrame, ground_truth: SparkDataFrame) -> MetricsReturnType:
370
334
  """
371
335
  Implementation for PySpark DataFrame.
372
336
  """
373
337
  recs = self._get_enriched_recommendations(recommendations, ground_truth)
374
338
  return self._spark_compute(recs)
375
339
 
376
- def _polars_call(
377
- self, recommendations: PolarsDataFrame, ground_truth: PolarsDataFrame
378
- ) -> MetricsReturnType:
340
+ def _polars_call(self, recommendations: PolarsDataFrame, ground_truth: PolarsDataFrame) -> MetricsReturnType:
379
341
  """
380
342
  Implementation for Polars DataFrame.
381
343
  """
@@ -383,7 +345,7 @@ class Metric(ABC):
383
345
  return self._polars_compute(recs)
384
346
 
385
347
  def _get_metric_distribution(
386
- self, recs: Union[PolarsDataFrame, SparkDataFrame]
348
+ self, recs: Union[PolarsDataFrame, SparkDataFrame]
387
349
  ) -> Union[PolarsDataFrame, SparkDataFrame]:
388
350
  if isinstance(recs, SparkDataFrame):
389
351
  return self._get_metric_distribution_spark(recs)
@@ -406,16 +368,13 @@ class Metric(ABC):
406
368
  distribution = recs.map_rows(lambda x: (x[0], *cur_class._get_metric_value_by_user(self.topk, *x[1:])))
407
369
  distribution = distribution.rename({"column_0": self.query_column})
408
370
  distribution = distribution.rename(
409
- {distribution.columns[x + 1]: f"value_{self.topk[x]}"
410
- for x in range(len(self.topk))}
371
+ {distribution.columns[x + 1]: f"value_{self.topk[x]}" for x in range(len(self.topk))}
411
372
  )
412
373
  return distribution
413
374
 
414
375
  @staticmethod
415
376
  @abstractmethod
416
- def _get_metric_value_by_user( # pylint: disable=invalid-name
417
- ks: List[int], *args: List
418
- ) -> List[float]: # pragma: no cover
377
+ def _get_metric_value_by_user(ks: List[int], *args: List) -> List[float]: # pragma: no cover
419
378
  """
420
379
  Metric calculation for one user.
421
380
 
@@ -4,7 +4,7 @@ from typing import Dict, List, Union
4
4
  import numpy as np
5
5
  import polars as pl
6
6
 
7
- from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame, PolarsDataFrame
7
+ from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, PolarsDataFrame, SparkDataFrame
8
8
 
9
9
  from .base_metric import (
10
10
  Metric,
@@ -16,11 +16,12 @@ from .base_metric import (
16
16
  from .descriptors import CalculationDescriptor, Mean
17
17
 
18
18
  if PYSPARK_AVAILABLE:
19
- from pyspark.sql import Window
20
- from pyspark.sql import functions as F
19
+ from pyspark.sql import (
20
+ Window,
21
+ functions as sf,
22
+ )
21
23
 
22
24
 
23
- # pylint: disable=too-few-public-methods
24
25
  class CategoricalDiversity(Metric):
25
26
  """
26
27
  Metric calculation is as follows:
@@ -59,7 +60,6 @@ class CategoricalDiversity(Metric):
59
60
  <BLANKLINE>
60
61
  """
61
62
 
62
- # pylint: disable=too-many-arguments
63
63
  def __init__(
64
64
  self,
65
65
  topk: Union[List, int],
@@ -108,31 +108,21 @@ class CategoricalDiversity(Metric):
108
108
  precalculated_answer = self._precalculate_unique_cats(recommendations)
109
109
  return self._dict_call(precalculated_answer)
110
110
 
111
- # pylint: disable=arguments-differ
112
111
  def _get_enriched_recommendations(
113
- self, recommendations: Union[PolarsDataFrame, SparkDataFrame],
112
+ self,
113
+ recommendations: Union[PolarsDataFrame, SparkDataFrame],
114
114
  ) -> Union[PolarsDataFrame, SparkDataFrame]:
115
115
  if isinstance(recommendations, SparkDataFrame):
116
116
  return self._get_enriched_recommendations_spark(recommendations)
117
117
  else:
118
118
  return self._get_enriched_recommendations_polars(recommendations)
119
119
 
120
- # pylint: disable=arguments-differ
121
- def _get_enriched_recommendations_spark(
122
- self, recommendations: SparkDataFrame
123
- ) -> SparkDataFrame:
124
- window = Window.partitionBy(self.query_column).orderBy(
125
- F.col(self.rating_column).desc()
126
- )
127
- sorted_by_score_recommendations = recommendations.withColumn(
128
- "rank", F.row_number().over(window)
129
- )
120
+ def _get_enriched_recommendations_spark(self, recommendations: SparkDataFrame) -> SparkDataFrame:
121
+ window = Window.partitionBy(self.query_column).orderBy(sf.col(self.rating_column).desc())
122
+ sorted_by_score_recommendations = recommendations.withColumn("rank", sf.row_number().over(window))
130
123
  return sorted_by_score_recommendations
131
124
 
132
- # pylint: disable=arguments-differ
133
- def _get_enriched_recommendations_polars(
134
- self, recommendations: PolarsDataFrame
135
- ) -> PolarsDataFrame:
125
+ def _get_enriched_recommendations_polars(self, recommendations: PolarsDataFrame) -> PolarsDataFrame:
136
126
  sorted_by_score_recommendations = recommendations.select(
137
127
  pl.all().sort_by(self.rating_column, descending=True).over(self.query_column)
138
128
  )
@@ -146,13 +136,9 @@ class CategoricalDiversity(Metric):
146
136
  def _spark_compute_per_user(self, recs: SparkDataFrame) -> MetricsPerUserReturnType:
147
137
  distribution_per_user = defaultdict(list)
148
138
  for k in self.topk:
149
- filtered_recs = recs.filter(F.col("rank") <= k)
150
- aggreagated_by_user = filtered_recs.groupBy(self.query_column).agg(
151
- F.countDistinct(self.category_column)
152
- )
153
- aggreagated_by_user_dict = (
154
- aggreagated_by_user.rdd.collectAsMap()
155
- ) # type:ignore
139
+ filtered_recs = recs.filter(sf.col("rank") <= k)
140
+ aggreagated_by_user = filtered_recs.groupBy(self.query_column).agg(sf.countDistinct(self.category_column))
141
+ aggreagated_by_user_dict = aggreagated_by_user.rdd.collectAsMap()
156
142
  for user, metric in aggreagated_by_user_dict.items():
157
143
  distribution_per_user[user].append(metric / k)
158
144
  return self._aggregate_results_per_user(dict(distribution_per_user))
@@ -161,12 +147,8 @@ class CategoricalDiversity(Metric):
161
147
  distribution_per_user = defaultdict(list)
162
148
  for k in self.topk:
163
149
  filtered_recs = recs.filter(pl.col("rank") <= k)
164
- aggreagated_by_user = filtered_recs.group_by(self.query_column).agg(
165
- pl.col(self.category_column).n_unique()
166
- )
167
- aggreagated_by_user_dict = (
168
- dict(aggreagated_by_user.iter_rows())
169
- ) # type:ignore
150
+ aggreagated_by_user = filtered_recs.group_by(self.query_column).agg(pl.col(self.category_column).n_unique())
151
+ aggreagated_by_user_dict = dict(aggreagated_by_user.iter_rows())
170
152
  for user, metric in aggreagated_by_user_dict.items():
171
153
  distribution_per_user[user].append(metric / k)
172
154
  return self._aggregate_results_per_user(dict(distribution_per_user))
@@ -174,10 +156,10 @@ class CategoricalDiversity(Metric):
174
156
  def _spark_compute_agg(self, recs: SparkDataFrame) -> MetricsMeanReturnType:
175
157
  metrics = []
176
158
  for k in self.topk:
177
- filtered_recs = recs.filter(F.col("rank") <= k)
159
+ filtered_recs = recs.filter(sf.col("rank") <= k)
178
160
  aggregated_by_user = (
179
161
  filtered_recs.groupBy(self.query_column)
180
- .agg(F.countDistinct(self.category_column))
162
+ .agg(sf.countDistinct(self.category_column))
181
163
  .drop(self.query_column)
182
164
  )
183
165
  metrics.append(self._mode.spark(aggregated_by_user) / k)
@@ -195,7 +177,6 @@ class CategoricalDiversity(Metric):
195
177
  metrics.append(self._mode.cpu(aggregated_by_user) / k)
196
178
  return self._aggregate_results(metrics)
197
179
 
198
- # pylint: disable=arguments-differ
199
180
  def _spark_call(self, recommendations: SparkDataFrame) -> MetricsReturnType:
200
181
  """
201
182
  Implementation for Pyspark DataFrame.
@@ -205,7 +186,6 @@ class CategoricalDiversity(Metric):
205
186
  return self._spark_compute_per_user(recs)
206
187
  return self._spark_compute_agg(recs)
207
188
 
208
- # pylint: disable=arguments-differ
209
189
  def _polars_call(self, recommendations: PolarsDataFrame) -> MetricsReturnType:
210
190
  """
211
191
  Implementation for Polars DataFrame.
@@ -223,7 +203,6 @@ class CategoricalDiversity(Metric):
223
203
  .to_dict()
224
204
  )
225
205
 
226
- # pylint: disable=no-self-use
227
206
  def _precalculate_unique_cats(self, recommendations: Dict) -> Dict:
228
207
  """
229
208
  Precalculate unique categories for each prefix for each user.
@@ -238,24 +217,16 @@ class CategoricalDiversity(Metric):
238
217
  answer[user] = unique_len
239
218
  return answer
240
219
 
241
- # pylint: disable=arguments-renamed,arguments-differ
242
- def _dict_compute_per_user(
243
- self, precalculated_answer: Dict
244
- ) -> MetricsPerUserReturnType: # type:ignore
220
+ def _dict_compute_per_user(self, precalculated_answer: Dict) -> MetricsPerUserReturnType:
245
221
  distribution_per_user = defaultdict(list)
246
222
  for k in self.topk:
247
223
  for user, unique_cats in precalculated_answer.items():
248
- distribution_per_user[user].append(
249
- unique_cats[min(len(unique_cats), k) - 1] / k
250
- )
224
+ distribution_per_user[user].append(unique_cats[min(len(unique_cats), k) - 1] / k)
251
225
  return self._aggregate_results_per_user(distribution_per_user)
252
226
 
253
- # pylint: disable=arguments-renamed
254
- def _dict_compute_mean(
255
- self, precalculated_answer: Dict
256
- ) -> MetricsMeanReturnType: # type:ignore
227
+ def _dict_compute_mean(self, precalculated_answer: Dict) -> MetricsMeanReturnType:
257
228
  distribution_list = []
258
- for _, unique_cats in precalculated_answer.items():
229
+ for unique_cats in precalculated_answer.values():
259
230
  metrics_per_user = []
260
231
  for k in self.topk:
261
232
  metric = unique_cats[min(len(unique_cats), k) - 1] / k
@@ -264,12 +235,9 @@ class CategoricalDiversity(Metric):
264
235
 
265
236
  distribution = np.stack(distribution_list)
266
237
  assert distribution.shape[1] == len(self.topk)
267
- metrics = []
268
- for k in range(distribution.shape[1]):
269
- metrics.append(self._mode.cpu(distribution[:, k]))
238
+ metrics = [self._mode.cpu(distribution[:, k]) for k in range(distribution.shape[1])]
270
239
  return self._aggregate_results(metrics)
271
240
 
272
- # pylint: disable=arguments-differ
273
241
  def _dict_call(self, precalculated_answer: Dict) -> MetricsReturnType:
274
242
  """
275
243
  Calculating metrics in dict format.
@@ -279,7 +247,5 @@ class CategoricalDiversity(Metric):
279
247
  return self._dict_compute_mean(precalculated_answer)
280
248
 
281
249
  @staticmethod
282
- def _get_metric_value_by_user(
283
- ks: List[int], *args: List
284
- ) -> List[float]: # pragma: no cover
250
+ def _get_metric_value_by_user(ks: List[int], *args: List) -> List[float]: # pragma: no cover
285
251
  pass