replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
@@ -1,661 +0,0 @@
1
- """
2
- Base classes for quality and diversity metrics.
3
- """
4
- import logging
5
- from abc import ABC, abstractmethod
6
- from typing import Dict, List, Optional, Union
7
-
8
- from scipy.stats import norm
9
-
10
- from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, IntOrList, NumType, PandasDataFrame, SparkDataFrame
11
- from replay.utils.session_handler import State
12
- from replay.utils.spark_utils import convert2spark, get_top_k_recs
13
-
14
- if PYSPARK_AVAILABLE:
15
- from pyspark.sql import Column, Window
16
- from pyspark.sql import functions as sf
17
- from pyspark.sql import types as st
18
- from pyspark.sql.column import _to_java_column, _to_seq
19
- from pyspark.sql.types import DataType
20
-
21
-
22
- def fill_na_with_empty_array(
23
- df: SparkDataFrame, col_name: str, element_type: DataType
24
- ) -> SparkDataFrame:
25
- """
26
- Fill empty values in array column with empty array of `element_type` values.
27
- :param df: dataframe with `col_name` column of ArrayType(`element_type`)
28
- :param col_name: name of a column to fill missing values
29
- :param element_type: DataType of an array element
30
- :return: df with `col_name` na values filled with empty arrays
31
- """
32
- return df.withColumn(
33
- col_name,
34
- sf.coalesce(
35
- col_name,
36
- sf.array().cast(st.ArrayType(element_type)),
37
- ),
38
- )
39
-
40
-
41
- def preprocess_gt(
42
- ground_truth: DataFrameLike,
43
- ground_truth_users: Optional[DataFrameLike] = None,
44
- ) -> SparkDataFrame:
45
- """
46
- Preprocess `ground_truth` data before metric calculation
47
- :param ground_truth: spark dataframe with columns ``[user_idx, item_idx, relevance]``
48
- :param ground_truth_users: spark dataframe with column ``[user_idx]``
49
- :return: spark dataframe with columns ``[user_idx, ground_truth]``
50
- """
51
- ground_truth = convert2spark(ground_truth)
52
- ground_truth_users = convert2spark(ground_truth_users)
53
-
54
- true_items_by_users = ground_truth.groupby("user_idx").agg(
55
- sf.collect_set("item_idx").alias("ground_truth")
56
- )
57
- if ground_truth_users is not None:
58
- true_items_by_users = true_items_by_users.join(
59
- ground_truth_users, on="user_idx", how="right"
60
- )
61
- true_items_by_users = fill_na_with_empty_array(
62
- true_items_by_users,
63
- "ground_truth",
64
- ground_truth.schema["item_idx"].dataType,
65
- )
66
-
67
- return true_items_by_users
68
-
69
-
70
- def drop_duplicates(recommendations: DataFrameLike) -> SparkDataFrame:
71
-
72
- """
73
- Filter duplicated predictions by choosing the most relevant
74
- """
75
- return (
76
- recommendations.withColumn(
77
- "_num",
78
- sf.row_number().over(
79
- Window.partitionBy("user_idx", "item_idx").orderBy(sf.col("relevance").desc())
80
- ),
81
- )
82
- .where(sf.col("_num") == 1)
83
- .drop("_num")
84
- )
85
-
86
-
87
- def filter_sort(recommendations: SparkDataFrame, extra_column: str = None) -> SparkDataFrame:
88
- """
89
- Filters duplicated predictions by choosing items with the highest relevance,
90
- Sorts items in predictions by its relevance,
91
- If `extra_column` is not None return DataFrame with extra_column e.g. item weight.
92
-
93
- :param recommendations: recommendation list
94
- :param extra_column: column in recommendations
95
- which will be return besides ``[user_idx, item_idx]``
96
- :return: ``[user_idx, item_idx]`` if extra_column = None
97
- or ``[user_idx, item_idx, extra_column]`` if extra_column exists.
98
- """
99
- item_type = recommendations.schema["item_idx"].dataType
100
- extra_column_type = recommendations.schema[extra_column].dataType if extra_column else None
101
-
102
- recommendations = drop_duplicates(recommendations)
103
-
104
- recommendations = (
105
- recommendations
106
- .groupby("user_idx")
107
- .agg(
108
- sf.collect_list(
109
- sf.struct(*[c for c in ["relevance", "item_idx", extra_column] if c is not None]))
110
- .alias("pred_list"))
111
- .withColumn("pred_list", sf.reverse(sf.array_sort("pred_list")))
112
- )
113
-
114
- selection = [
115
- "user_idx",
116
- sf.col("pred_list.item_idx")
117
- .cast(st.ArrayType(item_type, True)).alias("pred")
118
- ]
119
- if extra_column:
120
- selection.append(
121
- sf.col(f"pred_list.{extra_column}")
122
- .cast(st.ArrayType(extra_column_type, True)).alias(extra_column)
123
- )
124
-
125
- recommendations = recommendations.select(*selection)
126
-
127
- return recommendations
128
-
129
-
130
- def get_enriched_recommendations(
131
- recommendations: DataFrameLike,
132
- ground_truth: DataFrameLike,
133
- max_k: int,
134
- ground_truth_users: Optional[DataFrameLike] = None,
135
- ) -> SparkDataFrame:
136
- """
137
- Leave max_k recommendations for each user,
138
- merge recommendations and ground truth into a single DataFrame
139
- and aggregate items into lists so that each user has only one record.
140
-
141
- :param recommendations: recommendation list
142
- :param ground_truth: test data
143
- :param max_k: maximal k value to calculate the metric for.
144
- `max_k` most relevant predictions are left for each user
145
- :param ground_truth_users: list of users to consider in metric calculation.
146
- if None, only the users from ground_truth are considered.
147
- :return: ``[user_idx, pred, ground_truth]``
148
- """
149
- recommendations = convert2spark(recommendations)
150
- # if there are duplicates in recommendations,
151
- # we will leave fewer than k recommendations after sort_udf
152
- recommendations = get_top_k_recs(recommendations, k=max_k)
153
-
154
- true_items_by_users = preprocess_gt(ground_truth, ground_truth_users)
155
- joined = filter_sort(recommendations).join(
156
- true_items_by_users, how="right", on=["user_idx"]
157
- )
158
-
159
- return fill_na_with_empty_array(
160
- joined, "pred", recommendations.schema["item_idx"].dataType
161
- )
162
-
163
-
164
- def process_k(func):
165
- """Decorator that converts k to list and unpacks result"""
166
-
167
- def wrap(self, recs: SparkDataFrame, k: IntOrList, *args):
168
- if isinstance(k, int):
169
- k_list = [k]
170
- else:
171
- k_list = k
172
-
173
- res = func(self, recs, k_list, *args)
174
-
175
- if isinstance(k, int):
176
- return res[k]
177
- return res
178
-
179
- return wrap
180
-
181
-
182
- class Metric(ABC):
183
- """Base metric class"""
184
-
185
- _logger: Optional[logging.Logger] = None
186
- _scala_udf_name: Optional[str] = None
187
-
188
- def __init__(self, use_scala_udf: bool = False) -> None:
189
- self._use_scala_udf = use_scala_udf
190
-
191
- @property
192
- def logger(self) -> logging.Logger:
193
- """
194
- :returns: get library logger
195
- """
196
- if self._logger is None:
197
- self._logger = logging.getLogger("replay")
198
- return self._logger
199
-
200
- @property
201
- def scala_udf_name(self) -> str:
202
- """Returns UDF name from `org.apache.spark.replay.utils.ScalaPySparkUDFs`"""
203
- if self._scala_udf_name:
204
- return self._scala_udf_name
205
- else:
206
- raise NotImplementedError(f"Scala UDF not implemented for {type(self).__name__} class!")
207
-
208
- def __str__(self):
209
- return type(self).__name__
210
-
211
- def __call__(
212
- self,
213
- recommendations: DataFrameLike,
214
- ground_truth: DataFrameLike,
215
- k: IntOrList,
216
- ground_truth_users: Optional[DataFrameLike] = None,
217
- ) -> Union[Dict[int, NumType], NumType]:
218
- """
219
- :param recommendations: model predictions in a
220
- DataFrame ``[user_idx, item_idx, relevance]``
221
- :param ground_truth: test data
222
- ``[user_idx, item_idx, timestamp, relevance]``
223
- :param k: depth cut-off. Truncates recommendation lists to top-k items.
224
- :param ground_truth_users: list of users to consider in metric calculation.
225
- if None, only the users from ground_truth are considered.
226
- :return: metric value
227
- """
228
- recs = get_enriched_recommendations(
229
- recommendations,
230
- ground_truth,
231
- max_k=k if isinstance(k, int) else max(k),
232
- ground_truth_users=ground_truth_users,
233
- )
234
- return self._mean(recs, k)
235
-
236
- @process_k
237
- def _conf_interval(self, recs: SparkDataFrame, k_list: list, alpha: float):
238
- res = {}
239
- quantile = norm.ppf((1 + alpha) / 2)
240
- for k in k_list:
241
- distribution = self._get_metric_distribution(recs, k)
242
- value = (
243
- distribution.agg(
244
- sf.stddev("value").alias("std"),
245
- sf.count("value").alias("count"),
246
- )
247
- .select(
248
- sf.when(
249
- sf.isnan(sf.col("std")) | sf.col("std").isNull(),
250
- sf.lit(0.0),
251
- )
252
- .otherwise(sf.col("std"))
253
- .cast("float")
254
- .alias("std"),
255
- "count",
256
- )
257
- .first()
258
- )
259
- res[k] = quantile * value["std"] / (value["count"] ** 0.5)
260
- return res
261
-
262
- @process_k
263
- def _median(self, recs: SparkDataFrame, k_list: list):
264
- res = {}
265
- for k in k_list:
266
- distribution = self._get_metric_distribution(recs, k)
267
- value = distribution.agg(
268
- sf.expr("percentile_approx(value, 0.5)").alias("value")
269
- ).first()["value"]
270
- res[k] = value
271
- return res
272
-
273
- @process_k
274
- def _mean(self, recs: SparkDataFrame, k_list: list):
275
- res = {}
276
- for k in k_list:
277
- distribution = self._get_metric_distribution(recs, k)
278
- value = distribution.agg(sf.avg("value").alias("value")).first()[
279
- "value"
280
- ]
281
- res[k] = value
282
- return res
283
-
284
- def _get_metric_distribution(self, recs: SparkDataFrame, k: int) -> SparkDataFrame:
285
- """
286
- :param recs: recommendations
287
- :param k: depth cut-off
288
- :return: metric distribution for different cut-offs and users
289
- """
290
- if self._use_scala_udf:
291
- metric_value_col = self.get_scala_udf(
292
- self.scala_udf_name, [sf.lit(k).alias("k"), *recs.columns[1:]]
293
- ).alias("value")
294
- return recs.select("user_idx", metric_value_col)
295
-
296
- cur_class = self.__class__
297
- distribution = recs.rdd.flatMap(
298
- # pylint: disable=protected-access
299
- lambda x: [
300
- (x[0], float(cur_class._get_metric_value_by_user(k, *x[1:])))
301
- ]
302
- ).toDF(
303
- f"user_idx {recs.schema['user_idx'].dataType.typeName()}, value double"
304
- )
305
- return distribution
306
-
307
- @staticmethod
308
- @abstractmethod
309
- def _get_metric_value_by_user(k, pred, ground_truth) -> float:
310
- """
311
- Metric calculation for one user.
312
-
313
- :param k: depth cut-off
314
- :param pred: recommendations
315
- :param ground_truth: test data
316
- :return: metric value for current user
317
- """
318
-
319
- # pylint: disable=too-many-arguments
320
- def user_distribution(
321
- self,
322
- log: DataFrameLike,
323
- recommendations: DataFrameLike,
324
- ground_truth: DataFrameLike,
325
- k: IntOrList,
326
- ground_truth_users: Optional[DataFrameLike] = None,
327
- ) -> PandasDataFrame:
328
- """
329
- Get mean value of metric for all users with the same number of ratings.
330
-
331
- :param log: history DataFrame to calculate number of ratings per user
332
- :param recommendations: prediction DataFrame
333
- :param ground_truth: test data
334
- :param k: depth cut-off
335
- :param ground_truth_users: list of users to consider in metric calculation.
336
- if None, only the users from ground_truth are considered.
337
- :return: pandas DataFrame
338
- """
339
- log = convert2spark(log)
340
- count = log.groupBy("user_idx").count()
341
- if hasattr(self, "_get_enriched_recommendations"):
342
- recs = self._get_enriched_recommendations(
343
- recommendations,
344
- ground_truth,
345
- max_k=k if isinstance(k, int) else max(k),
346
- ground_truth_users=ground_truth_users,
347
- )
348
- else:
349
- recs = get_enriched_recommendations(
350
- recommendations,
351
- ground_truth,
352
- max_k=k if isinstance(k, int) else max(k),
353
- ground_truth_users=ground_truth_users,
354
- )
355
- if isinstance(k, int):
356
- k_list = [k]
357
- else:
358
- k_list = k
359
- res = PandasDataFrame()
360
- for cut_off in k_list:
361
- dist = self._get_metric_distribution(recs, cut_off)
362
- val = count.join(dist, on="user_idx", how="right").fillna(
363
- 0, subset="count"
364
- )
365
- val = (
366
- val.groupBy("count")
367
- .agg(sf.avg("value").alias("value"))
368
- .orderBy(["count"])
369
- .select("count", "value")
370
- .toPandas()
371
- )
372
- res = res.append(val, ignore_index=True)
373
- return res
374
-
375
- @staticmethod
376
- def get_scala_udf(udf_name: str, params: List) -> Column:
377
- """
378
- Returns expression of calling scala UDF as column
379
-
380
- :param udf_name: UDF name from `org.apache.spark.replay.utils.ScalaPySparkUDFs`
381
- :param params: list of UDF params in right order
382
- :return: column expression
383
- """
384
- sc = State().session.sparkContext # pylint: disable=invalid-name
385
- scala_udf = getattr(
386
- sc._jvm.org.apache.spark.replay.utils.ScalaPySparkUDFs, udf_name
387
- )()
388
- return Column(scala_udf.apply(_to_seq(sc, params, _to_java_column)))
389
-
390
-
391
- # pylint: disable=too-few-public-methods
392
- class RecOnlyMetric(Metric):
393
- """Base class for metrics that do not need holdout data"""
394
-
395
- @abstractmethod
396
- def __init__(self, log: DataFrameLike, *args, **kwargs): # pylint: disable=super-init-not-called
397
- pass
398
-
399
- # pylint: disable=no-self-use
400
- @abstractmethod
401
- def _get_enriched_recommendations(
402
- self,
403
- recommendations: DataFrameLike,
404
- ground_truth: Optional[DataFrameLike],
405
- max_k: int,
406
- ground_truth_users: Optional[DataFrameLike] = None,
407
- ) -> SparkDataFrame:
408
- pass
409
-
410
- def __call__(
411
- self,
412
- recommendations: DataFrameLike,
413
- k: IntOrList,
414
- ground_truth_users: Optional[DataFrameLike] = None,
415
- ) -> Union[Dict[int, NumType], NumType]:
416
- """
417
- :param recommendations: predictions of a model,
418
- DataFrame ``[user_idx, item_idx, relevance]``
419
- :param k: depth cut-off
420
- :param ground_truth_users: list of users to consider in metric calculation.
421
- if None, only the users from ground_truth are considered.
422
- :return: metric value
423
- """
424
- recs = self._get_enriched_recommendations(
425
- recommendations,
426
- None,
427
- max_k=k if isinstance(k, int) else max(k),
428
- ground_truth_users=ground_truth_users,
429
- )
430
- return self._mean(recs, k)
431
-
432
- @staticmethod
433
- @abstractmethod
434
- def _get_metric_value_by_user(k, *args) -> float:
435
- """
436
- Metric calculation for one user.
437
-
438
- :param k: depth cut-off
439
- :param *args: extra parameters, returned by
440
- '''self._get_enriched_recommendations''' method
441
- :return: metric value for current user
442
- """
443
-
444
-
445
- class NCISMetric(Metric):
446
- """
447
- RePlay implements Normalized Capped Importance Sampling for metric calculation with ``NCISMetric`` class.
448
- This method is mostly applied to RL-based recommendation systems to perform counterfactual evaluation, but could be
449
- used for any kind of recommender systems. See an article
450
- `Offline A/B testing for Recommender Systems <http://arxiv.org/abs/1801.07030>` for details.
451
-
452
- *Reward* (metric value for a user-item pair) is weighed by
453
- the ratio of *current policy score* (current relevance) on *previous policy score* (historical relevance).
454
-
455
- The *weight* is clipped by the *threshold* and put into interval :math:`[\\frac{1}{threshold}, threshold]`.
456
- Activation function (e.g. softmax, sigmoid) could be applied to the scores before weights calculation.
457
-
458
- Normalization weight for recommended item is calculated as follows:
459
-
460
- .. math::
461
- w_{ui} = \\frac{f(\pi^t_ui, pi^t_u)}{f(\pi^p_ui, pi^p_u)}
462
-
463
- Where:
464
-
465
- :math:`\pi^t_{ui}` - current policy value (predicted relevance) of the user-item interaction
466
-
467
- :math:`\pi^p_{ui}` - previous policy value (historical relevance) of the user-item interaction.
468
- Only values for user-item pairs present in current recommendations are used for calculation.
469
-
470
- :math:`\pi_u` - all predicted /historical policy values for selected user :math:`u`
471
-
472
- :math:`f(\pi_{ui}, \pi_u)` - activation function applied to policy values (optional)
473
-
474
- :math:`w_{ui}` - weight of user-item interaction for normalized metric calculation before clipping
475
-
476
-
477
- Calculated weights are clipped as follows:
478
-
479
- .. math::
480
- \hat{w_{ui}} = min(max(\\frac{1}{threshold}, w_{ui}), threshold)
481
-
482
- Normalization metric value for a user is calculated as follows:
483
-
484
- .. math::
485
- R_u = \\frac{r_{ui} \hat{w_{ui}}}{\sum_{i}\hat{w_{ui}}}
486
-
487
- Where:
488
-
489
- :math:`r_ui` - metric value (reward) for user-item interaction
490
-
491
- :math:`R_u` - metric value (reward) for user :math:`u`
492
-
493
- Weight calculation is implemented in ``_get_enriched_recommendations`` method.
494
- """
495
-
496
- def __init__(
497
- self,
498
- prev_policy_weights: DataFrameLike,
499
- threshold: float = 10.0,
500
- activation: Optional[str] = None,
501
- use_scala_udf: bool = False,
502
- ): # pylint: disable=super-init-not-called
503
- """
504
- :param prev_policy_weights: historical item of user-item relevance (previous policy values)
505
- :threshold: capping threshold, applied after activation,
506
- relevance values are cropped to interval [1/`threshold`, `threshold`]
507
- :activation: activation function, applied over relevance values.
508
- "logit"/"sigmoid", "softmax" or None
509
- """
510
- self._use_scala_udf = use_scala_udf
511
- self.prev_policy_weights = convert2spark(
512
- prev_policy_weights
513
- ).withColumnRenamed("relevance", "prev_relevance")
514
- self.threshold = threshold
515
- if activation is None or activation in ("logit", "sigmoid", "softmax"):
516
- self.activation = activation
517
- if activation == "softmax":
518
- self.logger.info(
519
- "For accurate softmax calculation pass only one `k` value "
520
- "in the NCISMetric metrics `call`"
521
- )
522
- else:
523
- raise ValueError(f"Unexpected `activation` - {activation}")
524
- if threshold <= 0:
525
- raise ValueError("Threshold should be positive real number")
526
-
527
- @staticmethod
528
- def _softmax_by_user(df: SparkDataFrame, col_name: str) -> SparkDataFrame:
529
- """
530
- Subtract minimal value (relevance) by user from `col_name`
531
- and apply softmax by user to `col_name`.
532
- """
533
- return (
534
- df.withColumn(
535
- "_min_rel_user",
536
- sf.min(col_name).over(Window.partitionBy("user_idx")),
537
- )
538
- .withColumn(
539
- col_name, sf.exp(sf.col(col_name) - sf.col("_min_rel_user"))
540
- )
541
- .withColumn(
542
- col_name,
543
- sf.col(col_name)
544
- / sf.sum(col_name).over(Window.partitionBy("user_idx")),
545
- )
546
- .drop("_min_rel_user")
547
- )
548
-
549
- @staticmethod
550
- def _sigmoid(df: SparkDataFrame, col_name: str) -> SparkDataFrame:
551
- """
552
- Apply sigmoid/logistic function to column `col_name`
553
- """
554
- return df.withColumn(
555
- col_name, sf.lit(1.0) / (sf.lit(1.0) + sf.exp(-sf.col(col_name)))
556
- )
557
-
558
- @staticmethod
559
- def _weigh_and_clip(
560
- df: SparkDataFrame,
561
- threshold: float,
562
- target_policy_col: str = "relevance",
563
- prev_policy_col: str = "prev_relevance",
564
- ):
565
- """
566
- Clip weights to fit into interval [1/threshold, threshold].
567
- """
568
- lower, upper = 1 / threshold, threshold
569
- return (
570
- df.withColumn(
571
- "weight_unbounded",
572
- sf.col(target_policy_col) / sf.col(prev_policy_col),
573
- )
574
- .withColumn(
575
- "weight",
576
- sf.when(sf.col(prev_policy_col) == sf.lit(0.0), sf.lit(upper))
577
- .when(
578
- sf.col("weight_unbounded") < sf.lit(lower), sf.lit(lower)
579
- )
580
- .when(
581
- sf.col("weight_unbounded") > sf.lit(upper), sf.lit(upper)
582
- )
583
- .otherwise(sf.col("weight_unbounded")),
584
- )
585
- .select("user_idx", "item_idx", "relevance", "weight")
586
- )
587
-
588
- def _reweighing(self, recommendations):
589
- if self.activation == "softmax":
590
- recommendations = self._softmax_by_user(
591
- recommendations, col_name="prev_relevance"
592
- )
593
- recommendations = self._softmax_by_user(
594
- recommendations, col_name="relevance"
595
- )
596
- elif self.activation in ["logit", "sigmoid"]:
597
- recommendations = self._sigmoid(
598
- recommendations, col_name="prev_relevance"
599
- )
600
- recommendations = self._sigmoid(
601
- recommendations, col_name="relevance"
602
- )
603
-
604
- return self._weigh_and_clip(recommendations, self.threshold)
605
-
606
- def _get_enriched_recommendations(
607
- self,
608
- recommendations: DataFrameLike,
609
- ground_truth: DataFrameLike,
610
- max_k: int,
611
- ground_truth_users: Optional[DataFrameLike] = None,
612
- ) -> SparkDataFrame:
613
- """
614
- Merge recommendations and ground truth into a single DataFrame
615
- and aggregate items into lists so that each user has only one record.
616
-
617
- :param recommendations: recommendation list
618
- :param ground_truth: test data
619
- :param max_k: maximal k value to calculate the metric for.
620
- `max_k` most relevant predictions are left for each user
621
- :param ground_truth_users: list of users to consider in metric calculation.
622
- if None, only the users from ground_truth are considered.
623
- :return: ``[user_idx, pred, ground_truth]``
624
- """
625
- recommendations = convert2spark(recommendations)
626
- ground_truth = convert2spark(ground_truth)
627
- ground_truth_users = convert2spark(ground_truth_users)
628
-
629
- true_items_by_users = ground_truth.groupby("user_idx").agg(
630
- sf.collect_set("item_idx").alias("ground_truth")
631
- )
632
-
633
- group_on = ["item_idx"]
634
- if "user_idx" in self.prev_policy_weights.columns:
635
- group_on.append("user_idx")
636
- recommendations = get_top_k_recs(recommendations, k=max_k)
637
-
638
- recommendations = recommendations.join(
639
- self.prev_policy_weights, on=group_on, how="left"
640
- ).na.fill(0.0, subset=["prev_relevance"])
641
-
642
- recommendations = self._reweighing(recommendations)
643
-
644
- weight_type = recommendations.schema["weight"].dataType
645
- item_type = ground_truth.schema["item_idx"].dataType
646
-
647
- recommendations = filter_sort(recommendations, "weight")
648
-
649
- if ground_truth_users is not None:
650
- true_items_by_users = true_items_by_users.join(
651
- ground_truth_users, on="user_idx", how="right"
652
- )
653
-
654
- recommendations = recommendations.join(
655
- true_items_by_users, how="right", on=["user_idx"]
656
- )
657
- return fill_na_with_empty_array(
658
- fill_na_with_empty_array(recommendations, "pred", item_type),
659
- "weight",
660
- weight_type,
661
- )