replay-rec 0.21.0rc0__py3-none-any.whl → 0.21.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/nn/parquet/parquet_module.py +1 -1
  3. replay/metrics/torch_metrics_builder.py +1 -1
  4. replay/models/nn/sequential/callbacks/validation_callback.py +14 -4
  5. replay/nn/lightning/callback/metrics_callback.py +18 -9
  6. replay/nn/lightning/callback/predictions_callback.py +2 -2
  7. replay/nn/loss/base.py +3 -3
  8. replay/nn/loss/login_ce.py +1 -1
  9. replay/nn/sequential/sasrec/model.py +1 -1
  10. replay/nn/sequential/twotower/reader.py +14 -5
  11. replay/nn/transform/template/sasrec.py +3 -3
  12. replay/nn/transform/template/twotower.py +1 -1
  13. {replay_rec-0.21.0rc0.dist-info → replay_rec-0.21.1.dist-info}/METADATA +17 -11
  14. {replay_rec-0.21.0rc0.dist-info → replay_rec-0.21.1.dist-info}/RECORD +17 -72
  15. replay/experimental/__init__.py +0 -0
  16. replay/experimental/metrics/__init__.py +0 -62
  17. replay/experimental/metrics/base_metric.py +0 -603
  18. replay/experimental/metrics/coverage.py +0 -97
  19. replay/experimental/metrics/experiment.py +0 -175
  20. replay/experimental/metrics/hitrate.py +0 -26
  21. replay/experimental/metrics/map.py +0 -30
  22. replay/experimental/metrics/mrr.py +0 -18
  23. replay/experimental/metrics/ncis_precision.py +0 -31
  24. replay/experimental/metrics/ndcg.py +0 -49
  25. replay/experimental/metrics/precision.py +0 -22
  26. replay/experimental/metrics/recall.py +0 -25
  27. replay/experimental/metrics/rocauc.py +0 -49
  28. replay/experimental/metrics/surprisal.py +0 -90
  29. replay/experimental/metrics/unexpectedness.py +0 -76
  30. replay/experimental/models/__init__.py +0 -50
  31. replay/experimental/models/admm_slim.py +0 -257
  32. replay/experimental/models/base_neighbour_rec.py +0 -200
  33. replay/experimental/models/base_rec.py +0 -1386
  34. replay/experimental/models/base_torch_rec.py +0 -234
  35. replay/experimental/models/cql.py +0 -454
  36. replay/experimental/models/ddpg.py +0 -932
  37. replay/experimental/models/dt4rec/__init__.py +0 -0
  38. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  39. replay/experimental/models/dt4rec/gpt1.py +0 -401
  40. replay/experimental/models/dt4rec/trainer.py +0 -127
  41. replay/experimental/models/dt4rec/utils.py +0 -264
  42. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  43. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  44. replay/experimental/models/hierarchical_recommender.py +0 -331
  45. replay/experimental/models/implicit_wrap.py +0 -131
  46. replay/experimental/models/lightfm_wrap.py +0 -303
  47. replay/experimental/models/mult_vae.py +0 -332
  48. replay/experimental/models/neural_ts.py +0 -986
  49. replay/experimental/models/neuromf.py +0 -406
  50. replay/experimental/models/scala_als.py +0 -293
  51. replay/experimental/models/u_lin_ucb.py +0 -115
  52. replay/experimental/nn/data/__init__.py +0 -1
  53. replay/experimental/nn/data/schema_builder.py +0 -102
  54. replay/experimental/preprocessing/__init__.py +0 -3
  55. replay/experimental/preprocessing/data_preparator.py +0 -839
  56. replay/experimental/preprocessing/padder.py +0 -229
  57. replay/experimental/preprocessing/sequence_generator.py +0 -208
  58. replay/experimental/scenarios/__init__.py +0 -1
  59. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  60. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  61. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  62. replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
  63. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  64. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  65. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  66. replay/experimental/utils/__init__.py +0 -0
  67. replay/experimental/utils/logger.py +0 -24
  68. replay/experimental/utils/model_handler.py +0 -186
  69. replay/experimental/utils/session_handler.py +0 -44
  70. {replay_rec-0.21.0rc0.dist-info → replay_rec-0.21.1.dist-info}/WHEEL +0 -0
  71. {replay_rec-0.21.0rc0.dist-info → replay_rec-0.21.1.dist-info}/licenses/LICENSE +0 -0
  72. {replay_rec-0.21.0rc0.dist-info → replay_rec-0.21.1.dist-info}/licenses/NOTICE +0 -0
replay/__init__.py CHANGED
@@ -4,4 +4,4 @@
4
4
  # functionality removed in Python 3.12 is used in downstream packages (like lightfm)
5
5
  import setuptools as _
6
6
 
7
- __version__ = "0.21.0.preview"
7
+ __version__ = "0.21.1"
@@ -94,7 +94,7 @@ class ParquetModule(L.LightningDataModule):
94
94
  missing_splits = [split_name for split_name, split_path in self.datapaths.items() if split_path is None]
95
95
  if missing_splits:
96
96
  msg = (
97
- f"The following dataset paths aren't provided: {','.join(missing_splits)}."
97
+ f"The following dataset paths aren't provided: {','.join(missing_splits)}. "
98
98
  "Make sure to disable these stages in your Lightning Trainer configuration."
99
99
  )
100
100
  warnings.warn(msg, stacklevel=2)
@@ -400,7 +400,7 @@ def metrics_to_df(metrics: Mapping[str, float]) -> PandasDataFrame:
400
400
 
401
401
  metric_name_and_k = metrics_df["m"].str.split("@", expand=True)
402
402
  metrics_df["metric"] = metric_name_and_k[0]
403
- metrics_df["k"] = metric_name_and_k[1]
403
+ metrics_df["k"] = metric_name_and_k[1].astype(int)
404
404
 
405
405
  pivoted_metrics = metrics_df.pivot(index="metric", columns="k", values="v")
406
406
  pivoted_metrics.index.name = None
@@ -162,14 +162,24 @@ class ValidationMetricsCallback(lightning.Callback):
162
162
  @rank_zero_only
163
163
  def print_metrics() -> None:
