replay-rec 0.17.1__py3-none-any.whl → 0.17.1rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. replay/__init__.py +1 -1
  2. replay/experimental/__init__.py +0 -0
  3. replay/experimental/metrics/__init__.py +61 -0
  4. replay/experimental/metrics/base_metric.py +601 -0
  5. replay/experimental/metrics/coverage.py +97 -0
  6. replay/experimental/metrics/experiment.py +175 -0
  7. replay/experimental/metrics/hitrate.py +26 -0
  8. replay/experimental/metrics/map.py +30 -0
  9. replay/experimental/metrics/mrr.py +18 -0
  10. replay/experimental/metrics/ncis_precision.py +31 -0
  11. replay/experimental/metrics/ndcg.py +49 -0
  12. replay/experimental/metrics/precision.py +22 -0
  13. replay/experimental/metrics/recall.py +25 -0
  14. replay/experimental/metrics/rocauc.py +49 -0
  15. replay/experimental/metrics/surprisal.py +90 -0
  16. replay/experimental/metrics/unexpectedness.py +76 -0
  17. replay/experimental/models/__init__.py +10 -0
  18. replay/experimental/models/admm_slim.py +205 -0
  19. replay/experimental/models/base_neighbour_rec.py +204 -0
  20. replay/experimental/models/base_rec.py +1271 -0
  21. replay/experimental/models/base_torch_rec.py +234 -0
  22. replay/experimental/models/cql.py +452 -0
  23. replay/experimental/models/ddpg.py +921 -0
  24. replay/experimental/models/dt4rec/__init__.py +0 -0
  25. replay/experimental/models/dt4rec/dt4rec.py +189 -0
  26. replay/experimental/models/dt4rec/gpt1.py +401 -0
  27. replay/experimental/models/dt4rec/trainer.py +127 -0
  28. replay/experimental/models/dt4rec/utils.py +265 -0
  29. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  30. replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
  31. replay/experimental/models/implicit_wrap.py +131 -0
  32. replay/experimental/models/lightfm_wrap.py +302 -0
  33. replay/experimental/models/mult_vae.py +331 -0
  34. replay/experimental/models/neuromf.py +405 -0
  35. replay/experimental/models/scala_als.py +296 -0
  36. replay/experimental/nn/data/__init__.py +1 -0
  37. replay/experimental/nn/data/schema_builder.py +55 -0
  38. replay/experimental/preprocessing/__init__.py +3 -0
  39. replay/experimental/preprocessing/data_preparator.py +838 -0
  40. replay/experimental/preprocessing/padder.py +229 -0
  41. replay/experimental/preprocessing/sequence_generator.py +208 -0
  42. replay/experimental/scenarios/__init__.py +1 -0
  43. replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  44. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
  45. replay/experimental/scenarios/obp_wrapper/replay_offline.py +248 -0
  46. replay/experimental/scenarios/obp_wrapper/utils.py +87 -0
  47. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  48. replay/experimental/scenarios/two_stages/reranker.py +117 -0
  49. replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
  50. replay/experimental/utils/__init__.py +0 -0
  51. replay/experimental/utils/logger.py +24 -0
  52. replay/experimental/utils/model_handler.py +181 -0
  53. replay/experimental/utils/session_handler.py +44 -0
  54. {replay_rec-0.17.1.dist-info → replay_rec-0.17.1rc0.dist-info}/METADATA +9 -1
  55. replay_rec-0.17.1rc0.dist-info/NOTICE +41 -0
  56. {replay_rec-0.17.1.dist-info → replay_rec-0.17.1rc0.dist-info}/RECORD +58 -5
  57. {replay_rec-0.17.1.dist-info → replay_rec-0.17.1rc0.dist-info}/WHEEL +1 -1
  58. {replay_rec-0.17.1.dist-info → replay_rec-0.17.1rc0.dist-info}/LICENSE +0 -0
replay/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
1
  """ RecSys library """
