replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from typing import Dict, List, Optional, Tuple, Union
|
|
3
3
|
|
|
4
|
-
from replay.utils import PandasDataFrame,
|
|
4
|
+
from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
5
5
|
|
|
6
6
|
from .base_metric import Metric, MetricsDataFrameLike, MetricsReturnType
|
|
7
7
|
from .coverage import Coverage
|
|
@@ -10,7 +10,6 @@ from .recall import Recall
|
|
|
10
10
|
from .surprisal import Surprisal
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
# pylint: disable=too-few-public-methods
|
|
14
13
|
class OfflineMetrics:
|
|
15
14
|
"""
|
|
16
15
|
Designed for efficient calculation of offline metrics provided by the RePlay.
|
|
@@ -146,7 +145,6 @@ class OfflineMetrics:
|
|
|
146
145
|
"Recall": ["ground_truth"],
|
|
147
146
|
}
|
|
148
147
|
|
|
149
|
-
# pylint: disable=too-many-arguments
|
|
150
148
|
def __init__(
|
|
151
149
|
self,
|
|
152
150
|
metrics: List[Metric],
|
|
@@ -220,15 +218,11 @@ class OfflineMetrics:
|
|
|
220
218
|
default_metric._check_duplicates_polars(recommendations)
|
|
221
219
|
unchanged_recs = recommendations
|
|
222
220
|
|
|
223
|
-
|
|
224
|
-
result_dict["default"] = default_metric._get_enriched_recommendations(
|
|
225
|
-
recommendations, ground_truth
|
|
226
|
-
)
|
|
221
|
+
result_dict["default"] = default_metric._get_enriched_recommendations(recommendations, ground_truth)
|
|
227
222
|
|
|
228
223
|
for metric in self.metrics:
|
|
229
224
|
# find Coverage
|
|
230
225
|
if metric.__class__.__name__ == "Coverage":
|
|
231
|
-
# pylint: disable=protected-access
|
|
232
226
|
result_dict["Coverage"] = Coverage(
|
|
233
227
|
topk=2,
|
|
234
228
|
query_column=query_column,
|
|
@@ -244,9 +238,7 @@ class OfflineMetrics:
|
|
|
244
238
|
item_column=item_column,
|
|
245
239
|
rating_column=rating_column,
|
|
246
240
|
)
|
|
247
|
-
cur_recs = novelty_metric._get_enriched_recommendations(
|
|
248
|
-
unchanged_recs, train
|
|
249
|
-
)
|
|
241
|
+
cur_recs = novelty_metric._get_enriched_recommendations(unchanged_recs, train)
|
|
250
242
|
if is_spark:
|
|
251
243
|
cur_recs = cur_recs.withColumnRenamed("ground_truth", "train")
|
|
252
244
|
else:
|
|
@@ -265,12 +257,10 @@ class OfflineMetrics:
|
|
|
265
257
|
|
|
266
258
|
return result_dict, train
|
|
267
259
|
|
|
268
|
-
# pylint: disable=no-self-use
|
|
269
260
|
def _cache_dataframes(self, dataframes: Dict[str, SparkDataFrame]) -> None:
|
|
270
261
|
for data in dataframes.values():
|
|
271
262
|
data.cache()
|
|
272
263
|
|
|
273
|
-
# pylint: disable=no-self-use
|
|
274
264
|
def _unpersist_dataframes(self, dataframes: Dict[str, SparkDataFrame]) -> None:
|
|
275
265
|
for data in dataframes.values():
|
|
276
266
|
data.unpersist()
|
|
@@ -294,22 +284,18 @@ class OfflineMetrics:
|
|
|
294
284
|
else:
|
|
295
285
|
metric_args["recs"] = enriched_recs_dict["default"]
|
|
296
286
|
|
|
297
|
-
# pylint: disable=protected-access
|
|
298
287
|
if is_spark:
|
|
299
288
|
result.update(metric._spark_compute(**metric_args))
|
|
300
289
|
else:
|
|
301
290
|
result.update(metric._polars_compute(**metric_args))
|
|
302
291
|
return result
|
|
303
292
|
|
|
304
|
-
# pylint: disable=no-self-use
|
|
305
293
|
def _check_dataframes_types(
|
|
306
294
|
self,
|
|
307
295
|
recommendations: MetricsDataFrameLike,
|
|
308
296
|
ground_truth: MetricsDataFrameLike,
|
|
309
297
|
train: Optional[MetricsDataFrameLike],
|
|
310
|
-
base_recommendations: Optional[
|
|
311
|
-
Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]
|
|
312
|
-
],
|
|
298
|
+
base_recommendations: Optional[Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]],
|
|
313
299
|
) -> None:
|
|
314
300
|
types = set()
|
|
315
301
|
types.add(type(recommendations))
|
|
@@ -317,7 +303,7 @@ class OfflineMetrics:
|
|
|
317
303
|
if train is not None:
|
|
318
304
|
types.add(type(train))
|
|
319
305
|
if isinstance(base_recommendations, dict):
|
|
320
|
-
for
|
|
306
|
+
for df in base_recommendations.values():
|
|
321
307
|
if not isinstance(df, list):
|
|
322
308
|
types.add(type(df))
|
|
323
309
|
else:
|
|
@@ -327,7 +313,8 @@ class OfflineMetrics:
|
|
|
327
313
|
types.add(type(base_recommendations))
|
|
328
314
|
|
|
329
315
|
if len(types) != 1:
|
|
330
|
-
|
|
316
|
+
msg = "All given data frames must have the same type"
|
|
317
|
+
raise ValueError(msg)
|
|
331
318
|
|
|
332
319
|
def _check_query_column_present(
|
|
333
320
|
self,
|
|
@@ -350,7 +337,8 @@ class OfflineMetrics:
|
|
|
350
337
|
dataset_names = dataset.columns
|
|
351
338
|
|
|
352
339
|
if not isinstance(dataset, dict) and query_column not in dataset_names:
|
|
353
|
-
|
|
340
|
+
msg = f"Query column {query_column} is not present in {dataset_name} dataframe"
|
|
341
|
+
raise KeyError(msg)
|
|
354
342
|
|
|
355
343
|
def _get_unique_queries(
|
|
356
344
|
self,
|
|
@@ -386,14 +374,12 @@ class OfflineMetrics:
|
|
|
386
374
|
if queries.issubset(other_queries) is False:
|
|
387
375
|
warnings.warn(f"{dataset_name} contains queries that are not presented in recommendations")
|
|
388
376
|
|
|
389
|
-
def __call__( #
|
|
377
|
+
def __call__( # noqa: C901
|
|
390
378
|
self,
|
|
391
379
|
recommendations: MetricsDataFrameLike,
|
|
392
380
|
ground_truth: MetricsDataFrameLike,
|
|
393
381
|
train: Optional[MetricsDataFrameLike] = None,
|
|
394
|
-
base_recommendations: Optional[
|
|
395
|
-
Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]
|
|
396
|
-
] = None,
|
|
382
|
+
base_recommendations: Optional[Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]] = None,
|
|
397
383
|
) -> Dict[str, float]:
|
|
398
384
|
"""
|
|
399
385
|
Compute metrics.
|
|
@@ -424,9 +410,7 @@ class OfflineMetrics:
|
|
|
424
410
|
|
|
425
411
|
:return: metric values
|
|
426
412
|
"""
|
|
427
|
-
self._check_dataframes_types(
|
|
428
|
-
recommendations, ground_truth, train, base_recommendations
|
|
429
|
-
)
|
|
413
|
+
self._check_dataframes_types(recommendations, ground_truth, train, base_recommendations)
|
|
430
414
|
|
|
431
415
|
if len(self.main_metrics) > 0:
|
|
432
416
|
query_column = self.main_metrics[0].query_column
|
|
@@ -443,31 +427,22 @@ class OfflineMetrics:
|
|
|
443
427
|
|
|
444
428
|
if train is not None:
|
|
445
429
|
self._check_query_column_present(train, query_column, "train")
|
|
446
|
-
self._check_contains(
|
|
447
|
-
recs_queries,
|
|
448
|
-
self._get_unique_queries(train, query_column),
|
|
449
|
-
"train"
|
|
450
|
-
)
|
|
430
|
+
self._check_contains(recs_queries, self._get_unique_queries(train, query_column), "train")
|
|
451
431
|
if base_recommendations is not None:
|
|
452
|
-
if
|
|
453
|
-
|
|
432
|
+
if not isinstance(base_recommendations, dict) or isinstance(
|
|
433
|
+
next(iter(base_recommendations.values())), list
|
|
434
|
+
):
|
|
454
435
|
base_recommendations = {"base_recommendations": base_recommendations}
|
|
455
436
|
for name, dataset in base_recommendations.items():
|
|
456
437
|
self._check_query_column_present(dataset, query_column, name)
|
|
457
|
-
self._check_contains(
|
|
458
|
-
recs_queries,
|
|
459
|
-
self._get_unique_queries(dataset, query_column),
|
|
460
|
-
name
|
|
461
|
-
)
|
|
438
|
+
self._check_contains(recs_queries, self._get_unique_queries(dataset, query_column), name)
|
|
462
439
|
|
|
463
440
|
result = {}
|
|
464
441
|
if isinstance(recommendations, (SparkDataFrame, PolarsDataFrame)):
|
|
465
442
|
is_spark = isinstance(recommendations, SparkDataFrame)
|
|
466
443
|
assert isinstance(ground_truth, type(recommendations))
|
|
467
444
|
assert train is None or isinstance(train, type(recommendations))
|
|
468
|
-
enriched_recs_dict, train = self._get_enriched_recommendations(
|
|
469
|
-
recommendations, ground_truth, train
|
|
470
|
-
)
|
|
445
|
+
enriched_recs_dict, train = self._get_enriched_recommendations(recommendations, ground_truth, train)
|
|
471
446
|
|
|
472
447
|
if is_spark and self._allow_caching:
|
|
473
448
|
self._cache_dataframes(enriched_recs_dict)
|
|
@@ -480,12 +455,8 @@ class OfflineMetrics:
|
|
|
480
455
|
"train": train,
|
|
481
456
|
}
|
|
482
457
|
for metric in self.metrics:
|
|
483
|
-
args_to_call: Dict[str, Union[PandasDataFrame, Dict]] = {
|
|
484
|
-
|
|
485
|
-
}
|
|
486
|
-
for data_name in self._metrics_call_requirement_map[
|
|
487
|
-
str(metric.__class__.__name__)
|
|
488
|
-
]:
|
|
458
|
+
args_to_call: Dict[str, Union[PandasDataFrame, Dict]] = {"recommendations": recommendations}
|
|
459
|
+
for data_name in self._metrics_call_requirement_map[str(metric.__class__.__name__)]:
|
|
489
460
|
args_to_call[data_name] = current_map[data_name]
|
|
490
461
|
result.update(metric(**args_to_call))
|
|
491
462
|
unexpectedness_result = {}
|
|
@@ -493,23 +464,17 @@ class OfflineMetrics:
|
|
|
493
464
|
|
|
494
465
|
if len(self.unexpectedness_metric) != 0:
|
|
495
466
|
if base_recommendations is None:
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
if isinstance(base_recommendations, dict) and not isinstance(
|
|
500
|
-
list(base_recommendations.values())[0], list
|
|
501
|
-
):
|
|
467
|
+
msg = "Can not calculate Unexpectedness because base_recommendations is None"
|
|
468
|
+
raise ValueError(msg)
|
|
469
|
+
first_element = next(iter(base_recommendations.values()))
|
|
470
|
+
if isinstance(base_recommendations, dict) and not isinstance(first_element, list):
|
|
502
471
|
for unexp in self.unexpectedness_metric:
|
|
503
472
|
for model_name in base_recommendations:
|
|
504
|
-
cur_result = unexp(
|
|
505
|
-
recommendations, base_recommendations[model_name]
|
|
506
|
-
)
|
|
473
|
+
cur_result = unexp(recommendations, base_recommendations[model_name])
|
|
507
474
|
for metric_name in cur_result:
|
|
508
475
|
splitted = metric_name.split("@")
|
|
509
476
|
splitted[0] += "_" + model_name
|
|
510
|
-
unexpectedness_result["@".join(splitted)] = cur_result[
|
|
511
|
-
metric_name
|
|
512
|
-
]
|
|
477
|
+
unexpectedness_result["@".join(splitted)] = cur_result[metric_name]
|
|
513
478
|
|
|
514
479
|
if len(self.diversity_metric) != 0:
|
|
515
480
|
for diversity in self.diversity_metric:
|
replay/metrics/precision.py
CHANGED
|
@@ -3,16 +3,15 @@ from typing import List
|
|
|
3
3
|
from .base_metric import Metric
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
# pylint: disable=too-few-public-methods
|
|
7
6
|
class Precision(Metric):
|
|
8
7
|
"""
|
|
9
8
|
Mean percentage of relevant items among top ``K`` recommendations.
|
|
10
9
|
|
|
11
10
|
.. math::
|
|
12
|
-
Precision@K(i) = \\frac {
|
|
11
|
+
Precision@K(i) = \\frac {\\sum_{j=1}^{K}\\mathbb{1}_{r_{ij}}}{K}
|
|
13
12
|
|
|
14
13
|
.. math::
|
|
15
|
-
Precision@K = \\frac {
|
|
14
|
+
Precision@K = \\frac {\\sum_{i=1}^{N}Precision@K(i)}{N}
|
|
16
15
|
|
|
17
16
|
:math:`\\mathbb{1}_{r_{ij}}` -- indicator function showing that user :math:`i` interacted with item :math:`j`
|
|
18
17
|
|
|
@@ -62,9 +61,7 @@ class Precision(Metric):
|
|
|
62
61
|
"""
|
|
63
62
|
|
|
64
63
|
@staticmethod
|
|
65
|
-
def _get_metric_value_by_user(
|
|
66
|
-
ks: List[int], ground_truth: List, pred: List
|
|
67
|
-
) -> List[float]:
|
|
64
|
+
def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
|
|
68
65
|
if not ground_truth or not pred:
|
|
69
66
|
return [0.0 for _ in ks]
|
|
70
67
|
set_gt = set(ground_truth)
|
replay/metrics/recall.py
CHANGED
|
@@ -3,7 +3,6 @@ from typing import List
|
|
|
3
3
|
from .base_metric import Metric
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
# pylint: disable=too-few-public-methods
|
|
7
6
|
class Recall(Metric):
|
|
8
7
|
"""
|
|
9
8
|
Recall measures the coverage of the recommended items, and is defined as:
|
|
@@ -11,10 +10,10 @@ class Recall(Metric):
|
|
|
11
10
|
Mean percentage of relevant items, that was shown among top ``K`` recommendations.
|
|
12
11
|
|
|
13
12
|
.. math::
|
|
14
|
-
Recall@K(i) = \\frac {
|
|
13
|
+
Recall@K(i) = \\frac {\\sum_{j=1}^{K}\\mathbb{1}_{r_{ij}}}{|Rel_i|}
|
|
15
14
|
|
|
16
15
|
.. math::
|
|
17
|
-
Recall@K = \\frac {
|
|
16
|
+
Recall@K = \\frac {\\sum_{i=1}^{N}Recall@K(i)}{N}
|
|
18
17
|
|
|
19
18
|
:math:`\\mathbb{1}_{r_{ij}}` -- indicator function showing that user :math:`i` interacted with item :math:`j`
|
|
20
19
|
|
|
@@ -66,9 +65,7 @@ class Recall(Metric):
|
|
|
66
65
|
"""
|
|
67
66
|
|
|
68
67
|
@staticmethod
|
|
69
|
-
def _get_metric_value_by_user(
|
|
70
|
-
ks: List[int], ground_truth: List, pred: List
|
|
71
|
-
) -> List[float]:
|
|
68
|
+
def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
|
|
72
69
|
if not ground_truth or not pred:
|
|
73
70
|
return [0.0 for _ in ks]
|
|
74
71
|
set_gt = set(ground_truth)
|
replay/metrics/rocauc.py
CHANGED
|
@@ -3,7 +3,6 @@ from typing import List
|
|
|
3
3
|
from .base_metric import Metric
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
# pylint: disable=too-few-public-methods
|
|
7
6
|
class RocAuc(Metric):
|
|
8
7
|
"""
|
|
9
8
|
Receiver Operating Characteristic/Area Under the Curve is the aggregated performance measure,
|
|
@@ -13,21 +12,21 @@ class RocAuc(Metric):
|
|
|
13
12
|
The bigger the value of AUC, the better the classification model.
|
|
14
13
|
|
|
15
14
|
.. math::
|
|
16
|
-
ROCAUC@K(i) = \\frac {
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
{
|
|
15
|
+
ROCAUC@K(i) = \\frac {\\sum_{s=1}^{K}\\sum_{t=1}^{K}
|
|
16
|
+
\\mathbb{1}_{r_{si}<r_{ti}}
|
|
17
|
+
\\mathbb{1}_{gt_{si}<gt_{ti}}}
|
|
18
|
+
{\\sum_{s=1}^{K}\\sum_{t=1}^{K} \\mathbb{1}_{gt_{si}<gt_{tj}}}
|
|
20
19
|
|
|
21
20
|
:math:`\\mathbb{1}_{r_{si}<r_{ti}}` -- indicator function showing that recommendation score for
|
|
22
21
|
user :math:`i` for item :math:`s` is bigger than for item :math:`t`
|
|
23
22
|
|
|
24
|
-
:math
|
|
23
|
+
:math:`\\mathbb{1}_{gt_{si}<gt_{ti}}` -- indicator function showing that
|
|
25
24
|
user :math:`i` values item :math:`s` more than item :math:`t`.
|
|
26
25
|
|
|
27
26
|
Metric is averaged by all users.
|
|
28
27
|
|
|
29
28
|
.. math::
|
|
30
|
-
ROCAUC@K = \\frac {
|
|
29
|
+
ROCAUC@K = \\frac {\\sum_{i=1}^{N}ROCAUC@K(i)}{N}
|
|
31
30
|
|
|
32
31
|
>>> recommendations
|
|
33
32
|
query_id item_id rating
|
|
@@ -75,9 +74,7 @@ class RocAuc(Metric):
|
|
|
75
74
|
"""
|
|
76
75
|
|
|
77
76
|
@staticmethod
|
|
78
|
-
def _get_metric_value_by_user(
|
|
79
|
-
ks: List[int], ground_truth: List, pred: List
|
|
80
|
-
) -> List[float]:
|
|
77
|
+
def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
|
|
81
78
|
if not ground_truth or not pred:
|
|
82
79
|
return [0.0 for _ in ks]
|
|
83
80
|
set_gt = set(ground_truth)
|
replay/metrics/surprisal.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import Dict, List, Union
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import polars as pl
|
|
6
6
|
|
|
7
|
-
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame,
|
|
7
|
+
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
8
8
|
|
|
9
9
|
from .base_metric import Metric, MetricsDataFrameLike, MetricsReturnType
|
|
10
10
|
|
|
@@ -12,13 +12,12 @@ if PYSPARK_AVAILABLE:
|
|
|
12
12
|
from pyspark.sql import functions as sf
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
# pylint: disable=too-few-public-methods
|
|
16
15
|
class Surprisal(Metric):
|
|
17
16
|
"""
|
|
18
17
|
Measures how many surprising rare items are present in recommendations.
|
|
19
18
|
|
|
20
19
|
.. math::
|
|
21
|
-
\\textit{Self-Information}(j)=
|
|
20
|
+
\\textit{Self-Information}(j)= -\\log_2 \\frac {u_j}{N}
|
|
22
21
|
|
|
23
22
|
:math:`u_j` -- number of users that interacted with item :math:`j`.
|
|
24
23
|
Cold items are treated as if they were rated by 1 user.
|
|
@@ -32,12 +31,12 @@ class Surprisal(Metric):
|
|
|
32
31
|
Recommendation list surprisal is the average surprisal of items in it.
|
|
33
32
|
|
|
34
33
|
.. math::
|
|
35
|
-
Surprisal@K(i) = \\frac {
|
|
34
|
+
Surprisal@K(i) = \\frac {\\sum_{j=1}^{K}Surprisal(j)} {K}
|
|
36
35
|
|
|
37
36
|
Final metric is averaged by users.
|
|
38
37
|
|
|
39
38
|
.. math::
|
|
40
|
-
Surprisal@K = \\frac {
|
|
39
|
+
Surprisal@K = \\frac {\\sum_{i=1}^{N}Surprisal@K(i)}{N}
|
|
41
40
|
|
|
42
41
|
:math:`N` -- the number of users.
|
|
43
42
|
|
|
@@ -83,7 +82,6 @@ class Surprisal(Metric):
|
|
|
83
82
|
<BLANKLINE>
|
|
84
83
|
"""
|
|
85
84
|
|
|
86
|
-
# pylint: disable=no-self-use
|
|
87
85
|
def _get_weights(self, train: Dict) -> Dict:
|
|
88
86
|
n_users = len(train.keys())
|
|
89
87
|
items_counter = defaultdict(set)
|
|
@@ -102,7 +100,6 @@ class Surprisal(Metric):
|
|
|
102
100
|
recs_with_weights[user] = [weights.get(i, 1) for i in items]
|
|
103
101
|
return recs_with_weights
|
|
104
102
|
|
|
105
|
-
# pylint: disable=arguments-renamed
|
|
106
103
|
def _get_enriched_recommendations(
|
|
107
104
|
self,
|
|
108
105
|
recommendations: Union[PolarsDataFrame, SparkDataFrame],
|
|
@@ -113,38 +110,28 @@ class Surprisal(Metric):
|
|
|
113
110
|
else:
|
|
114
111
|
return self._get_enriched_recommendations_polars(recommendations, train)
|
|
115
112
|
|
|
116
|
-
def _get_enriched_recommendations_spark(
|
|
113
|
+
def _get_enriched_recommendations_spark(
|
|
117
114
|
self, recommendations: SparkDataFrame, train: SparkDataFrame
|
|
118
115
|
) -> SparkDataFrame:
|
|
119
116
|
n_users = train.select(self.query_column).distinct().count()
|
|
120
117
|
item_weights = train.groupby(self.item_column).agg(
|
|
121
|
-
(
|
|
122
|
-
sf.log2(n_users / sf.countDistinct(self.query_column)) / np.log2(n_users)
|
|
123
|
-
).alias("weight")
|
|
118
|
+
(sf.log2(n_users / sf.countDistinct(self.query_column)) / np.log2(n_users)).alias("weight")
|
|
124
119
|
)
|
|
125
|
-
recommendations = recommendations.join(
|
|
126
|
-
item_weights, on=self.item_column, how="left"
|
|
127
|
-
).fillna(1.0)
|
|
120
|
+
recommendations = recommendations.join(item_weights, on=self.item_column, how="left").fillna(1.0)
|
|
128
121
|
|
|
129
|
-
sorted_by_score_recommendations = self._get_items_list_per_user(
|
|
130
|
-
recommendations, "weight"
|
|
131
|
-
)
|
|
122
|
+
sorted_by_score_recommendations = self._get_items_list_per_user(recommendations, "weight")
|
|
132
123
|
return self._rearrange_columns(sorted_by_score_recommendations)
|
|
133
124
|
|
|
134
|
-
def _get_enriched_recommendations_polars(
|
|
125
|
+
def _get_enriched_recommendations_polars(
|
|
135
126
|
self, recommendations: PolarsDataFrame, train: PolarsDataFrame
|
|
136
127
|
) -> PolarsDataFrame:
|
|
137
128
|
n_users = train.select(self.query_column).n_unique()
|
|
138
129
|
item_weights = train.group_by(self.item_column).agg(
|
|
139
130
|
(np.log2(n_users / pl.col(self.query_column).n_unique()) / np.log2(n_users)).alias("weight")
|
|
140
131
|
)
|
|
141
|
-
recommendations = recommendations.join(
|
|
142
|
-
item_weights, on=self.item_column, how="left"
|
|
143
|
-
).fill_nan(1.0)
|
|
132
|
+
recommendations = recommendations.join(item_weights, on=self.item_column, how="left").fill_nan(1.0)
|
|
144
133
|
|
|
145
|
-
sorted_by_score_recommendations = self._get_items_list_per_user(
|
|
146
|
-
recommendations, "weight"
|
|
147
|
-
)
|
|
134
|
+
sorted_by_score_recommendations = self._get_items_list_per_user(recommendations, "weight")
|
|
148
135
|
return self._rearrange_columns(sorted_by_score_recommendations)
|
|
149
136
|
|
|
150
137
|
def __call__(
|
|
@@ -183,9 +170,7 @@ class Surprisal(Metric):
|
|
|
183
170
|
else self._convert_dict_to_dict_with_score(recommendations)
|
|
184
171
|
)
|
|
185
172
|
self._check_duplicates_dict(recommendations)
|
|
186
|
-
train = (
|
|
187
|
-
self._convert_pandas_to_dict_without_score(train) if is_pandas else train
|
|
188
|
-
)
|
|
173
|
+
train = self._convert_pandas_to_dict_without_score(train) if is_pandas else train
|
|
189
174
|
assert isinstance(train, dict)
|
|
190
175
|
|
|
191
176
|
weights = self._get_recommendation_weights(recommendations, train)
|
|
@@ -196,9 +181,7 @@ class Surprisal(Metric):
|
|
|
196
181
|
)
|
|
197
182
|
|
|
198
183
|
@staticmethod
|
|
199
|
-
def _get_metric_value_by_user(
|
|
200
|
-
ks: List[int], pred_item_ids: List, pred_weights: List
|
|
201
|
-
) -> List[float]:
|
|
184
|
+
def _get_metric_value_by_user(ks: List[int], pred_item_ids: List, pred_weights: List) -> List[float]:
|
|
202
185
|
if not pred_item_ids:
|
|
203
186
|
return [0.0 for _ in ks]
|
|
204
187
|
res = []
|
|
@@ -28,7 +28,6 @@ DEFAULT_METRICS: List[MetricName] = [
|
|
|
28
28
|
DEFAULT_KS: List[int] = [1, 5, 10, 20]
|
|
29
29
|
|
|
30
30
|
|
|
31
|
-
# pylint: disable=too-many-instance-attributes
|
|
32
31
|
@dataclass
|
|
33
32
|
class _MetricRequirements:
|
|
34
33
|
"""
|
|
@@ -113,7 +112,6 @@ class _CoverageHelper:
|
|
|
113
112
|
self._train_hist = torch.zeros(self.item_count)
|
|
114
113
|
self._pred_hist: Dict[int, torch.Tensor] = {k: torch.zeros(self.item_count) for k in self._top_k}
|
|
115
114
|
|
|
116
|
-
# pylint: disable=attribute-defined-outside-init
|
|
117
115
|
def _ensure_hists_on_device(self, device: torch.device) -> None:
|
|
118
116
|
self._train_hist = self._train_hist.to(device)
|
|
119
117
|
for k in self._top_k:
|
|
@@ -192,13 +190,11 @@ class _MetricBuilder(abc.ABC):
|
|
|
192
190
|
"""
|
|
193
191
|
|
|
194
192
|
|
|
195
|
-
# pylint: disable=too-many-instance-attributes
|
|
196
193
|
class TorchMetricsBuilder(_MetricBuilder):
|
|
197
194
|
"""
|
|
198
195
|
Computes specified metrics over multiple batches
|
|
199
196
|
"""
|
|
200
197
|
|
|
201
|
-
# pylint: disable=dangerous-default-value
|
|
202
198
|
def __init__(
|
|
203
199
|
self,
|
|
204
200
|
metrics: List[MetricName] = DEFAULT_METRICS,
|
replay/metrics/unexpectedness.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
from typing import List, Optional, Union
|
|
2
2
|
|
|
3
|
-
from replay.utils import PandasDataFrame,
|
|
3
|
+
from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
4
4
|
|
|
5
5
|
from .base_metric import Metric, MetricsDataFrameLike, MetricsReturnType
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
# pylint: disable=too-few-public-methods
|
|
9
8
|
class Unexpectedness(Metric):
|
|
10
9
|
"""
|
|
11
10
|
Fraction of recommended items that are not present in some baseline\
|
|
@@ -13,11 +12,12 @@ class Unexpectedness(Metric):
|
|
|
13
12
|
|
|
14
13
|
.. math::
|
|
15
14
|
Unexpectedness@K(i) = 1 -
|
|
16
|
-
\\frac {
|
|
15
|
+
\\frac {\\parallel R^{i}_{1..\\min(K, \\parallel R^{i} \\parallel)}
|
|
16
|
+
\\cap BR^{i}_{1..\\min(K, \\parallel BR^{i} \\parallel)} \\parallel}
|
|
17
17
|
{K}
|
|
18
18
|
|
|
19
19
|
.. math::
|
|
20
|
-
Unexpectedness@K = \\frac {1}{N}
|
|
20
|
+
Unexpectedness@K = \\frac {1}{N}\\sum_{i=1}^{N}Unexpectedness@K(i)
|
|
21
21
|
|
|
22
22
|
:math:`R_{1..j}^{i}` -- the first :math:`j` recommendations for the :math:`i`-th user.
|
|
23
23
|
|
|
@@ -61,7 +61,7 @@ class Unexpectedness(Metric):
|
|
|
61
61
|
'Unexpectedness-ConfidenceInterval@4': 0.0}
|
|
62
62
|
<BLANKLINE>
|
|
63
63
|
"""
|
|
64
|
-
|
|
64
|
+
|
|
65
65
|
def _get_enriched_recommendations(
|
|
66
66
|
self,
|
|
67
67
|
recommendations: Union[PolarsDataFrame, SparkDataFrame],
|
|
@@ -72,14 +72,14 @@ class Unexpectedness(Metric):
|
|
|
72
72
|
else:
|
|
73
73
|
return self._get_enriched_recommendations_polars(recommendations, base_recommendations)
|
|
74
74
|
|
|
75
|
-
def _get_enriched_recommendations_spark(
|
|
75
|
+
def _get_enriched_recommendations_spark(
|
|
76
76
|
self, recommendations: SparkDataFrame, base_recommendations: SparkDataFrame
|
|
77
77
|
) -> SparkDataFrame:
|
|
78
78
|
sorted_by_score_recommendations = self._get_items_list_per_user(recommendations)
|
|
79
79
|
|
|
80
|
-
sorted_by_score_base_recommendations = self._get_items_list_per_user(
|
|
81
|
-
|
|
82
|
-
)
|
|
80
|
+
sorted_by_score_base_recommendations = self._get_items_list_per_user(base_recommendations).withColumnRenamed(
|
|
81
|
+
"pred_item_id", "base_pred_item_id"
|
|
82
|
+
)
|
|
83
83
|
|
|
84
84
|
enriched_recommendations = sorted_by_score_recommendations.join(
|
|
85
85
|
sorted_by_score_base_recommendations, how="left", on=self.query_column
|
|
@@ -87,14 +87,14 @@ class Unexpectedness(Metric):
|
|
|
87
87
|
|
|
88
88
|
return self._rearrange_columns(enriched_recommendations)
|
|
89
89
|
|
|
90
|
-
def _get_enriched_recommendations_polars(
|
|
90
|
+
def _get_enriched_recommendations_polars(
|
|
91
91
|
self, recommendations: PolarsDataFrame, base_recommendations: PolarsDataFrame
|
|
92
92
|
) -> PolarsDataFrame:
|
|
93
93
|
sorted_by_score_recommendations = self._get_items_list_per_user(recommendations)
|
|
94
94
|
|
|
95
|
-
sorted_by_score_base_recommendations = self._get_items_list_per_user(
|
|
96
|
-
|
|
97
|
-
)
|
|
95
|
+
sorted_by_score_base_recommendations = self._get_items_list_per_user(base_recommendations).rename(
|
|
96
|
+
{"pred_item_id": "base_pred_item_id"}
|
|
97
|
+
)
|
|
98
98
|
|
|
99
99
|
enriched_recommendations = sorted_by_score_recommendations.join(
|
|
100
100
|
sorted_by_score_base_recommendations, how="left", on=self.query_column
|
|
@@ -152,12 +152,7 @@ class Unexpectedness(Metric):
|
|
|
152
152
|
)
|
|
153
153
|
|
|
154
154
|
@staticmethod
|
|
155
|
-
def _get_metric_value_by_user(
|
|
156
|
-
ks: List[int], base_recs: Optional[List], recs: Optional[List]
|
|
157
|
-
) -> List[float]:
|
|
155
|
+
def _get_metric_value_by_user(ks: List[int], base_recs: Optional[List], recs: Optional[List]) -> List[float]:
|
|
158
156
|
if not base_recs or not recs:
|
|
159
157
|
return [0.0 for _ in ks]
|
|
160
|
-
|
|
161
|
-
for k in ks:
|
|
162
|
-
res.append(1.0 - len(set(recs[:k]) & set(base_recs[:k])) / k)
|
|
163
|
-
return res
|
|
158
|
+
return [1.0 - len(set(recs[:k]) & set(base_recs[:k])) / k for k in ks]
|
replay/models/__init__.py
CHANGED
|
@@ -12,6 +12,7 @@ from .association_rules import AssociationRulesItemRec
|
|
|
12
12
|
from .base_rec import Recommender
|
|
13
13
|
from .cat_pop_rec import CatPopRec
|
|
14
14
|
from .cluster import ClusterRec
|
|
15
|
+
from .kl_ucb import KLUCB
|
|
15
16
|
from .knn import ItemKNN
|
|
16
17
|
from .pop_rec import PopRec
|
|
17
18
|
from .query_pop_rec import QueryPopRec
|
|
@@ -19,7 +20,5 @@ from .random_rec import RandomRec
|
|
|
19
20
|
from .slim import SLIM
|
|
20
21
|
from .thompson_sampling import ThompsonSampling
|
|
21
22
|
from .ucb import UCB
|
|
22
|
-
# pylint: disable=cyclic-import
|
|
23
|
-
from .kl_ucb import KLUCB
|
|
24
23
|
from .wilson import Wilson
|
|
25
24
|
from .word2vec import Word2VecRec
|
replay/models/als.py
CHANGED
|
@@ -2,9 +2,10 @@ from os.path import join
|
|
|
2
2
|
from typing import Optional, Tuple
|
|
3
3
|
|
|
4
4
|
from replay.data import Dataset
|
|
5
|
-
from .base_rec import ItemVectorModel, Recommender
|
|
6
5
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
7
6
|
|
|
7
|
+
from .base_rec import ItemVectorModel, Recommender
|
|
8
|
+
|
|
8
9
|
if PYSPARK_AVAILABLE:
|
|
9
10
|
import pyspark.sql.functions as sf
|
|
10
11
|
from pyspark.ml.recommendation import ALS, ALSModel
|
|
@@ -13,7 +14,6 @@ if PYSPARK_AVAILABLE:
|
|
|
13
14
|
from replay.utils.spark_utils import list_to_vector_udf
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
# pylint: disable=too-many-instance-attributes
|
|
17
17
|
class ALSWrap(Recommender, ItemVectorModel):
|
|
18
18
|
"""Wrapper for `Spark ALS
|
|
19
19
|
<https://spark.apache.org/docs/latest/api/python/pyspark.mllib.html#pyspark.mllib.recommendation.ALS>`_.
|
|
@@ -24,7 +24,6 @@ class ALSWrap(Recommender, ItemVectorModel):
|
|
|
24
24
|
"rank": {"type": "loguniform_int", "args": [8, 256]},
|
|
25
25
|
}
|
|
26
26
|
|
|
27
|
-
# pylint: disable=too-many-arguments
|
|
28
27
|
def __init__(
|
|
29
28
|
self,
|
|
30
29
|
rank: int = 10,
|
|
@@ -98,7 +97,6 @@ class ALSWrap(Recommender, ItemVectorModel):
|
|
|
98
97
|
self.model.itemFactors.unpersist()
|
|
99
98
|
self.model.userFactors.unpersist()
|
|
100
99
|
|
|
101
|
-
# pylint: disable=too-many-arguments
|
|
102
100
|
def _predict(
|
|
103
101
|
self,
|
|
104
102
|
dataset: Optional[Dataset],
|
|
@@ -107,10 +105,8 @@ class ALSWrap(Recommender, ItemVectorModel):
|
|
|
107
105
|
items: SparkDataFrame,
|
|
108
106
|
filter_seen_items: bool = True,
|
|
109
107
|
) -> SparkDataFrame:
|
|
110
|
-
|
|
111
108
|
if (items.count() == self.fit_items.count()) and (
|
|
112
|
-
items.join(self.fit_items, on=self.item_column, how="inner").count()
|
|
113
|
-
== self.fit_items.count()
|
|
109
|
+
items.join(self.fit_items, on=self.item_column, how="inner").count() == self.fit_items.count()
|
|
114
110
|
):
|
|
115
111
|
max_seen = 0
|
|
116
112
|
if filter_seen_items and dataset is not None:
|
|
@@ -125,9 +121,7 @@ class ALSWrap(Recommender, ItemVectorModel):
|
|
|
125
121
|
|
|
126
122
|
recs_als = self.model.recommendForUserSubset(queries, k + max_seen)
|
|
127
123
|
return (
|
|
128
|
-
recs_als.withColumn(
|
|
129
|
-
"recommendations", sf.explode("recommendations")
|
|
130
|
-
)
|
|
124
|
+
recs_als.withColumn("recommendations", sf.explode("recommendations"))
|
|
131
125
|
.withColumn(self.item_column, sf.col(f"recommendations.{self.item_column}"))
|
|
132
126
|
.withColumn(
|
|
133
127
|
self.rating_column,
|
|
@@ -144,7 +138,7 @@ class ALSWrap(Recommender, ItemVectorModel):
|
|
|
144
138
|
def _predict_pairs(
|
|
145
139
|
self,
|
|
146
140
|
pairs: SparkDataFrame,
|
|
147
|
-
dataset: Optional[Dataset] = None,
|
|
141
|
+
dataset: Optional[Dataset] = None, # noqa: ARG002
|
|
148
142
|
) -> SparkDataFrame:
|
|
149
143
|
return (
|
|
150
144
|
self.model.transform(pairs)
|
|
@@ -153,15 +147,13 @@ class ALSWrap(Recommender, ItemVectorModel):
|
|
|
153
147
|
)
|
|
154
148
|
|
|
155
149
|
def _get_features(
|
|
156
|
-
self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
|
|
150
|
+
self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
|
|
157
151
|
) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
|
|
158
152
|
entity = "user" if self.query_column in ids.columns else "item"
|
|
159
153
|
entity_col = self.query_column if self.query_column in ids.columns else self.item_column
|
|
160
154
|
|
|
161
155
|
als_factors = getattr(self.model, f"{entity}Factors")
|
|
162
|
-
als_factors = als_factors.withColumnRenamed(
|
|
163
|
-
"id", entity_col
|
|
164
|
-
).withColumnRenamed("features", f"{entity}_factors")
|
|
156
|
+
als_factors = als_factors.withColumnRenamed("id", entity_col).withColumnRenamed("features", f"{entity}_factors")
|
|
165
157
|
return (
|
|
166
158
|
als_factors.join(ids, how="right", on=entity_col),
|
|
167
159
|
self.model.rank,
|