164
164
  metrics = {}
165
+
165
166
  for name, value in trainer.logged_metrics.items():
166
167
  if "@" in name:
167
168
  metrics[name] = value.item()
168
169
 
169
- if metrics:
170
- metrics_df = metrics_to_df(metrics)
170
+ if not metrics:
171
+ return
171
172
 
172
- print(metrics_df) # noqa: T201
173
- print() # noqa: T201
173
+ if len(self._dataloaders_size) > 1:
174
+ for i in range(len(self._dataloaders_size)):
175
+ suffix = trainer._results.DATALOADER_SUFFIX.format(i)[1:]
176
+ cur_dataloader_metrics = {k.split("/")[0]: v for k, v in metrics.items() if suffix in k}
177
+ metrics_df = metrics_to_df(cur_dataloader_metrics)
178
+
179
+ print(suffix) # noqa: T201
180
+ print(metrics_df, "\n") # noqa: T201
181
+ else:
182
+ metrics_df = metrics_to_df(metrics)
183
+ print(metrics_df, "\n") # noqa: T201
174
184
 
175
185
  print_metrics()
@@ -2,7 +2,6 @@ from typing import Any, Optional
2
2
 
3
3
  import lightning
4
4
  import torch
5
- from lightning.pytorch.utilities.combined_loader import CombinedLoader
6
5
  from lightning.pytorch.utilities.rank_zero import rank_zero_only
7
6
 
8
7
  from replay.metrics.torch_metrics_builder import (
@@ -64,8 +63,8 @@ class ComputeMetricsCallback(lightning.Callback):
64
63
  self._train_column = train_column
65
64
 
66
65
  def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> list[int]:
67
- if isinstance(dataloaders, CombinedLoader):
68
- return [len(dataloader) for dataloader in dataloaders.flattened] # pragma: no cover
66
+ if isinstance(dataloaders, list):
67
+ return [len(dataloader) for dataloader in dataloaders]
69
68
  return [len(dataloaders)]
70
69
 
71
70
  def on_validation_epoch_start(
@@ -123,7 +122,7 @@ class ComputeMetricsCallback(lightning.Callback):
123
122
  batch: dict,
124
123
  batch_idx: int,
125
124
  dataloader_idx: int = 0,
126
- ) -> None: # pragma: no cover
125
+ ) -> None:
127
126
  self._batch_end(
128
127
  trainer,
129
128
  pl_module,
@@ -159,7 +158,7 @@ class ComputeMetricsCallback(lightning.Callback):
159
158
  def on_validation_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None:
160
159
  self._epoch_end(trainer, pl_module)
161
160
 
162
- def on_test_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None: # pragma: no cover
161
+ def on_test_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None:
163
162
  self._epoch_end(trainer, pl_module)
164
163
 
165
164
  def _epoch_end(
@@ -170,14 +169,24 @@ class ComputeMetricsCallback(lightning.Callback):
170
169
  @rank_zero_only
171
170
  def print_metrics() -> None:
172
171
  metrics = {}
172
+
173
173
  for name, value in trainer.logged_metrics.items():
174
174
  if "@" in name:
175
175
  metrics[name] = value.item()
176
176
 
177
- if metrics:
178
- metrics_df = metrics_to_df(metrics)
177
+ if not metrics:
178
+ return
179
179
 
180
- print(metrics_df) # noqa: T201
181
- print() # noqa: T201
180
+ if len(self._dataloaders_size) > 1:
181
+ for i in range(len(self._dataloaders_size)):
182
+ suffix = trainer._results.DATALOADER_SUFFIX.format(i)[1:]
183
+ cur_dataloader_metrics = {k.split("/")[0]: v for k, v in metrics.items() if suffix in k}
184
+ metrics_df = metrics_to_df(cur_dataloader_metrics)
185
+
186
+ print(suffix) # noqa: T201
187
+ print(metrics_df, "\n") # noqa: T201
188
+ else:
189
+ metrics_df = metrics_to_df(metrics)
190
+ print(metrics_df, "\n") # noqa: T201
182
191
 
183
192
  print_metrics()
@@ -15,11 +15,11 @@ from replay.utils import (
15
15
  SparkDataFrame,
16
16
  )
17
17
 
18
- if PYSPARK_AVAILABLE: # pragma: no cover
18
+ if PYSPARK_AVAILABLE:
19
19
  import pyspark.sql.functions as sf
20
20
  from pyspark.sql import SparkSession
21
21
  from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType
22
- else: # pragma: no cover
22
+ else:
23
23
  SparkSession = MissingImport
24
24
 
25
25
 
replay/nn/loss/base.py CHANGED
@@ -85,7 +85,7 @@ class SampledLossBase(torch.nn.Module):
85
85
  # [batch_size, num_negatives] -> [batch_size, 1, num_negatives]
86
86
  negative_labels = negative_labels.unsqueeze(1).repeat(1, seq_len, 1)
87
87
 
88
- if negative_labels.dim() == 3: # pragma: no cover
88
+ if negative_labels.dim() == 3:
89
89
  # [batch_size, seq_len, num_negatives] -> [batch_size, seq_len, 1, num_negatives]
90
90
  negative_labels = negative_labels.unsqueeze(-2)
91
91
  if num_positives != 1:
@@ -119,7 +119,7 @@ class SampledLossBase(torch.nn.Module):
119
119
  positive_labels = positive_labels[target_padding_mask].unsqueeze(-1)
120
120
  assert positive_labels.size() == (masked_batch_size, 1)
121
121
 
122
- if negative_labels.dim() != 1: # pragma: no cover
122
+ if negative_labels.dim() != 1:
123
123
  # [batch_size, seq_len, num_positives, num_negatives] -> [masked_batch_size, num_negatives]
124
124
  negative_labels = negative_labels[target_padding_mask]
125
125
  assert negative_labels.size() == (masked_batch_size, num_negatives)
@@ -183,7 +183,7 @@ def mask_negative_logits(
183
183
  if negative_labels_ignore_index >= 0:
184
184
  negative_logits.masked_fill_(negative_labels == negative_labels_ignore_index, -1e9)
185
185
 
186
- if negative_labels.dim() > 1: # pragma: no cover
186
+ if negative_labels.dim() > 1:
187
187
  # [masked_batch_size, num_negatives] -> [masked_batch_size, 1, num_negatives]
188
188
  negative_labels = negative_labels.unsqueeze(-2)
189
189
 
@@ -74,7 +74,7 @@ class LogInCEBase(SampledLossBase):
74
74
  positive_labels = positive_labels[masked_target_padding_mask]
75
75
  assert positive_labels.size() == (masked_batch_size, num_positives)
76
76
 
77
- if negative_labels.dim() > 1: # pragma: no cover
77
+ if negative_labels.dim() > 1:
78
78
  # [batch_size, seq_len, num_negatives] -> [masked_batch_size, num_negatives]
79
79
  negative_labels = negative_labels[masked_target_padding_mask]
80
80
  assert negative_labels.size() == (masked_batch_size, num_negatives)
@@ -141,7 +141,7 @@ class SasRec(torch.nn.Module):
141
141
  feature_type=FeatureType.CATEGORICAL,
142
142
  embedding_dim=256,
143
143
  padding_value=NUM_UNIQUE_ITEMS,
144
- cardinality=NUM_UNIQUE_ITEMS+1,
144
+ cardinality=NUM_UNIQUE_ITEMS,
145
145
  feature_hint=FeatureHint.ITEM_ID,
146
146
  feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "item_id")]
147
147
  ),