2
- __version__ = "0.17.1"
2
+ __version__ = "0.17.1.preview"
File without changes
@@ -0,0 +1,61 @@
1
+ """
2
+ Most metrics require dataframe with recommendations
3
+ and dataframe with ground truth values —
4
+ which objects each user interacted with.
5
+
6
+ - recommendations (Union[pandas.DataFrame, spark.DataFrame]):
7
+ predictions of a recommender system,
8
+ DataFrame with columns ``[user_id, item_id, relevance]``
9
+ - ground_truth (Union[pandas.DataFrame, spark.DataFrame]):
10
+ test data, DataFrame with columns
11
+ ``[user_id, item_id, timestamp, relevance]``
12
+
13
+ Metric is calculated for all users, presented in ``ground_truth``
14
+ for accurate metric calculation in case when the recommender system generated
15
+ recommendation not for all users. It is assumed, that all users,
16
+ we want to calculate metric for, have positive interactions.
17
+
18
+ But if we have users, who observed the recommendations, but have not responded,
19
+ those users will be ignored and metric will be overestimated.
20
+ For such case we propose additional optional parameter ``ground_truth_users``,
21
+ the dataframe with all users, which should be considered during the metric calculation.
22
+
23
+ - ground_truth_users (Optional[Union[pandas.DataFrame, spark.DataFrame]]):
24
+ full list of users to calculate metric for, DataFrame with ``user_id`` column
25
+
26
+ Every metric is calculated using top ``K`` items for each user.
27
+ It is also possible to calculate metrics
28
+ using multiple values for ``K`` simultaneously.
29
+ In this case the result will be a dictionary and not a number.
30
+
31
+ Make sure your recommendations do not contain user-item duplicates
32
+ as duplicates could lead to the wrong calculation results.
33
+
34
+ - k (Union[Iterable[int], int]):
35
+ a single number or a list, specifying the
36
+ truncation length for recommendation list for each user
37
+
38
+ By default, metrics are averaged by users,
39
+ but you can alternatively use method ``metric.median``.
40
+ Also, you can get the lower bound
41
+ of ``conf_interval`` for a given ``alpha``.
42
+
43
+ Diversity metrics require extra parameters on initialization stage,
44
+ but do not use ``ground_truth`` parameter.
45
+
46
+ For each metric, a formula for its calculation is given, because this is
47
+ important for the correct comparison of algorithms, as mentioned in our
48
+ `article <https://arxiv.org/abs/2206.12858>`_.
49
+ """
50
+ from replay.experimental.metrics.base_metric import Metric, NCISMetric
51
+ from replay.experimental.metrics.coverage import Coverage
52
+ from replay.experimental.metrics.hitrate import HitRate
53
+ from replay.experimental.metrics.map import MAP
54
+ from replay.experimental.metrics.mrr import MRR
55
+ from replay.experimental.metrics.ncis_precision import NCISPrecision
56
+ from replay.experimental.metrics.ndcg import NDCG
57
+ from replay.experimental.metrics.precision import Precision
58
+ from replay.experimental.metrics.recall import Recall
59
+ from replay.experimental.metrics.rocauc import RocAuc
60
+ from replay.experimental.metrics.surprisal import Surprisal
61
+ from replay.experimental.metrics.unexpectedness import Unexpectedness
@@ -0,0 +1,601 @@
1
+ """
2
+ Base classes for quality and diversity metrics.
3
+ """
4
+ import logging
5
+ from abc import ABC, abstractmethod
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ from scipy.stats import norm
9
+
10
+ from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, IntOrList, NumType, PandasDataFrame, SparkDataFrame
11
+ from replay.utils.session_handler import State
12
+ from replay.utils.spark_utils import convert2spark, get_top_k_recs
13
+
14
+ if PYSPARK_AVAILABLE:
15
+ from pyspark.sql import (
16
+ Column,
17
+ Window,
18
+ functions as sf,
19
+ types as st,
20
+ )
21
+ from pyspark.sql.column import _to_java_column, _to_seq
22
+ from pyspark.sql.types import DataType
23
+
24
+
25
+ def fill_na_with_empty_array(df: SparkDataFrame, col_name: str, element_type: DataType) -> SparkDataFrame:
26
+ """
27
+ Fill empty values in array column with empty array of `element_type` values.
28
+ :param df: dataframe with `col_name` column of ArrayType(`element_type`)
29
+ :param col_name: name of a column to fill missing values
30
+ :param element_type: DataType of an array element
31
+ :return: df with `col_name` na values filled with empty arrays
32
+ """
33
+ return df.withColumn(
34
+ col_name,
35
+ sf.coalesce(
36
+ col_name,
37
+ sf.array().cast(st.ArrayType(element_type)),
38
+ ),
39
+ )
40
+
41
+
42
+ def preprocess_gt(
43
+ ground_truth: DataFrameLike,
44
+ ground_truth_users: Optional[DataFrameLike] = None,
45
+ ) -> SparkDataFrame:
46
+ """
47
+ Preprocess `ground_truth` data before metric calculation
48
+ :param ground_truth: spark dataframe with columns ``[user_idx, item_idx, relevance]``
49
+ :param ground_truth_users: spark dataframe with column ``[user_idx]``
50
+ :return: spark dataframe with columns ``[user_idx, ground_truth]``
51
+ """
52
+ ground_truth = convert2spark(ground_truth)
53
+ ground_truth_users = convert2spark(ground_truth_users)
54
+
55
+ true_items_by_users = ground_truth.groupby("user_idx").agg(sf.collect_set("item_idx").alias("ground_truth"))
56
+ if ground_truth_users is not None:
57
+ true_items_by_users = true_items_by_users.join(ground_truth_users, on="user_idx", how="right")
58
+ true_items_by_users = fill_na_with_empty_array(
59
+ true_items_by_users,
60
+ "ground_truth",
61
+ ground_truth.schema["item_idx"].dataType,
62
+ )
63
+
64
+ return true_items_by_users
65
+
66
+
67
+ def drop_duplicates(recommendations: DataFrameLike) -> SparkDataFrame:
68
+ """
69
+ Filter duplicated predictions by choosing the most relevant
70
+ """
71
+ return (
72
+ recommendations.withColumn(
73
+ "_num",
74
+ sf.row_number().over(Window.partitionBy("user_idx", "item_idx").orderBy(sf.col("relevance").desc())),
75
+ )
76
+ .where(sf.col("_num") == 1)
77
+ .drop("_num")
78
+ )
79
+
80
+
81
+ def filter_sort(recommendations: SparkDataFrame, extra_column: Optional[str] = None) -> SparkDataFrame:
82
+ """
83
+ Filters duplicated predictions by choosing items with the highest relevance,
84
+ Sorts items in predictions by its relevance,
85
+ If `extra_column` is not None return DataFrame with extra_column e.g. item weight.
86
+
87
+ :param recommendations: recommendation list
88
+ :param extra_column: column in recommendations
89
+ which will be return besides ``[user_idx, item_idx]``
90
+ :return: ``[user_idx, item_idx]`` if extra_column = None
91
+ or ``[user_idx, item_idx, extra_column]`` if extra_column exists.
92
+ """
93
+ item_type = recommendations.schema["item_idx"].dataType
94
+ extra_column_type = recommendations.schema[extra_column].dataType if extra_column else None
95
+
96
+ recommendations = drop_duplicates(recommendations)
97
+
98
+ recommendations = (
99
+ recommendations.groupby("user_idx")
100
+ .agg(
101
+ sf.collect_list(sf.struct(*[c for c in ["relevance", "item_idx", extra_column] if c is not None])).alias(
102
+ "pred_list"
103
+ )
104
+ )
105
+ .withColumn("pred_list", sf.reverse(sf.array_sort("pred_list")))
106
+ )
107
+
108
+ selection = ["user_idx", sf.col("pred_list.item_idx").cast(st.ArrayType(item_type, True)).alias("pred")]
109
+ if extra_column:
110
+ selection.append(
111
+ sf.col(f"pred_list.{extra_column}").cast(st.ArrayType(extra_column_type, True)).alias(extra_column)
112
+ )
113
+
114
+ recommendations = recommendations.select(*selection)
115
+
116
+ return recommendations
117
+
118
+
119
+ def get_enriched_recommendations(
120
+ recommendations: DataFrameLike,
121
+ ground_truth: DataFrameLike,
122
+ max_k: int,
123
+ ground_truth_users: Optional[DataFrameLike] = None,
124
+ ) -> SparkDataFrame:
125
+ """
126
+ Leave max_k recommendations for each user,
127
+ merge recommendations and ground truth into a single DataFrame
128
+ and aggregate items into lists so that each user has only one record.
129
+
130
+ :param recommendations: recommendation list
131
+ :param ground_truth: test data
132
+ :param max_k: maximal k value to calculate the metric for.
133
+ `max_k` most relevant predictions are left for each user
134
+ :param ground_truth_users: list of users to consider in metric calculation.
135
+ if None, only the users from ground_truth are considered.
136
+ :return: ``[user_idx, pred, ground_truth]``
137
+ """
138
+ recommendations = convert2spark(recommendations)
139
+ # if there are duplicates in recommendations,
140
+ # we will leave fewer than k recommendations after sort_udf
141
+ recommendations = get_top_k_recs(recommendations, k=max_k)
142
+
143
+ true_items_by_users = preprocess_gt(ground_truth, ground_truth_users)
144
+ joined = filter_sort(recommendations).join(true_items_by_users, how="right", on=["user_idx"])
145
+
146
+ return fill_na_with_empty_array(joined, "pred", recommendations.schema["item_idx"].dataType)
147
+
148
+
149
+ def process_k(func):
150
+ """Decorator that converts k to list and unpacks result"""
151
+
152
+ def wrap(self, recs: SparkDataFrame, k: IntOrList, *args):
153
+ k_list = [k] if isinstance(k, int) else k
154
+
155
+ res = func(self, recs, k_list, *args)
156
+
157
+ if isinstance(k, int):
158
+ return res[k]
159
+ return res
160
+
161
+ return wrap
162
+
163
+
164
+ class Metric(ABC):
165
+ """Base metric class"""
166
+
167
+ _logger: Optional[logging.Logger] = None
168
+ _scala_udf_name: Optional[str] = None
169
+
170
+ def __init__(self, use_scala_udf: bool = False) -> None:
171
+ self._use_scala_udf = use_scala_udf
172
+
173
+ @property
174
+ def logger(self) -> logging.Logger:
175
+ """
176
+ :returns: get library logger
177
+ """
178
+ if self._logger is None:
179
+ self._logger = logging.getLogger("replay")
180
+ return self._logger
181
+
182
+ @property
183
+ def scala_udf_name(self) -> str:
184
+ """Returns UDF name from `org.apache.spark.replay.utils.ScalaPySparkUDFs`"""
185
+ if self._scala_udf_name:
186
+ return self._scala_udf_name
187
+ else:
188
+ msg = f"Scala UDF not implemented for {type(self).__name__} class!"
189
+ raise NotImplementedError(msg)
190
+
191
+ def __str__(self):
192
+ return type(self).__name__
193
+
194
+ def __call__(
195
+ self,
196
+ recommendations: DataFrameLike,
197
+ ground_truth: DataFrameLike,
198
+ k: IntOrList,
199
+ ground_truth_users: Optional[DataFrameLike] = None,
200
+ ) -> Union[Dict[int, NumType], NumType]:
201
+ """
202
+ :param recommendations: model predictions in a
203
+ DataFrame ``[user_idx, item_idx, relevance]``
204
+ :param ground_truth: test data
205
+ ``[user_idx, item_idx, timestamp, relevance]``
206
+ :param k: depth cut-off. Truncates recommendation lists to top-k items.
207
+ :param ground_truth_users: list of users to consider in metric calculation.
208
+ if None, only the users from ground_truth are considered.
209
+ :return: metric value
210
+ """
211
+ recs = get_enriched_recommendations(
212
+ recommendations,
213
+ ground_truth,
214
+ max_k=k if isinstance(k, int) else max(k),
215
+ ground_truth_users=ground_truth_users,
216
+ )
217
+ return self._mean(recs, k)
218
+
219
+ @process_k
220
+ def _conf_interval(self, recs: SparkDataFrame, k_list: list, alpha: float):
221
+ res = {}
222
+ quantile = norm.ppf((1 + alpha) / 2)
223
+ for k in k_list:
224
+ distribution = self._get_metric_distribution(recs, k)
225
+ value = (
226
+ distribution.agg(
227
+ sf.stddev("value").alias("std"),
228
+ sf.count("value").alias("count"),
229
+ )
230
+ .select(
231
+ sf.when(
232
+ sf.isnan(sf.col("std")) | sf.col("std").isNull(),
233
+ sf.lit(0.0),
234
+ )
235
+ .otherwise(sf.col("std"))
236
+ .cast("float")
237
+ .alias("std"),
238
+ "count",
239
+ )
240
+ .first()
241
+ )
242
+ res[k] = quantile * value["std"] / (value["count"] ** 0.5)
243
+ return res
244
+
245
+ @process_k
246
+ def _median(self, recs: SparkDataFrame, k_list: list):
247
+ res = {}
248
+ for k in k_list:
249
+ distribution = self._get_metric_distribution(recs, k)
250
+ value = distribution.agg(sf.expr("percentile_approx(value, 0.5)").alias("value")).first()["value"]
251
+ res[k] = value
252
+ return res
253
+
254
+ @process_k
255
+ def _mean(self, recs: SparkDataFrame, k_list: list):
256
+ res = {}
257
+ for k in k_list:
258
+ distribution = self._get_metric_distribution(recs, k)
259
+ value = distribution.agg(sf.avg("value").alias("value")).first()["value"]
260
+ res[k] = value
261
+ return res
262
+
263
+ def _get_metric_distribution(self, recs: SparkDataFrame, k: int) -> SparkDataFrame:
264
+ """
265
+ :param recs: recommendations
266
+ :param k: depth cut-off
267
+ :return: metric distribution for different cut-offs and users
268
+ """
269
+ if self._use_scala_udf:
270
+ metric_value_col = self.get_scala_udf(self.scala_udf_name, [sf.lit(k).alias("k"), *recs.columns[1:]]).alias(
271
+ "value"
272
+ )
273
+ return recs.select("user_idx", metric_value_col)
274
+
275
+ cur_class = self.__class__
276
+ distribution = recs.rdd.flatMap(lambda x: [(x[0], float(cur_class._get_metric_value_by_user(k, *x[1:])))]).toDF(
277
+ f"user_idx {recs.schema['user_idx'].dataType.typeName()}, value double"
278
+ )
279
+ return distribution
280
+
281
+ @staticmethod
282
+ @abstractmethod
283
+ def _get_metric_value_by_user(k, pred, ground_truth) -> float:
284
+ """
285
+ Metric calculation for one user.
286
+
287
+ :param k: depth cut-off
288
+ :param pred: recommendations
289
+ :param ground_truth: test data
290
+ :return: metric value for current user
291
+ """
292
+
293
+ def user_distribution(
294
+ self,
295
+ log: DataFrameLike,
296
+ recommendations: DataFrameLike,
297
+ ground_truth: DataFrameLike,
298
+ k: IntOrList,
299
+ ground_truth_users: Optional[DataFrameLike] = None,
300
+ ) -> PandasDataFrame:
301
+ """
302
+ Get mean value of metric for all users with the same number of ratings.
303
+
304
+ :param log: history DataFrame to calculate number of ratings per user
305
+ :param recommendations: prediction DataFrame
306
+ :param ground_truth: test data
307
+ :param k: depth cut-off
308
+ :param ground_truth_users: list of users to consider in metric calculation.
309
+ if None, only the users from ground_truth are considered.
310
+ :return: pandas DataFrame
311
+ """
312
+ log = convert2spark(log)
313
+ count = log.groupBy("user_idx").count()
314
+ if hasattr(self, "_get_enriched_recommendations"):
315
+ recs = self._get_enriched_recommendations(
316
+ recommendations,
317
+ ground_truth,
318
+ max_k=k if isinstance(k, int) else max(k),
319
+ ground_truth_users=ground_truth_users,
320
+ )
321
+ else:
322
+ recs = get_enriched_recommendations(
323
+ recommendations,
324
+ ground_truth,
325
+ max_k=k if isinstance(k, int) else max(k),
326
+ ground_truth_users=ground_truth_users,
327
+ )
328
+ k_list = [k] if isinstance(k, int) else k
329
+ res = PandasDataFrame()
330
+ for cut_off in k_list:
331
+ dist = self._get_metric_distribution(recs, cut_off)
332
+ val = count.join(dist, on="user_idx", how="right").fillna(0, subset="count")
333
+ val = (
334
+ val.groupBy("count")
335
+ .agg(sf.avg("value").alias("value"))
336
+ .orderBy(["count"])
337
+ .select("count", "value")
338
+ .toPandas()
339
+ )
340
+ res = res.append(val, ignore_index=True)
341
+ return res
342
+
343
+ @staticmethod
344
+ def get_scala_udf(udf_name: str, params: List) -> Column:
345
+ """
346
+ Returns expression of calling scala UDF as column
347
+
348
+ :param udf_name: UDF name from `org.apache.spark.replay.utils.ScalaPySparkUDFs`
349
+ :param params: list of UDF params in right order
350
+ :return: column expression
351
+ """
352
+ sc = State().session.sparkContext
353
+ scala_udf = getattr(sc._jvm.org.apache.spark.replay.utils.ScalaPySparkUDFs, udf_name)()
354
+ return Column(scala_udf.apply(_to_seq(sc, params, _to_java_column)))
355
+
356
+
357
+ class RecOnlyMetric(Metric):
358
+ """Base class for metrics that do not need holdout data"""
359
+
360
+ @abstractmethod
361
+ def __init__(self, log: DataFrameLike, *args, **kwargs):
362
+ pass
363
+
364
+ @abstractmethod
365
+ def _get_enriched_recommendations(
366
+ self,
367
+ recommendations: DataFrameLike,
368
+ ground_truth: Optional[DataFrameLike],
369
+ max_k: int,
370
+ ground_truth_users: Optional[DataFrameLike] = None,
371
+ ) -> SparkDataFrame:
372
+ pass
373
+
374
+ def __call__(
375
+ self,
376
+ recommendations: DataFrameLike,
377
+ k: IntOrList,
378
+ ground_truth_users: Optional[DataFrameLike] = None,
379
+ ) -> Union[Dict[int, NumType], NumType]:
380
+ """
381
+ :param recommendations: predictions of a model,
382
+ DataFrame ``[user_idx, item_idx, relevance]``
383
+ :param k: depth cut-off
384
+ :param ground_truth_users: list of users to consider in metric calculation.
385
+ if None, only the users from ground_truth are considered.
386
+ :return: metric value
387
+ """
388
+ recs = self._get_enriched_recommendations(
389
+ recommendations,
390
+ None,
391
+ max_k=k if isinstance(k, int) else max(k),
392
+ ground_truth_users=ground_truth_users,
393
+ )
394
+ return self._mean(recs, k)
395
+
396
+ @staticmethod
397
+ @abstractmethod
398
+ def _get_metric_value_by_user(k, *args) -> float:
399
+ """
400
+ Metric calculation for one user.
401
+
402
+ :param k: depth cut-off
403
+ :param *args: extra parameters, returned by
404
+ '''self._get_enriched_recommendations''' method
405
+ :return: metric value for current user
406
+ """
407
+
408
+
409
+ class NCISMetric(Metric):
410
+ """
411
+ RePlay implements Normalized Capped Importance Sampling for metric calculation with ``NCISMetric`` class.
412
+ This method is mostly applied to RL-based recommendation systems to perform counterfactual evaluation, but could be
413
+ used for any kind of recommender systems. See an article
414
+ `Offline A/B testing for Recommender Systems <http://arxiv.org/abs/1801.07030>` for details.
415
+
416
+ *Reward* (metric value for a user-item pair) is weighed by
417
+ the ratio of *current policy score* (current relevance) on *previous policy score* (historical relevance).
418
+
419
+ The *weight* is clipped by the *threshold* and put into interval :math:`[\\frac{1}{threshold}, threshold]`.
420
+ Activation function (e.g. softmax, sigmoid) could be applied to the scores before weights calculation.
421
+
422
+ Normalization weight for recommended item is calculated as follows:
423
+
424
+ .. math::
425
+ w_{ui} = \\frac{f(\\pi^t_ui, pi^t_u)}{f(\\pi^p_ui, pi^p_u)}
426
+
427
+ Where:
428
+
429
+ :math:`\\pi^t_{ui}` - current policy value (predicted relevance) of the user-item interaction
430
+
431
+ :math:`\\pi^p_{ui}` - previous policy value (historical relevance) of the user-item interaction.
432
+ Only values for user-item pairs present in current recommendations are used for calculation.
433
+
434
+ :math:`\\pi_u` - all predicted /historical policy values for selected user :math:`u`
435
+
436
+ :math:`f(\\pi_{ui}, \\pi_u)` - activation function applied to policy values (optional)
437
+
438
+ :math:`w_{ui}` - weight of user-item interaction for normalized metric calculation before clipping
439
+
440
+
441
+ Calculated weights are clipped as follows:
442
+
443
+ .. math::
444
+ \\hat{w_{ui}} = min(max(\\frac{1}{threshold}, w_{ui}), threshold)
445
+
446
+ Normalization metric value for a user is calculated as follows:
447
+
448
+ .. math::
449
+ R_u = \\frac{r_{ui} \\hat{w_{ui}}}{\\sum_{i}\\hat{w_{ui}}}
450
+
451
+ Where:
452
+
453
+ :math:`r_ui` - metric value (reward) for user-item interaction
454
+
455
+ :math:`R_u` - metric value (reward) for user :math:`u`
456
+
457
+ Weight calculation is implemented in ``_get_enriched_recommendations`` method.
458
+ """
459
+
460
+ def __init__(
461
+ self,
462
+ prev_policy_weights: DataFrameLike,
463
+ threshold: float = 10.0,
464
+ activation: Optional[str] = None,
465
+ use_scala_udf: bool = False,
466
+ ):
467
+ """
468
+ :param prev_policy_weights: historical item of user-item relevance (previous policy values)
469
+ :threshold: capping threshold, applied after activation,
470
+ relevance values are cropped to interval [1/`threshold`, `threshold`]
471
+ :activation: activation function, applied over relevance values.
472
+ "logit"/"sigmoid", "softmax" or None
473
+ """
474
+ self._use_scala_udf = use_scala_udf
475
+ self.prev_policy_weights = convert2spark(prev_policy_weights).withColumnRenamed("relevance", "prev_relevance")
476
+ self.threshold = threshold
477
+ if activation is None or activation in ("logit", "sigmoid", "softmax"):
478
+ self.activation = activation
479
+ if activation == "softmax":
480
+ self.logger.info(
481
+ "For accurate softmax calculation pass only one `k` value in the NCISMetric metrics `call`"
482
+ )
483
+ else:
484
+ msg = f"Unexpected `activation` - {activation}"
485
+ raise ValueError(msg)
486
+ if threshold <= 0:
487
+ msg = "Threshold should be positive real number"
488
+ raise ValueError(msg)
489
+
490
+ @staticmethod
491
+ def _softmax_by_user(df: SparkDataFrame, col_name: str) -> SparkDataFrame:
492
+ """
493
+ Subtract minimal value (relevance) by user from `col_name`
494
+ and apply softmax by user to `col_name`.
495
+ """
496
+ return (
497
+ df.withColumn(
498
+ "_min_rel_user",
499
+ sf.min(col_name).over(Window.partitionBy("user_idx")),
500
+ )
501
+ .withColumn(col_name, sf.exp(sf.col(col_name) - sf.col("_min_rel_user")))
502
+ .withColumn(
503
+ col_name,
504
+ sf.col(col_name) / sf.sum(col_name).over(Window.partitionBy("user_idx")),
505
+ )
506
+ .drop("_min_rel_user")
507
+ )
508
+
509
+ @staticmethod
510
+ def _sigmoid(df: SparkDataFrame, col_name: str) -> SparkDataFrame:
511
+ """
512
+ Apply sigmoid/logistic function to column `col_name`
513
+ """
514
+ return df.withColumn(col_name, sf.lit(1.0) / (sf.lit(1.0) + sf.exp(-sf.col(col_name))))
515
+
516
+ @staticmethod
517
+ def _weigh_and_clip(
518
+ df: SparkDataFrame,
519
+ threshold: float,
520
+ target_policy_col: str = "relevance",
521
+ prev_policy_col: str = "prev_relevance",
522
+ ):
523
+ """
524
+ Clip weights to fit into interval [1/threshold, threshold].
525
+ """
526
+ lower, upper = 1 / threshold, threshold
527
+ return (
528
+ df.withColumn(
529
+ "weight_unbounded",
530
+ sf.col(target_policy_col) / sf.col(prev_policy_col),
531
+ )
532
+ .withColumn(
533
+ "weight",
534
+ sf.when(sf.col(prev_policy_col) == sf.lit(0.0), sf.lit(upper))
535
+ .when(sf.col("weight_unbounded") < sf.lit(lower), sf.lit(lower))
536
+ .when(sf.col("weight_unbounded") > sf.lit(upper), sf.lit(upper))
537
+ .otherwise(sf.col("weight_unbounded")),
538
+ )
539
+ .select("user_idx", "item_idx", "relevance", "weight")
540
+ )
541
+
542
+ def _reweighing(self, recommendations):
543
+ if self.activation == "softmax":
544
+ recommendations = self._softmax_by_user(recommendations, col_name="prev_relevance")
545
+ recommendations = self._softmax_by_user(recommendations, col_name="relevance")
546
+ elif self.activation in ["logit", "sigmoid"]:
547
+ recommendations = self._sigmoid(recommendations, col_name="prev_relevance")
548
+ recommendations = self._sigmoid(recommendations, col_name="relevance")
549
+
550
+ return self._weigh_and_clip(recommendations, self.threshold)
551
+
552
+ def _get_enriched_recommendations(
553
+ self,
554
+ recommendations: DataFrameLike,
555
+ ground_truth: DataFrameLike,
556
+ max_k: int,
557
+ ground_truth_users: Optional[DataFrameLike] = None,
558
+ ) -> SparkDataFrame:
559
+ """
560
+ Merge recommendations and ground truth into a single DataFrame
561
+ and aggregate items into lists so that each user has only one record.
562
+
563
+ :param recommendations: recommendation list
564
+ :param ground_truth: test data
565
+ :param max_k: maximal k value to calculate the metric for.
566
+ `max_k` most relevant predictions are left for each user
567
+ :param ground_truth_users: list of users to consider in metric calculation.
568
+ if None, only the users from ground_truth are considered.
569
+ :return: ``[user_idx, pred, ground_truth]``
570
+ """
571
+ recommendations = convert2spark(recommendations)
572
+ ground_truth = convert2spark(ground_truth)
573
+ ground_truth_users = convert2spark(ground_truth_users)
574
+
575
+ true_items_by_users = ground_truth.groupby("user_idx").agg(sf.collect_set("item_idx").alias("ground_truth"))
576
+
577
+ group_on = ["item_idx"]
578
+ if "user_idx" in self.prev_policy_weights.columns:
579
+ group_on.append("user_idx")
580
+ recommendations = get_top_k_recs(recommendations, k=max_k)
581
+
582
+ recommendations = recommendations.join(self.prev_policy_weights, on=group_on, how="left").na.fill(
583
+ 0.0, subset=["prev_relevance"]
584
+ )
585
+
586
+ recommendations = self._reweighing(recommendations)
587
+
588
+ weight_type = recommendations.schema["weight"].dataType
589
+ item_type = ground_truth.schema["item_idx"].dataType
590
+
591
+ recommendations = filter_sort(recommendations, "weight")
592
+
593
+ if ground_truth_users is not None:
594
+ true_items_by_users = true_items_by_users.join(ground_truth_users, on="user_idx", how="right")
595
+
596
+ recommendations = recommendations.join(true_items_by_users, how="right", on=["user_idx"])
597
+ return fill_na_with_empty_array(
598
+ fill_na_with_empty_array(recommendations, "pred", item_type),
599
+ "weight",
600
+ weight_type,
601
+ )