@@ -22,7 +22,6 @@ class FeaturesReader:
22
22
  :param schema: the same tensor schema used in TwoTower model.
23
23
  :param metadata: A dictionary of feature names that
24
24
  associated with its shape and padding_value.\n
25
- Example: {"item_id" : {"shape": 100, "padding": 7657}}.\n
26
25
  For details, see the section :ref:`parquet-processing`.
27
26
  :param path: path to parquet with dataframe of item features.\n
28
27
  **Note:**\n
@@ -30,8 +29,8 @@ class FeaturesReader:
30
29
  2. Every feature for item "tower" in `schema` must contain ``feature_sources`` with the names
31
30
  of the source features to create correct inverse mapping.
32
31
  Also, for each such feature one of the requirements must be met: the ``schema`` for the feature must
33
- contain ``feature_sources`` with a source of type FeatureSource.ITEM_FEATURES
34
- or hint type FeatureHint.ITEM_ID.
32
+ contain ``feature_sources`` with a source of type ``FeatureSource.ITEM_FEATURES``
33
+ or hint type ``FeatureHint.ITEM_ID``.
35
34
 
36
35
  """
37
36
  item_feature_names = [
@@ -81,8 +80,18 @@ class FeaturesReader:
81
80
  self._features = {}
82
81
 
83
82
  for k in features.columns:
84
- dtype = torch.float32 if schema[k].is_num else torch.int64
85
- feature_tensor = torch.asarray(features[k], dtype=dtype)
83
+ dtype = np.float32 if schema[k].is_num else np.int64
84
+ if schema[k].is_list:
85
+ feature = np.asarray(
86
+ features[k].to_list(),
87
+ dtype=dtype,
88
+ )
89
+ else:
90
+ feature = features[k].to_numpy(dtype=dtype)
91
+ feature_tensor = torch.asarray(
92
+ feature,
93
+ dtype=torch.float32 if schema[k].is_num else torch.int64,
94
+ )
86
95
  self._features[k] = feature_tensor
87
96
 
88
97
  def __getitem__(self, key: str) -> torch.Tensor:
@@ -14,7 +14,7 @@ def make_default_sasrec_transforms(
14
14
 
15
15
  Generated pipeline expects input dataset to contain the following columns:
16
16
  1) Query ID column, specified by ``query_column``.
17
- 2) Item ID column, specified in the tensor schema.
17
+ 2) All features specified in the ``tensor_schema``.
18
18
 
19
19
  :param tensor_schema: TensorSchema used to infer feature columns.
20
20
  :param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
@@ -32,12 +32,12 @@ def make_default_sasrec_transforms(
32
32
  ),
33
33
  UnsqueezeTransform("target_padding_mask", -1),
34
34
  UnsqueezeTransform("positive_labels", -1),
35
- GroupTransform({"feature_tensors": [item_column]}),
35
+ GroupTransform({"feature_tensors": tensor_schema.names}),
36
36
  ]
37
37
 
38
38
  val_transforms = [
39
39
  RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
40
- GroupTransform({"feature_tensors": [item_column]}),
40
+ GroupTransform({"feature_tensors": tensor_schema.names}),
41
41
  ]
42
42
  test_transforms = copy.deepcopy(val_transforms)
43
43
 
@@ -13,7 +13,7 @@ def make_default_twotower_transforms(
13
13
 
14
14
  Generated pipeline expects input dataset to contain the following columns:
15
15
  1) Query ID column, specified by ``query_column``.
16
- 2) Item ID column, specified in the tensor schema.
16
+ 2) All features specified in the ``tensor_schema``.
17
17
 
18
18
  :param tensor_schema: TensorSchema used to infer feature columns.
19
19
  :param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: replay-rec
3
- Version: 0.21.0rc0
3
+ Version: 0.21.1
4
4
  Summary: RecSys Library
5
5
  License-Expression: Apache-2.0
6
6
  License-File: LICENSE
@@ -14,23 +14,29 @@ Classifier: Intended Audience :: Developers
14
14
  Classifier: Intended Audience :: Science/Research
15
15
  Classifier: Natural Language :: English
16
16
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
- Requires-Dist: d3rlpy (>=2.8.1,<2.9)
18
- Requires-Dist: implicit (>=0.7.2,<0.8)
19
- Requires-Dist: lightautoml (>=0.4.1,<0.5)
20
- Requires-Dist: lightning (>=2.0.2,<=2.4.0)
21
- Requires-Dist: numba (>=0.50,<1)
17
+ Provides-Extra: spark
18
+ Provides-Extra: torch
19
+ Provides-Extra: torch-cpu
20
+ Requires-Dist: lightning (<2.6.0) ; extra == "torch" or extra == "torch-cpu"
21
+ Requires-Dist: lightning ; extra == "torch"
22
+ Requires-Dist: lightning ; extra == "torch-cpu"
22
23
  Requires-Dist: numpy (>=1.20.0,<2)
23
24
  Requires-Dist: pandas (>=1.3.5,<2.4.0)
24
25
  Requires-Dist: polars (<2.0)
25
- Requires-Dist: psutil (<=7.0.0)
26
+ Requires-Dist: psutil (<=7.0.0) ; extra == "spark"
27
+ Requires-Dist: psutil ; extra == "spark"
26
28
  Requires-Dist: pyarrow (<22.0)
27
- Requires-Dist: pyspark (>=3.0,<3.5)
28
- Requires-Dist: pytorch-optimizer (>=3.8.0,<4)
29
- Requires-Dist: sb-obp (>=0.5.10,<0.6)
29
+ Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark"
30
+ Requires-Dist: pyspark ; extra == "spark"
31
+ Requires-Dist: pytorch-optimizer (>=3.8.0,<3.9.0) ; extra == "torch" or extra == "torch-cpu"
32
+ Requires-Dist: pytorch-optimizer ; extra == "torch"
33
+ Requires-Dist: pytorch-optimizer ; extra == "torch-cpu"
30
34
  Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
31
35
  Requires-Dist: scipy (>=1.8.1,<2.0.0)
32
36
  Requires-Dist: setuptools
33
- Requires-Dist: torch (>=1.8,<3.0.0)
37
+ Requires-Dist: torch (>=1.8,<3.0.0) ; extra == "torch" or extra == "torch-cpu"
38
+ Requires-Dist: torch ; extra == "torch"
39
+ Requires-Dist: torch ; extra == "torch-cpu"
34
40
  Requires-Dist: tqdm (>=4.67,<5)
35
41
  Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
36
42
  Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
@@ -1,4 +1,4 @@
1
- replay/__init__.py,sha256=60XVma5C1iIgad7iEMfhh4Mr8aPrS2yXXo1uA4Dn_zY,233
1
+ replay/__init__.py,sha256=2kRxqt2GF_2mTRxcddaKhR1p-tGp_fVjPLBFC2gI4os,225
2
2
  replay/data/__init__.py,sha256=g5bKRyF76QL_BqlED-31RnS8pBdcyj9loMsx5vAG_0E,301
3
3
  replay/data/dataset.py,sha256=yBl-yJVIokgN4prFY949tHe2UVJC_j5xdaulIoSPvQI,31252
4
4
  replay/data/dataset_utils/__init__.py,sha256=9wUvG8ZwGUvuzLU4zQI5FDcH0WVVo5YLN2ey3DterP0,55
@@ -31,7 +31,7 @@ replay/data/nn/parquet/iterator.py,sha256=X5KXtjdY_uSfMlP9IXBqMzSimBqlAZbYX_Y483
31
31
  replay/data/nn/parquet/metadata/__init__.py,sha256=UZX60ANtjo6zX0p43hU9q8fBldVJNCEmGzXjHqz0MJQ,341
32
32
  replay/data/nn/parquet/metadata/metadata.py,sha256=jJOL8mieXhX18FO9lgaP95MOtO1l7tY63ldxoOAvzwA,3459
33
33
  replay/data/nn/parquet/parquet_dataset.py,sha256=pKthRppp0MstfNwOk9wMrE6wFvDecCtbTKWIri4HGr0,8017
34
- replay/data/nn/parquet/parquet_module.py,sha256=g53lgb-bydDg5P27I4MODnnMcRi1qjpvAw3_QQ9UgxQ,8208
34
+ replay/data/nn/parquet/parquet_module.py,sha256=BSf_ev-XtFTsPV9R3y9YO2qa1JHU4Z1Wp7jsXL6GFjM,8209
35
35
  replay/data/nn/parquet/partitioned_iterable_dataset.py,sha256=BZEh2EiBKMZxi822-doyTbjDkZQQ62SxAp_NhZVZdmk,1938
36
36
  replay/data/nn/parquet/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
37
  replay/data/nn/parquet/utils/compute_length.py,sha256=VWabulpRICy-_Z0ZBXpEmhAIlpXVwTwe9kX2L2XCdbE,2492
@@ -46,61 +46,6 @@ replay/data/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
46
46
  replay/data/utils/batching.py,sha256=jBNhRC5jqNe2pVVlmvFLvjTo86Ud0e_Lj2P0W2yNcKY,2268
47
47
  replay/data/utils/typing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
48
  replay/data/utils/typing/dtype.py,sha256=QJigLH7fv-xIb_s-R_70KTZxOgl2ZJkhEhf_txziRAY,1590
49
- replay/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
- replay/experimental/metrics/__init__.py,sha256=bdQogGbEDVAeH7Ejbb6vpw7bP6CYhftTu_DQuoFRuCA,2861
51
- replay/experimental/metrics/base_metric.py,sha256=0ro9VoSnPtPAximnlcgmQaMNg9zoUN2AHAH_2WgfZiQ,22663
52
- replay/experimental/metrics/coverage.py,sha256=UqYm-WtAlBFZ3kqv8PyLo4qqKiIXmR_CQFAl6H_YdqA,3150
53
- replay/experimental/metrics/experiment.py,sha256=pD2Dyyg4PM1HjbrNrhAspZJP3B-i2So205qBChRGwwc,7337
54
- replay/experimental/metrics/hitrate.py,sha256=TfWJrUyZXabdMr4tn8zqUPGDcYy2yphVCzXmLSHCxY0,675
55
- replay/experimental/metrics/map.py,sha256=S4dKiMpYR0_pu0bqioGMT0kIC1s2aojFP4rddBqMPtM,921
56
- replay/experimental/metrics/mrr.py,sha256=q6I1Cndlwr716mMuYtTMu0lN8Rrp9khxhb49OM2IpV8,530
57
- replay/experimental/metrics/ncis_precision.py,sha256=yrErOhBZvZdNpQPx_AXyktDJatqdWRIHNMyei0QDJtQ,1088
58
- replay/experimental/metrics/ndcg.py,sha256=q3KTsyZCrfvcpEjEnR_kWVB9ZaTFRxnoNRAr2WD0TrU,1538
59
- replay/experimental/metrics/precision.py,sha256=U9pD9yRGeT8uH32BTyQ-W5qsAnbFWu-pqy4XfkcXfCM,664
60
- replay/experimental/metrics/recall.py,sha256=5xRPGxfbVoDFEI5E6dVlZpT4RvnDlWzaktyoqh3a8mc,774
61
- replay/experimental/metrics/rocauc.py,sha256=yq4vW2_bXO8HCjREBZVrHMKeZ054LYvjJmLJTXWPfQA,1675
62
- replay/experimental/metrics/surprisal.py,sha256=CK4_zed2bSMDwC7ZBCS8d8RwGEqt8bh3w3fTpjKiK6Y,3052
63
- replay/experimental/metrics/unexpectedness.py,sha256=JQQXEYHtQM8nqp7X2He4E9ZYwbpdENaK8oQG7sUQT3s,2621
64
- replay/experimental/models/__init__.py,sha256=yeu0PAkqWNqNLDnUYpg0_vpkWT8tG8KmRMybodVFkZ4,1709
65
- replay/experimental/models/admm_slim.py,sha256=dDg2c_5Lk8acykirtsv38Jg1l6kgAoBhRvPHPv5Vfis,8654
66
- replay/experimental/models/base_neighbour_rec.py,sha256=Q2C4rle9FeVIncqgMuhLV6qZbPj2Bz8W_Ao8iQu31TU,7387
67
- replay/experimental/models/base_rec.py,sha256=AmN6-PgIaNzD-sMIndMuRA3TJ0WZBbowCjaSTTgiYrY,54150
68
- replay/experimental/models/base_torch_rec.py,sha256=mwbbsR-sQuQAFC1d8X2k0zP3iJeEP-X5nAaR3IV7Sqg,8105
69
- replay/experimental/models/cql.py,sha256=ItTukqhH3V-PItVPawET9zO9tG4D8R4xKzz3tqKMjSc,19619
70
- replay/experimental/models/ddpg.py,sha256=bzX4KvkuIecYA4bkFB1BnLKE3zqteujhpvsxAXEnKoM,32266
71
- replay/experimental/models/dt4rec/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
72
- replay/experimental/models/dt4rec/dt4rec.py,sha256=zcxn2MjrJg8eYqfGwfK80UjH2-uwNDg4PBbmQZz7Le0,5895
73
- replay/experimental/models/dt4rec/gpt1.py,sha256=T3buFtYyF6Fh6sW6f9dUZFcFEnQdljItbRa22CiKb0w,14044
74
- replay/experimental/models/dt4rec/trainer.py,sha256=YeaJ8mnoYZqnPwm1P9qOYb8GzgFC5At-JeSDcvG2V2o,3859
75
- replay/experimental/models/dt4rec/utils.py,sha256=UF--cukjFB3uwzqaVHdCS3ik2qTtw97tzbSFGPkDfE8,8153
76
- replay/experimental/models/extensions/spark_custom_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
- replay/experimental/models/extensions/spark_custom_models/als_extension.py,sha256=KrPr9M9LeasjAvGzcm6B76vx792FLr5ONfvO3zYxFH8,25844
78
- replay/experimental/models/hierarchical_recommender.py,sha256=BqnEFBppKawt8Xx5lzBWk6qnmdCrZ7c2hpKj3mi1GvU,11441
79
- replay/experimental/models/implicit_wrap.py,sha256=8F-f-CaStmlNHwphu-yu8o4Aft08NKDD_SqqH0zp1Uo,4655
80
- replay/experimental/models/lightfm_wrap.py,sha256=rA9T2vGjrbt_GJV1XccYYsrs9qtgDtqVJCWBHFYrm4k,11329
81
- replay/experimental/models/mult_vae.py,sha256=l-6g-2fIs80vxBl9VGY4FrJannAXrzsQOyGNuHU8tDs,11601
82
- replay/experimental/models/neural_ts.py,sha256=oCqStgGg5CpGFAv1dC-3ODmK9nI05evzJ3XKBDQhgAo,42535
83
- replay/experimental/models/neuromf.py,sha256=acC50kxYlctriNGqyOEkq57Iu4icUvZasyWFeRUJans,14386
84
- replay/experimental/models/scala_als.py,sha256=6aMl8hUFR2J_nI5U8Z_-5BxfeATiWnC8zdj1C0AFbm4,10751
85
- replay/experimental/models/u_lin_ucb.py,sha256=-gu6meOYeSwP6N8ILtwasWYj4Mbs6EJEFQXUHE8N_lY,3750
86
- replay/experimental/nn/data/__init__.py,sha256=5EAF-FNd7xhkUpTq_5MyVcPXBD81mJCwYrcbhdGOWjE,48
87
- replay/experimental/nn/data/schema_builder.py,sha256=nfE0-bVgYUwzyhNTTcXUWhfNBAZQLHWenM6-zEglqps,3301
88
- replay/experimental/preprocessing/__init__.py,sha256=uMyeyQ_GKqjLhVGwhrEk3NLhhzS0DKi5xGo3VF4WkiA,130
89
- replay/experimental/preprocessing/data_preparator.py,sha256=-yqWZT06iEYsY7rCSGRAgLcp6o7jvlsU431HspHQ2o4,35940
90
- replay/experimental/preprocessing/padder.py,sha256=uxE6WlmYNd9kbACMEidxG1L19G5Rk0gQbvpN_TosMZ4,9558
91
- replay/experimental/preprocessing/sequence_generator.py,sha256=vFtLkq9MuLGThPsa67103qlcMLYLfnAkR_HI1FXPwjw,9047
92
- replay/experimental/scenarios/__init__.py,sha256=gWFLCkLyOmOppvbRMK7C3UMlMpcbIgiGVolSH6LPgWA,91
93
- replay/experimental/scenarios/obp_wrapper/__init__.py,sha256=ZOJgpjRsmhXTpzGumk3AALKmstNBachtu_hOXUIPY5s,434
94
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py,sha256=swwcot05a8GzIVhEKpfmjG16CuciItVuddPaOjCKo9o,2543
95
- replay/experimental/scenarios/obp_wrapper/replay_offline.py,sha256=9ZP17steBiTh_KO37NnXWyN5LuPpABPhL_QG4JJHf7I,9622
96
- replay/experimental/scenarios/obp_wrapper/utils.py,sha256=Uv_fqyJDt69vIdrw-Y9orLLzyHG0ko8svza0Hs_a87Q,3233
97
- replay/experimental/scenarios/two_stages/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
98
- replay/experimental/scenarios/two_stages/reranker.py,sha256=NQhooA3OXLAh_PwydBNU2DGRRGPq2j2R0SSHtDM7hlg,4238
99
- replay/experimental/scenarios/two_stages/two_stages_scenario.py,sha256=u41ymdhx0MS1I08VDjJ2UhXpSqsfTA1x9Hbz1tOaWLY,29822
100
- replay/experimental/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
101
- replay/experimental/utils/logger.py,sha256=UwLowaeOG17sDEe32LiZel8MnjSTzeW7J3uLG1iwLuA,639
102
- replay/experimental/utils/model_handler.py,sha256=Rfj57E1R_XMEEigHNZa9a-rzEsyLWSDsgKfXoRzWWdg,6426
103
- replay/experimental/utils/session_handler.py,sha256=H0C-Q2pqrs_5aDvoAkRMZuS5qu07uhu6g5FEL3NJiic,1305
104
49
  replay/metrics/__init__.py,sha256=j0PGvUehaPEZMNo9SQwJsnvzrS4bam9eHrRMQFLnMjY,2813
105
50
  replay/metrics/base_metric.py,sha256=ejtwFHktN4J8Fi1HIM3w0zlMAd8nO7-XpFi2D1iHXUQ,16010
106
51
  replay/metrics/categorical_diversity.py,sha256=3tp8n457Ob4gjM-UTB5N19u9WAF7fLDkWKk-Mth-Vzc,10769
@@ -117,7 +62,7 @@ replay/metrics/precision.py,sha256=DRlsgY_b4bJCOSZjCA58N41REMiDt-dbagRSXxfXyvY,2
117
62
  replay/metrics/recall.py,sha256=fzpASDiH88zcpXJZTbStQ3nuzzSdhd9k1wjF27rM4wc,2447
118
63
  replay/metrics/rocauc.py,sha256=1vaVEK7DQTL8BX-i7A64hTFWyO38aNycscPGrdWKwbA,3282
119
64
  replay/metrics/surprisal.py,sha256=HkmYrOuw3jydxFrkidjdcpAcKz2DeOnMsKqwB2g9pwY,7526
120
- replay/metrics/torch_metrics_builder.py,sha256=mnHrmRTOKZ_edrTrTKs7IPzKt5DkQYRd2B_8b3bB9yU,14071
65
+ replay/metrics/torch_metrics_builder.py,sha256=cf0cRGQnBtR9OUUFOzLvOk3pX9rX2613nw5L28J4DDw,14083
121
66
  replay/metrics/unexpectedness.py,sha256=LSi-z50l3_yrvLnmToHQzm6Ygf2QpNt_zhk6jdg7QUo,6882
122
67
  replay/models/__init__.py,sha256=kECYluQZ83zRUWaHVvnt7Tg3BerHrJy9v8XfRxsqyYY,1123
123
68
  replay/models/als.py,sha256=1MFAbcx64tv0MX1wE9CM1NxKD3F3ZDhZUrmt6dvHu74,6220
@@ -170,7 +115,7 @@ replay/models/nn/sequential/bert4rec/lightning.py,sha256=vxAf1H1VfLqgZhOz9fxEMmw
170
115
  replay/models/nn/sequential/bert4rec/model.py,sha256=C1AKcQ8KF0XMXERwrFneW9kg7hzc-9FIqhCc-t91F7o,17469
171
116
  replay/models/nn/sequential/callbacks/__init__.py,sha256=Q7mSZ_RB6iyD7QZaBL_NJ0uh8cRfgxq7gtPHbkSyhoo,282
172
117
  replay/models/nn/sequential/callbacks/prediction_callbacks.py,sha256=UtEzO9_f5Jwku9dbz7twr4o2_cV3L-viC4lQuce5l1c,10808
173
- replay/models/nn/sequential/callbacks/validation_callback.py,sha256=ydcNkUhaFD78ogqZWySzzKg4BaPyEkaRqmLiD4qFDzM,6583
118
+ replay/models/nn/sequential/callbacks/validation_callback.py,sha256=VDIa8c6Wpekz_AvzdtETajvdkqi2aiJBokE9JEOY3rI,7071
174
119
  replay/models/nn/sequential/compiled/__init__.py,sha256=eSVcCaUH5cDJQRbC7K99X7uMNR-Z-KR4TmYOGKWWJCI,531
175
120
  replay/models/nn/sequential/compiled/base_compiled_model.py,sha256=f4AuTyx5tufQOtOWUSEgj1cWvMZzSL7YN2Z-PtURgTY,10478
176
121
  replay/models/nn/sequential/compiled/bert4rec_compiled.py,sha256=woGI3qk4J2Rb5FyaDwpSCuG-AMfyH34F6Bt5pV-wqk0,6798
@@ -201,8 +146,8 @@ replay/nn/ffn.py,sha256=ivOFu14289URepyEFxYov_XNYMUrINjU-2rEqoXxbnU,4618
201
146
  replay/nn/head.py,sha256=csjwQrcA7M7FebgSL1tKDbjfaoni52CymQR0Zt8RhWg,2084
202
147
  replay/nn/lightning/__init__.py,sha256=jHiwtYuboGUY4Of18zrkvdWD0xXJ_zuo83-XgiqxSfY,36
203
148
  replay/nn/lightning/callback/__init__.py,sha256=ImNEJeIK-wJnqdkZgP8tWTDQHaS9xYqzTEf3FEM0XAw,253
204
- replay/nn/lightning/callback/metrics_callback.py,sha256=dIu1wDtqjXH6ogFGsh2L-dpkgz7OKjtTrVbBLrI4pjg,6986
205
- replay/nn/lightning/callback/predictions_callback.py,sha256=e9PeXNyyGz-m46FEaafgCToPEVC9T5Cb8Q4sFArnpLY,11347
149
+ replay/nn/lightning/callback/metrics_callback.py,sha256=AzDsxvNHfjrJdhcgZsMtnKju1TYO86Pc2Knv_tS7HBA,7323
150
+ replay/nn/lightning/callback/predictions_callback.py,sha256=4iS3QwRRFolAwizxDp2guBDJNRvitgOOHhYmMPL8ub0,11307
206
151
  replay/nn/lightning/module.py,sha256=jFvevwiriY9alZMBw6KAiRMsJv-dJ8fEVrenVRiuWeI,5246
207
152
  replay/nn/lightning/optimizer.py,sha256=1tXhz9RIBHLpEQtZ1PUzCAc4mn6Q_E38zR0nf5km6U8,1778
208
153
  replay/nn/lightning/postprocessor/__init__.py,sha256=LhUeOWDD5vRBDXF2tQEjvPKH1rNIlrf5KPbcV66AdtQ,77
@@ -210,10 +155,10 @@ replay/nn/lightning/postprocessor/_base.py,sha256=X0LtYItmxlXt4Sxk3cOdyIK3FG5dij
210
155
  replay/nn/lightning/postprocessor/seen_items.py,sha256=h-sfD3vmNCdS7lYvqCfqw9oPqutmaSIuZ0CIidG0Y30,2922
211
156
  replay/nn/lightning/scheduler.py,sha256=CUuynPTFrKBrkpmbWR-xpfAkHZ0Vfz_THUDo3uoZi8k,2714
212
157
  replay/nn/loss/__init__.py,sha256=YXAXQIN0coj8MxeK5isTGXgvMxhH5pUO6j1D3d7jl3A,471
213
- replay/nn/loss/base.py,sha256=oD1vATWoQDi45zG9EPjg3hgDrfpr4ue_rQFfArn1dFs,8871
158
+ replay/nn/loss/base.py,sha256=XM2ASulAW8Kyg2Vw43I8Tqv1d8cij9NcNirP_RTk4b8,8811
214
159
  replay/nn/loss/bce.py,sha256=cPlxdJTBZ0b22K6V9ve4qo7xkp99CjEsnl3_vVGphqs,8373
215
160
  replay/nn/loss/ce.py,sha256=jOmhLtKD_E0jX8tUfXpsmaaQVHKKiwXW9USB_GyN3ZU,13218
216
- replay/nn/loss/login_ce.py,sha256=NER_Hbs_H3IXn_bkgwG25VQNQ6ZjjDcxq-aMI7pC2eM,16498
161
+ replay/nn/loss/login_ce.py,sha256=ri4KvHQXOVMB5o_vqGY2u8ayatYH9MLZwsXwp6cpDhI,16478
217
162
  replay/nn/loss/logout_ce.py,sha256=KhcYyCnUzLZR1sFpxM6_QliLoxmC6MJoLkPOgf_ZYzU,10306
218
163
  replay/nn/mask.py,sha256=Jbx7sulGZYfasNaD9CZzJma0cEVaDlxdpzs295507II,3329
219
164
  replay/nn/normalization.py,sha256=Z86t5WCr4KfVR9qCCe-EIAwwomnIIxb11PP88WHA1JI,187
@@ -222,11 +167,11 @@ replay/nn/sequential/__init__.py,sha256=jet_ueMz5Bm087JDph7ln87NID7DbCb0WENj-tjo
222
167
  replay/nn/sequential/sasrec/__init__.py,sha256=8crj-JL8xeP-cCOCnxCSVF_-R6feKhj0YRHOcaMsqrU,213
223
168
  replay/nn/sequential/sasrec/agg.py,sha256=e-IkIO-MMbei2UGxTUopWvloguJoVaZiN31sXkdUVag,2004
224
169
  replay/nn/sequential/sasrec/diff_transformer.py,sha256=4ehM5EMizajmWBAzmcj3CYSFl21V1R2b7RDRJlx3O4Q,4790
225
- replay/nn/sequential/sasrec/model.py,sha256=sQ2FvfDyZ3G6PjbNME--fMboqUt66z9J8t8YYlJ9J6Q,14803
170
+ replay/nn/sequential/sasrec/model.py,sha256=Db4IcI4EHzQoO7Vij_ItGvvs8aOJ6ANyHuXp_9v84zs,14801
226
171
  replay/nn/sequential/sasrec/transformer.py,sha256=sJf__IPnhbJWDPuFTPSbBGSSntznQtS-hJtJo3iFBkw,4037
227
172
  replay/nn/sequential/twotower/__init__.py,sha256=-rEASPqKCbS55MTTgeDZ5irfWfM9or1vNTHZnJN2AcU,124
228
173
  replay/nn/sequential/twotower/model.py,sha256=VxUUjldHndCkDjrXGqmxGnTi5fh8vmnr7XNBpYjsqW8,28659
229
- replay/nn/sequential/twotower/reader.py,sha256=j4mlKx5Lf3hFnSgaxMLkuqWLZd3dkLchDI4JEuZHLGY,3674
174
+ replay/nn/sequential/twotower/reader.py,sha256=8z-R0oZDbOaw6eFL3ffyt7yuc3q7qoKFhFrIdiwwJ10,3938
230
175
  replay/nn/transform/__init__.py,sha256=9PeaDHmftb0s1gEEgJRNWw6Bl2wfE_-lImatipaHUQ0,705
231
176
  replay/nn/transform/copy.py,sha256=ZfNXbMJYTwXDMJ5T8ib9Dh5XOGLjj7gGB4NbBExFZiM,1302
232
177
  replay/nn/transform/grouping.py,sha256=XOJoVBk234DI6x05Kqr7KOjLetDaLp2NMAJWHecQcsI,1384
@@ -236,8 +181,8 @@ replay/nn/transform/rename.py,sha256=_uD2e1UmtBRyOTVpHUnZ5xhePmClaGQsc0g7Es-rupE
236
181
  replay/nn/transform/reshape.py,sha256=sgswIogWHUwOVp02k13Qopn84LofqLoA4M7U1GAfmio,1359
237
182
  replay/nn/transform/sequence_roll.py,sha256=7jf42SgWHU1L7SirqQWXx0h9a6VQQ29kehE4LmdUt9o,1531
238
183
  replay/nn/transform/template/__init__.py,sha256=lYzAekZUXwncGR66Nq8YypplGOtL00GFfm0PalGiY5g,106
239
- replay/nn/transform/template/sasrec.py,sha256=FoOhroe-S0JPaxIQ3Ba-3_gyslgj47RoLL2geOxNAO4,1906
240
- replay/nn/transform/template/twotower.py,sha256=BIlbqTfKEMcyx2Ksr4qzAD0h0mdhiTLa1xcmZ2e8Ksc,896
184
+ replay/nn/transform/template/sasrec.py,sha256=0bhb8EhyTafiM9Eh1GfPEFS_syyUCPY3JOzTBr17I-o,1919
185
+ replay/nn/transform/template/twotower.py,sha256=qutfN3iUwGHn1BG2yLEc7hHQoxEwQiFdvXIoe-ej4tM,897
241
186
  replay/nn/transform/token_mask.py,sha256=WcalZkY2UCoNiq2mBtu8fqYFOUfqCh21XyDMgvIpeB4,2529
242
187
  replay/nn/transform/trim.py,sha256=mPn6LPxu3c3yE14heMSRsDEU4h94tkFiRr62mOa3lKg,1608
243
188
  replay/nn/utils.py,sha256=GumtN-QRP9ljXYti3YvuNk13e0Q92xvkYuCJBhaViCI,801
@@ -271,8 +216,8 @@ replay/utils/session_handler.py,sha256=fQo2wseow8yuzKnEXT-aYAXcQIgRbTTXp0v7g1VVi
271
216
  replay/utils/spark_utils.py,sha256=GbRp-MuUoO3Pc4chFvlmo9FskSlRLeNlC3Go5pEJ6Ok,27411
272
217
  replay/utils/time.py,sha256=J8asoQBytPcNw-BLGADYIsKeWhIoN1H5hKiX9t2AMqo,9376
273
218
  replay/utils/types.py,sha256=rD9q9CqEXgF4yy512Hv2nXclvwcnfodOnhBZ1HSUI4c,1260
274
- replay_rec-0.21.0rc0.dist-info/METADATA,sha256=AaAA9pE-AnldOt_PSPiOYnzGaqkGWnuoPnPPiSy5AcI,13166
275
- replay_rec-0.21.0rc0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
276
- replay_rec-0.21.0rc0.dist-info/licenses/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
277
- replay_rec-0.21.0rc0.dist-info/licenses/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
278
- replay_rec-0.21.0rc0.dist-info/RECORD,,
219
+ replay_rec-0.21.1.dist-info/METADATA,sha256=atVJNoBxihnIh3r6q0pORPlykCDaDD5t-HPKJ6ufNIw,13573
220
+ replay_rec-0.21.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
221
+ replay_rec-0.21.1.dist-info/licenses/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
222
+ replay_rec-0.21.1.dist-info/licenses/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
223
+ replay_rec-0.21.1.dist-info/RECORD,,
File without changes
@@ -1,62 +0,0 @@
1
- """
2
- Most metrics require dataframe with recommendations
3
- and dataframe with ground truth values —
4
- which objects each user interacted with.
5
-
6
- - recommendations (Union[pandas.DataFrame, spark.DataFrame]):
7
- predictions of a recommender system,
8
- DataFrame with columns ``[user_id, item_id, relevance]``
9
- - ground_truth (Union[pandas.DataFrame, spark.DataFrame]):
10
- test data, DataFrame with columns
11
- ``[user_id, item_id, timestamp, relevance]``
12
-
13
- Metric is calculated for all users, presented in ``ground_truth``
14
- for accurate metric calculation in case when the recommender system generated
15
- recommendation not for all users. It is assumed, that all users,
16
- we want to calculate metric for, have positive interactions.
17
-
18
- But if we have users, who observed the recommendations, but have not responded,
19
- those users will be ignored and metric will be overestimated.
20
- For such case we propose additional optional parameter ``ground_truth_users``,
21
- the dataframe with all users, which should be considered during the metric calculation.
22
-
23
- - ground_truth_users (Optional[Union[pandas.DataFrame, spark.DataFrame]]):
24
- full list of users to calculate metric for, DataFrame with ``user_id`` column
25
-
26
- Every metric is calculated using top ``K`` items for each user.
27
- It is also possible to calculate metrics
28
- using multiple values for ``K`` simultaneously.
29
- In this case the result will be a dictionary and not a number.
30
-
31
- Make sure your recommendations do not contain user-item duplicates
32
- as duplicates could lead to the wrong calculation results.
33
-
34
- - k (Union[Iterable[int], int]):
35
- a single number or a list, specifying the
36
- truncation length for recommendation list for each user
37
-
38
- By default, metrics are averaged by users,
39
- but you can alternatively use method ``metric.median``.
40
- Also, you can get the lower bound
41
- of ``conf_interval`` for a given ``alpha``.
42
-
43
- Diversity metrics require extra parameters on initialization stage,
44
- but do not use ``ground_truth`` parameter.
45
-
46
- For each metric, a formula for its calculation is given, because this is
47
- important for the correct comparison of algorithms, as mentioned in our
48
- `article <https://arxiv.org/abs/2206.12858>`_.
49
- """
50
-
51
- from replay.experimental.metrics.base_metric import Metric, NCISMetric
52
- from replay.experimental.metrics.coverage import Coverage
53
- from replay.experimental.metrics.hitrate import HitRate
54
- from replay.experimental.metrics.map import MAP
55
- from replay.experimental.metrics.mrr import MRR
56
- from replay.experimental.metrics.ncis_precision import NCISPrecision
57
- from replay.experimental.metrics.ndcg import NDCG
58
- from replay.experimental.metrics.precision import Precision
59
- from replay.experimental.metrics.recall import Recall
60
- from replay.experimental.metrics.rocauc import RocAuc
61
- from replay.experimental.metrics.surprisal import Surprisal
62
- from replay.experimental.metrics.unexpectedness import Unexpectedness