replay-rec 0.20.3rc0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/METADATA +18 -12
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay/experimental/metrics/__init__.py +0 -62
- replay/experimental/metrics/base_metric.py +0 -603
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -50
- replay/experimental/models/admm_slim.py +0 -257
- replay/experimental/models/base_neighbour_rec.py +0 -200
- replay/experimental/models/base_rec.py +0 -1386
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -454
- replay/experimental/models/ddpg.py +0 -932
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -264
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/hierarchical_recommender.py +0 -331
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -303
- replay/experimental/models/mult_vae.py +0 -332
- replay/experimental/models/neural_ts.py +0 -986
- replay/experimental/models/neuromf.py +0 -406
- replay/experimental/models/scala_als.py +0 -293
- replay/experimental/models/u_lin_ucb.py +0 -115
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -102
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -839
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -186
- replay/experimental/utils/session_handler.py +0 -44
- replay_rec-0.20.3rc0.dist-info/RECORD +0 -193
- /replay/{experimental → data/nn/parquet/constants}/__init__.py +0 -0
- /replay/{experimental/models/dt4rec → data/nn/parquet/impl}/__init__.py +0 -0
- /replay/{experimental/models/extensions/spark_custom_models → data/nn/parquet/info}/__init__.py +0 -0
- /replay/{experimental/scenarios/two_stages → data/nn/parquet/utils}/__init__.py +0 -0
- /replay/{experimental → data}/utils/__init__.py +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
from replay.utils import (
|
|
8
|
+
PYSPARK_AVAILABLE,
|
|
9
|
+
DataFrameLike,
|
|
10
|
+
PandasDataFrame,
|
|
11
|
+
PolarsDataFrame,
|
|
12
|
+
SparkDataFrame,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from .base_splitter import Splitter
|
|
16
|
+
|
|
17
|
+
if PYSPARK_AVAILABLE:
|
|
18
|
+
import pyspark.sql.functions as sf
|
|
19
|
+
from pyspark.sql import Window
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RandomNextNSplitter(Splitter):
|
|
23
|
+
"""
|
|
24
|
+
Split interactions by a random position in the user sequence.
|
|
25
|
+
For each user, a random cut index is sampled and the target part consists of
|
|
26
|
+
the next ``N`` interactions starting from this cut; the train part contains
|
|
27
|
+
all interactions before the cut. Interactions after the target window are
|
|
28
|
+
discarded.
|
|
29
|
+
|
|
30
|
+
Note: by changing the ``seed`` attribute on an existing splitter instance,
|
|
31
|
+
you can obtain different splits without recreating the object. This is useful
|
|
32
|
+
when you need to generate multiple randomized splits of the same dataset.
|
|
33
|
+
|
|
34
|
+
>>> from datetime import datetime
|
|
35
|
+
>>> import pandas as pd
|
|
36
|
+
>>> columns = ["query_id", "item_id", "timestamp"]
|
|
37
|
+
>>> data = [
|
|
38
|
+
... (1, 1, "01-01-2020"),
|
|
39
|
+
... (1, 2, "02-01-2020"),
|
|
40
|
+
... (1, 3, "03-01-2020"),
|
|
41
|
+
... (2, 1, "06-01-2020"),
|
|
42
|
+
... (2, 2, "07-01-2020"),
|
|
43
|
+
... (2, 3, "08-01-2020"),
|
|
44
|
+
... ]
|
|
45
|
+
>>> dataset = pd.DataFrame(data, columns=columns)
|
|
46
|
+
>>> dataset["timestamp"] = pd.to_datetime(dataset["timestamp"], format="%d-%m-%Y")
|
|
47
|
+
>>> splitter = RandomNextNSplitter(
|
|
48
|
+
... N=2,
|
|
49
|
+
... divide_column="query_id",
|
|
50
|
+
... seed=42,
|
|
51
|
+
... query_column="query_id",
|
|
52
|
+
... item_column="item_id",
|
|
53
|
+
... )
|
|
54
|
+
>>> train, test = splitter.split(dataset)
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
_init_arg_names = [
|
|
58
|
+
"N",
|
|
59
|
+
"divide_column",
|
|
60
|
+
"seed",
|
|
61
|
+
"drop_cold_users",
|
|
62
|
+
"drop_cold_items",
|
|
63
|
+
"query_column",
|
|
64
|
+
"item_column",
|
|
65
|
+
"timestamp_column",
|
|
66
|
+
"session_id_column",
|
|
67
|
+
"session_id_processing_strategy",
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
N: Optional[int] = 1, # noqa: N803
|
|
73
|
+
divide_column: str = "query_id",
|
|
74
|
+
seed: Optional[int] = None,
|
|
75
|
+
query_column: str = "query_id",
|
|
76
|
+
drop_cold_users: bool = False,
|
|
77
|
+
drop_cold_items: bool = False,
|
|
78
|
+
item_column: str = "item_id",
|
|
79
|
+
timestamp_column: str = "timestamp",
|
|
80
|
+
session_id_column: Optional[str] = None,
|
|
81
|
+
session_id_processing_strategy: str = "test",
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
:param N: Optional window size. If None, the test set contains all interactions
|
|
85
|
+
from the cut to the end; otherwise the next ``N`` interactions. Must be >= 1.
|
|
86
|
+
Default: 1.
|
|
87
|
+
:param divide_column: Name of the column used to group interactions
|
|
88
|
+
for random cut sampling, default: ``query_id``.
|
|
89
|
+
:param seed: Random seed used to sample cut indices, default: ``None``.
|
|
90
|
+
:param query_column: Name of query interaction column.
|
|
91
|
+
:param drop_cold_users: Drop users from test DataFrame which are not in
|
|
92
|
+
the train DataFrame, default: ``False``.
|
|
93
|
+
:param drop_cold_items: Drop items from test DataFrame which are not in
|
|
94
|
+
the train DataFrame, default: ``False``.
|
|
95
|
+
:param item_column: Name of item interaction column.
|
|
96
|
+
If ``drop_cold_items`` is ``False``, then you can omit this
|
|
97
|
+
parameter. Default: ``item_id``.
|
|
98
|
+
:param timestamp_column: Name of time column, default: ``timestamp``.
|
|
99
|
+
:param session_id_column: Name of session id column whose values cannot
|
|
100
|
+
be split between train/test, default: ``None``.
|
|
101
|
+
:param session_id_processing_strategy: Strategy to process a session if
|
|
102
|
+
it crosses the boundary: ``train`` or ``test``. ``train`` means the
|
|
103
|
+
whole session goes to train, ``test`` — the whole session goes to
|
|
104
|
+
test. Default: ``test``.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
super().__init__(
|
|
108
|
+
drop_cold_users=drop_cold_users,
|
|
109
|
+
drop_cold_items=drop_cold_items,
|
|
110
|
+
query_column=query_column,
|
|
111
|
+
item_column=item_column,
|
|
112
|
+
timestamp_column=timestamp_column,
|
|
113
|
+
session_id_column=session_id_column,
|
|
114
|
+
session_id_processing_strategy=session_id_processing_strategy,
|
|
115
|
+
)
|
|
116
|
+
self.N = N
|
|
117
|
+
if self.N is not None and self.N < 1:
|
|
118
|
+
msg = "N must be >= 1"
|
|
119
|
+
raise ValueError(msg)
|
|
120
|
+
self.divide_column = divide_column
|
|
121
|
+
self.seed = seed
|
|
122
|
+
|
|
123
|
+
def _sample_cuts(self, counts: np.ndarray) -> np.ndarray:
|
|
124
|
+
rng = np.random.RandomState(self.seed)
|
|
125
|
+
return rng.randint(0, counts)
|
|
126
|
+
|
|
127
|
+
def _partial_split_pandas(
|
|
128
|
+
self,
|
|
129
|
+
interactions: PandasDataFrame,
|
|
130
|
+
) -> tuple[PandasDataFrame, PandasDataFrame]:
|
|
131
|
+
df = interactions.sort_values([self.divide_column, self.timestamp_column])
|
|
132
|
+
df["_event_rank"] = df.groupby(self.divide_column, sort=False).cumcount()
|
|
133
|
+
|
|
134
|
+
counts = df.groupby(self.divide_column, sort=False).size()
|
|
135
|
+
cuts = pd.Series(self._sample_cuts(counts.values), index=counts.index)
|
|
136
|
+
df["_cut_index"] = df[self.divide_column].map(cuts)
|
|
137
|
+
|
|
138
|
+
if self.N is not None:
|
|
139
|
+
df = df[df["_event_rank"] < df["_cut_index"] + self.N]
|
|
140
|
+
|
|
141
|
+
df["is_test"] = df["_event_rank"] >= df["_cut_index"]
|
|
142
|
+
if self.session_id_column:
|
|
143
|
+
df = self._recalculate_with_session_id_column(df)
|
|
144
|
+
|
|
145
|
+
train = df[~df["is_test"]][interactions.columns]
|
|
146
|
+
test = df[df["is_test"]][interactions.columns]
|
|
147
|
+
|
|
148
|
+
return train, test
|
|
149
|
+
|
|
150
|
+
def _partial_split_polars(
|
|
151
|
+
self,
|
|
152
|
+
interactions: PolarsDataFrame,
|
|
153
|
+
) -> tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
154
|
+
df = interactions.sort([self.divide_column, self.timestamp_column]).with_columns(
|
|
155
|
+
(pl.col(self.divide_column).cum_count().over(self.divide_column) - 1).alias("_event_rank")
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
counts = df.group_by(self.divide_column).len()
|
|
159
|
+
r_values = self._sample_cuts(counts["len"].to_numpy())
|
|
160
|
+
cuts_df = pl.DataFrame(
|
|
161
|
+
{
|
|
162
|
+
self.divide_column: counts[self.divide_column],
|
|
163
|
+
"_cut_index": r_values,
|
|
164
|
+
}
|
|
165
|
+
)
|
|
166
|
+
df = df.join(cuts_df, on=self.divide_column, how="left")
|
|
167
|
+
|
|
168
|
+
if self.N is not None:
|
|
169
|
+
df = df.filter(pl.col("_event_rank") < pl.col("_cut_index") + self.N)
|
|
170
|
+
|
|
171
|
+
df = df.with_columns((pl.col("_event_rank") >= pl.col("_cut_index")).alias("is_test"))
|
|
172
|
+
if self.session_id_column:
|
|
173
|
+
df = self._recalculate_with_session_id_column(df)
|
|
174
|
+
|
|
175
|
+
train = df.filter(~pl.col("is_test")).select(interactions.columns)
|
|
176
|
+
test = df.filter(pl.col("is_test")).select(interactions.columns)
|
|
177
|
+
|
|
178
|
+
return train, test
|
|
179
|
+
|
|
180
|
+
def _partial_split_spark(
|
|
181
|
+
self,
|
|
182
|
+
interactions: SparkDataFrame,
|
|
183
|
+
) -> tuple[SparkDataFrame, SparkDataFrame]:
|
|
184
|
+
w = Window.partitionBy(self.divide_column).orderBy(self.timestamp_column)
|
|
185
|
+
df = interactions.withColumn("_event_rank", sf.row_number().over(w) - sf.lit(1))
|
|
186
|
+
|
|
187
|
+
counts = df.groupBy(self.divide_column).agg(sf.count(sf.lit(1)).alias("_count"))
|
|
188
|
+
seed_lit = sf.lit(self.seed) if self.seed is not None else sf.lit(0)
|
|
189
|
+
cuts = counts.select(
|
|
190
|
+
self.divide_column,
|
|
191
|
+
sf.pmod(
|
|
192
|
+
sf.xxhash64(sf.col(self.divide_column), seed_lit).cast("long"),
|
|
193
|
+
sf.col("_count").cast("long"),
|
|
194
|
+
)
|
|
195
|
+
.cast("long")
|
|
196
|
+
.alias("_cut_index"),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
df = df.join(cuts, on=self.divide_column, how="left")
|
|
200
|
+
|
|
201
|
+
if self.N is not None:
|
|
202
|
+
df = df.filter(sf.col("_event_rank") < sf.col("_cut_index") + sf.lit(self.N))
|
|
203
|
+
|
|
204
|
+
df = df.withColumn("is_test", sf.col("_event_rank") >= sf.col("_cut_index"))
|
|
205
|
+
if self.session_id_column:
|
|
206
|
+
df = self._recalculate_with_session_id_column(df)
|
|
207
|
+
|
|
208
|
+
train = df.filter(~sf.col("is_test")).select(interactions.columns)
|
|
209
|
+
test = df.filter(sf.col("is_test")).select(interactions.columns)
|
|
210
|
+
|
|
211
|
+
return train, test
|
|
212
|
+
|
|
213
|
+
def _partial_split(self, interactions: DataFrameLike) -> tuple[DataFrameLike, DataFrameLike]:
|
|
214
|
+
if isinstance(interactions, PandasDataFrame):
|
|
215
|
+
return self._partial_split_pandas(interactions)
|
|
216
|
+
if isinstance(interactions, PolarsDataFrame):
|
|
217
|
+
return self._partial_split_polars(interactions)
|
|
218
|
+
if isinstance(interactions, SparkDataFrame):
|
|
219
|
+
return self._partial_split_spark(interactions)
|
|
220
|
+
msg = f"{self} is not implemented for {type(interactions)}"
|
|
221
|
+
raise NotImplementedError(msg)
|
|
222
|
+
|
|
223
|
+
def _core_split(self, interactions: DataFrameLike) -> tuple[DataFrameLike, DataFrameLike]:
|
|
224
|
+
return self._partial_split(interactions)
|
replay/utils/common.py
CHANGED
|
@@ -2,9 +2,10 @@ import functools
|
|
|
2
2
|
import inspect
|
|
3
3
|
import json
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Any, Callable, Union
|
|
5
|
+
from typing import Any, Callable, TypeVar, Union
|
|
6
6
|
|
|
7
7
|
from polars import from_pandas as pl_from_pandas
|
|
8
|
+
from typing_extensions import ParamSpec
|
|
8
9
|
|
|
9
10
|
from replay.data.dataset import Dataset
|
|
10
11
|
from replay.preprocessing import (
|
|
@@ -16,6 +17,7 @@ from replay.splitters import (
|
|
|
16
17
|
KFolds,
|
|
17
18
|
LastNSplitter,
|
|
18
19
|
NewUsersSplitter,
|
|
20
|
+
RandomNextNSplitter,
|
|
19
21
|
RandomSplitter,
|
|
20
22
|
RatioSplitter,
|
|
21
23
|
TimeSplitter,
|
|
@@ -37,6 +39,7 @@ SavableObject = Union[
|
|
|
37
39
|
KFolds,
|
|
38
40
|
LastNSplitter,
|
|
39
41
|
NewUsersSplitter,
|
|
42
|
+
RandomNextNSplitter,
|
|
40
43
|
RandomSplitter,
|
|
41
44
|
RatioSplitter,
|
|
42
45
|
TimeSplitter,
|
|
@@ -56,6 +59,9 @@ if TORCH_AVAILABLE:
|
|
|
56
59
|
PolarsSequentialDataset,
|
|
57
60
|
]
|
|
58
61
|
|
|
62
|
+
P = ParamSpec("P")
|
|
63
|
+
R = TypeVar("R")
|
|
64
|
+
|
|
59
65
|
|
|
60
66
|
def save_to_replay(obj: SavableObject, path: Union[str, Path]) -> None:
|
|
61
67
|
"""
|
|
@@ -87,10 +93,10 @@ def _check_if_dataframe(var: Any):
|
|
|
87
93
|
raise ValueError(msg)
|
|
88
94
|
|
|
89
95
|
|
|
90
|
-
def check_if_dataframe(*args_to_check: str) -> Callable[
|
|
91
|
-
def decorator_func(func: Callable[
|
|
96
|
+
def check_if_dataframe(*args_to_check: str) -> Callable[P, R]:
|
|
97
|
+
def decorator_func(func: Callable[P, R]) -> Callable[P, R]:
|
|
92
98
|
@functools.wraps(func)
|
|
93
|
-
def wrap_func(*args:
|
|
99
|
+
def wrap_func(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
94
100
|
extended_kwargs = {}
|
|
95
101
|
extended_kwargs.update(kwargs)
|
|
96
102
|
extended_kwargs.update(dict(zip(inspect.signature(func).parameters.keys(), args)))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: replay-rec
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.21.0
|
|
4
4
|
Summary: RecSys Library
|
|
5
5
|
License-Expression: Apache-2.0
|
|
6
6
|
License-File: LICENSE
|
|
@@ -14,23 +14,29 @@ Classifier: Intended Audience :: Developers
|
|
|
14
14
|
Classifier: Intended Audience :: Science/Research
|
|
15
15
|
Classifier: Natural Language :: English
|
|
16
16
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
Requires-Dist: lightning (
|
|
21
|
-
Requires-Dist:
|
|
17
|
+
Provides-Extra: spark
|
|
18
|
+
Provides-Extra: torch
|
|
19
|
+
Provides-Extra: torch-cpu
|
|
20
|
+
Requires-Dist: lightning (<2.6.0) ; extra == "torch" or extra == "torch-cpu"
|
|
21
|
+
Requires-Dist: lightning ; extra == "torch"
|
|
22
|
+
Requires-Dist: lightning ; extra == "torch-cpu"
|
|
22
23
|
Requires-Dist: numpy (>=1.20.0,<2)
|
|
23
24
|
Requires-Dist: pandas (>=1.3.5,<2.4.0)
|
|
24
25
|
Requires-Dist: polars (<2.0)
|
|
25
|
-
Requires-Dist: psutil (<=7.0.0)
|
|
26
|
+
Requires-Dist: psutil (<=7.0.0) ; extra == "spark"
|
|
27
|
+
Requires-Dist: psutil ; extra == "spark"
|
|
26
28
|
Requires-Dist: pyarrow (<22.0)
|
|
27
|
-
Requires-Dist: pyspark (>=3.0,<3.5)
|
|
28
|
-
Requires-Dist:
|
|
29
|
-
Requires-Dist:
|
|
29
|
+
Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark"
|
|
30
|
+
Requires-Dist: pyspark ; extra == "spark"
|
|
31
|
+
Requires-Dist: pytorch-optimizer (>=3.8.0,<3.9.0) ; extra == "torch" or extra == "torch-cpu"
|
|
32
|
+
Requires-Dist: pytorch-optimizer ; extra == "torch"
|
|
33
|
+
Requires-Dist: pytorch-optimizer ; extra == "torch-cpu"
|
|
30
34
|
Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
|
|
31
35
|
Requires-Dist: scipy (>=1.8.1,<2.0.0)
|
|
32
36
|
Requires-Dist: setuptools
|
|
33
|
-
Requires-Dist: torch (>=1.8,<
|
|
37
|
+
Requires-Dist: torch (>=1.8,<3.0.0) ; extra == "torch" or extra == "torch-cpu"
|
|
38
|
+
Requires-Dist: torch ; extra == "torch"
|
|
39
|
+
Requires-Dist: torch ; extra == "torch-cpu"
|
|
34
40
|
Requires-Dist: tqdm (>=4.67,<5)
|
|
35
41
|
Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
|
|
36
42
|
Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
|
|
@@ -231,7 +237,7 @@ pip install optuna
|
|
|
231
237
|
|
|
232
238
|
2) Model compilation via OpenVINO:
|
|
233
239
|
```bash
|
|
234
|
-
pip install openvino onnx
|
|
240
|
+
pip install openvino onnx onnxscript
|
|
235
241
|
```
|
|
236
242
|
|
|
237
243
|
3) Vector database and hierarchical search support:
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
replay/__init__.py,sha256=v3mrDhnKFg0X1ZQBAAyAMlOgyZDPiRd01VsfpkOu9bo,225
|
|
2
|
+
replay/data/__init__.py,sha256=g5bKRyF76QL_BqlED-31RnS8pBdcyj9loMsx5vAG_0E,301
|
|
3
|
+
replay/data/dataset.py,sha256=yBl-yJVIokgN4prFY949tHe2UVJC_j5xdaulIoSPvQI,31252
|
|
4
|
+
replay/data/dataset_utils/__init__.py,sha256=9wUvG8ZwGUvuzLU4zQI5FDcH0WVVo5YLN2ey3DterP0,55
|
|
5
|
+
replay/data/dataset_utils/dataset_label_encoder.py,sha256=bxuJPhShFZBok7bQZYGNMV1etCLNTJUpyKO5MIwWack,9823
|
|
6
|
+
replay/data/nn/__init__.py,sha256=Dpso6tN10moj92_NrXCVWBEAMhnGXewGC12H9fTCg0E,1228
|
|
7
|
+
replay/data/nn/parquet/__init__.py,sha256=e6FDBPzlv9SMduGJOtn2EarxPXk3_wHKWConS__SmWk,786
|
|
8
|
+
replay/data/nn/parquet/collate.py,sha256=tOArGUnJCILdAHEHELW7o3iuKCVD4w8BEbxNYXv7yJc,984
|
|
9
|
+
replay/data/nn/parquet/constants/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
replay/data/nn/parquet/constants/batches.py,sha256=2VJwk_W9wOk6C1P4IMvlO5bWuP1i14TCFGHaL487TkI,271
|
|
11
|
+
replay/data/nn/parquet/constants/device.py,sha256=EV25_HKMiPyAx7pops1Vr3YVR-9CmW_cpmxlymmyg9Q,51
|
|
12
|
+
replay/data/nn/parquet/constants/filesystem.py,sha256=v23OWKtTDFnCqCYuhL4d9o-PDjihfYIMVkatWzhqoiQ,67
|
|
13
|
+
replay/data/nn/parquet/constants/metadata.py,sha256=UQdTtnMPwGkpwgogrIij5C9G_HdKCZXKoGur0KFIdCM,133
|
|
14
|
+
replay/data/nn/parquet/fixed_batch_dataset.py,sha256=SFfyUkFaleZ4W_oskrl_6ws8f10Dkqo71U7C6g7yuD8,5150
|
|
15
|
+
replay/data/nn/parquet/impl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
+
replay/data/nn/parquet/impl/array_1d_column.py,sha256=HgjJ4Dz1za_X_WfqrrXbjV_pPo5B_P3OJ3NXw8cjcBY,4823
|
|
17
|
+
replay/data/nn/parquet/impl/array_2d_column.py,sha256=2pjjmlF8Kbqi11uTtMYa29mWR1Aiieajvx3w96YrW50,5633
|
|
18
|
+
replay/data/nn/parquet/impl/column_protocol.py,sha256=Tjcbo3b2I834OMsaVz312AzvFMAiSydWDtN7_uMNzc0,340
|
|
19
|
+
replay/data/nn/parquet/impl/indexing.py,sha256=_ETICbsn-q70iEvUAIwgoZFqrIB2UxQVvqQ4kD3F8DY,4945
|
|
20
|
+
replay/data/nn/parquet/impl/masking.py,sha256=NBq6klPCAUO-Zm-VvCf1t6E_yLGbo03KvAq7Bl64ZsI,627
|
|
21
|
+
replay/data/nn/parquet/impl/named_columns.py,sha256=LUlI7tsh-6kfcVAAKjpPnZvGUwaCvHOj-Zqgqh73A14,3117
|
|
22
|
+
replay/data/nn/parquet/impl/numeric_column.py,sha256=A1jKct3YJegzzu8BoHHcypAgZUPFM3QR03FkdMQfxnI,3940
|
|
23
|
+
replay/data/nn/parquet/impl/utils.py,sha256=MqZcSC4fQnRrrxHFg21ukrntTQkASpEh6SKftOP20Ds,446
|
|
24
|
+
replay/data/nn/parquet/info/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
+
replay/data/nn/parquet/info/distributed_info.py,sha256=po7pl0m24pqjzlWKlfrdDWKZ4w0FEPoYLfW0KqqFCVY,850
|
|
26
|
+
replay/data/nn/parquet/info/partitioning.py,sha256=2XuiRlHaQbiRHrmPrmn2JNogiHvknmvIHB8IIK3pf78,4525
|
|
27
|
+
replay/data/nn/parquet/info/replicas.py,sha256=L2YnD6nvp_hjqhtptzq5KLdm8WlUskFnxbvL_06_AYQ,2480
|
|
28
|
+
replay/data/nn/parquet/info/worker_info.py,sha256=sIqBqHSeFdO00dDg_Mc_6UNXDQXGcu0iQVinRy84RUE,947
|
|
29
|
+
replay/data/nn/parquet/iterable_dataset.py,sha256=mQe2xvrpOU3vrVdy_tCxaFk45fqJv8mnQQS48-sQcqU,4246
|
|
30
|
+
replay/data/nn/parquet/iterator.py,sha256=X5KXtjdY_uSfMlP9IXBqMzSimBqlAZbYX_Y483q_3U8,2577
|
|
31
|
+
replay/data/nn/parquet/metadata/__init__.py,sha256=UZX60ANtjo6zX0p43hU9q8fBldVJNCEmGzXjHqz0MJQ,341
|
|
32
|
+
replay/data/nn/parquet/metadata/metadata.py,sha256=jJOL8mieXhX18FO9lgaP95MOtO1l7tY63ldxoOAvzwA,3459
|
|
33
|
+
replay/data/nn/parquet/parquet_dataset.py,sha256=pKthRppp0MstfNwOk9wMrE6wFvDecCtbTKWIri4HGr0,8017
|
|
34
|
+
replay/data/nn/parquet/parquet_module.py,sha256=g53lgb-bydDg5P27I4MODnnMcRi1qjpvAw3_QQ9UgxQ,8208
|
|
35
|
+
replay/data/nn/parquet/partitioned_iterable_dataset.py,sha256=BZEh2EiBKMZxi822-doyTbjDkZQQ62SxAp_NhZVZdmk,1938
|
|
36
|
+
replay/data/nn/parquet/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
37
|
+
replay/data/nn/parquet/utils/compute_length.py,sha256=VWabulpRICy-_Z0ZBXpEmhAIlpXVwTwe9kX2L2XCdbE,2492
|
|
38
|
+
replay/data/nn/schema.py,sha256=vLSDj4ZprOL9aurdcpOZ78KgNRXXuwt4wuTq5feiAvA,17115
|
|
39
|
+
replay/data/nn/sequence_tokenizer.py,sha256=zh026PRsTzPhUhW1SqPOvAZOdrIDbDyBRwdkgwtvTh0,37745
|
|
40
|
+
replay/data/nn/sequential_dataset.py,sha256=BcLkM_w3yF7F0EgPK5_jcreurh8k0fVJBoA9KJpp1fM,11800
|
|
41
|
+
replay/data/nn/torch_sequential_dataset.py,sha256=VQ3l3SQBFxIuXKr5FpVJNE-As3MgJ7SAa4Aeb0S2yNA,11874
|
|
42
|
+
replay/data/nn/utils.py,sha256=Ic3G4yZRIzBYXLmwP1VstlZXPNR7AYGCc5EyZAERp5c,3297
|
|
43
|
+
replay/data/schema.py,sha256=JmYLCrNgBS5oq4O_PT724Gr1pDurHEykcqV8Xaj0XTw,15922
|
|
44
|
+
replay/data/spark_schema.py,sha256=4o0Kn_fjwz2-9dBY3q46F9PL0F3E7jdVpIlX7SG3OZI,1111
|
|
45
|
+
replay/data/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
46
|
+
replay/data/utils/batching.py,sha256=jBNhRC5jqNe2pVVlmvFLvjTo86Ud0e_Lj2P0W2yNcKY,2268
|
|
47
|
+
replay/data/utils/typing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
48
|
+
replay/data/utils/typing/dtype.py,sha256=QJigLH7fv-xIb_s-R_70KTZxOgl2ZJkhEhf_txziRAY,1590
|
|
49
|
+
replay/metrics/__init__.py,sha256=j0PGvUehaPEZMNo9SQwJsnvzrS4bam9eHrRMQFLnMjY,2813
|
|
50
|
+
replay/metrics/base_metric.py,sha256=ejtwFHktN4J8Fi1HIM3w0zlMAd8nO7-XpFi2D1iHXUQ,16010
|
|
51
|
+
replay/metrics/categorical_diversity.py,sha256=3tp8n457Ob4gjM-UTB5N19u9WAF7fLDkWKk-Mth-Vzc,10769
|
|
52
|
+
replay/metrics/coverage.py,sha256=e6vPItrRlI-mLNuOT5uoo5lMAAzkYGKZRxvupi21dMk,8528
|
|
53
|
+
replay/metrics/descriptors.py,sha256=BHORyGKfJgPeUjgLO0u2urSTe16UQbb-HHh8soqnwDE,3893
|
|
54
|
+
replay/metrics/experiment.py,sha256=6Sw8PyItn3E2R-BBa_YwrmtBV3n0uAGHHOvkhHYgMz4,8125
|
|
55
|
+
replay/metrics/hitrate.py,sha256=LcOJLMs3_Dq4_pbKx95qdCdjGrX52dyWyuWUFXCyaDw,2314
|
|
56
|
+
replay/metrics/map.py,sha256=dIZcmUxd2XnNC7d_d7gmq0cjNaI1hlNMaJTSHGCokQE,2572
|
|
57
|
+
replay/metrics/mrr.py,sha256=qM8tVMSoyYR-kTx0mnBGppoC53SxNlZKm7JKMUmSv9U,2163
|
|
58
|
+
replay/metrics/ndcg.py,sha256=izajmD243ZIK3KLm9M-NtLwxb9N3Ktj58__AAfwF6Vc,3110
|
|
59
|
+
replay/metrics/novelty.py,sha256=j3p1fbUVi2QQgEre42jeQx73PYYDUhy5gYlrL4BL5b8,5488
|
|
60
|
+
replay/metrics/offline_metrics.py,sha256=f_U4Tk3Ke5sR0_OYvoE2_nD6wrOCveg3DM3B9pStVUI,20454
|
|
61
|
+
replay/metrics/precision.py,sha256=DRlsgY_b4bJCOSZjCA58N41REMiDt-dbagRSXxfXyvY,2256
|
|
62
|
+
replay/metrics/recall.py,sha256=fzpASDiH88zcpXJZTbStQ3nuzzSdhd9k1wjF27rM4wc,2447
|
|
63
|
+
replay/metrics/rocauc.py,sha256=1vaVEK7DQTL8BX-i7A64hTFWyO38aNycscPGrdWKwbA,3282
|
|
64
|
+
replay/metrics/surprisal.py,sha256=HkmYrOuw3jydxFrkidjdcpAcKz2DeOnMsKqwB2g9pwY,7526
|
|
65
|
+
replay/metrics/torch_metrics_builder.py,sha256=mnHrmRTOKZ_edrTrTKs7IPzKt5DkQYRd2B_8b3bB9yU,14071
|
|
66
|
+
replay/metrics/unexpectedness.py,sha256=LSi-z50l3_yrvLnmToHQzm6Ygf2QpNt_zhk6jdg7QUo,6882
|
|
67
|
+
replay/models/__init__.py,sha256=kECYluQZ83zRUWaHVvnt7Tg3BerHrJy9v8XfRxsqyYY,1123
|
|
68
|
+
replay/models/als.py,sha256=1MFAbcx64tv0MX1wE9CM1NxKD3F3ZDhZUrmt6dvHu74,6220
|
|
69
|
+
replay/models/association_rules.py,sha256=shBNsKjlii0YK-XA6bSl5Ov0ZFTnjxZbgKJU9PFYptY,14507
|
|
70
|
+
replay/models/base_neighbour_rec.py,sha256=SdGb2ejpYjHmxFNTk5zwEo0RWdfPAj1vKGP_oj7IrQo,7783
|
|
71
|
+
replay/models/base_rec.py,sha256=aNIEbSy8G5q92NOpDlSJbp0Z-lAkazFLa9eDAajl1wI,56067
|
|
72
|
+
replay/models/cat_pop_rec.py,sha256=ed1X1PDQY41hFJ1cO3Q5OWy0rXhV5_n23hJ-QHWONtE,11968
|
|
73
|
+
replay/models/cluster.py,sha256=9JcpGnbfgFa4UsyxPAa4WMuJFa3rsuAxiKoy-s_UfyE,4970
|
|
74
|
+
replay/models/common.py,sha256=rFmfwwzWCWED2HaDVuSN7ZUAgaNPGPawUudgn4IApbo,2121
|
|
75
|
+
replay/models/extensions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
76
|
+
replay/models/extensions/ann/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
|
+
replay/models/extensions/ann/ann_mixin.py,sha256=Ua1fuwrvtISNDQ8iPV-ln8S1LDKz8-rIU2UYsMExAiU,7782
|
|
78
|
+
replay/models/extensions/ann/entities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
79
|
+
replay/models/extensions/ann/entities/base_hnsw_param.py,sha256=5GRdcQj4-zhNXfJ7ko2WHGHgRuXCzSHCRcRxljl1V4c,776
|
|
80
|
+
replay/models/extensions/ann/entities/hnswlib_param.py,sha256=j3V4JXM_yfR6s2TjYakIXMg-zS1-MrP6an930DEIWGM,2104
|
|
81
|
+
replay/models/extensions/ann/entities/nmslib_hnsw_param.py,sha256=WeEhRR4jKqgvWK_zDK8fx6kEqc2e_bc0kubvqK3iV8c,2162
|
|
82
|
+
replay/models/extensions/ann/index_builders/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
83
|
+
replay/models/extensions/ann/index_builders/base_index_builder.py,sha256=Ul25G0FaNLOXUjrDXxZDTg7tLXlv1N6wR8kWjWICtZ0,2110
|
|
84
|
+
replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py,sha256=U8-3lRahyWmWkZ7tYuO-Avd1jX-lGh7JukC140wJ-WQ,1600
|
|
85
|
+
replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py,sha256=1NLWyAJGYgp46uUBhUYQyd0stmG6DhLh7U4JEne5TFw,1308
|
|
86
|
+
replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py,sha256=cf3LhBCRRN-lBYGlJbv8vnY-KVeHAleN5cVjvd58Ibs,2476
|
|
87
|
+
replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py,sha256=0DPJ3WAt0cZ5dmtZv87fmMEgYXWf8rM35f7CA_DgWZY,2618
|
|
88
|
+
replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py,sha256=AIkVnobesnTM5lrBSWf9gd0CySwFQ0vH_DjemfLS4Cs,1925
|
|
89
|
+
replay/models/extensions/ann/index_inferers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
90
|
+
replay/models/extensions/ann/index_inferers/base_inferer.py,sha256=I39aqEc2somfndrCd-KC3XYZnYSrJ2hGpR9y6wO93NA,2524
|
|
91
|
+
replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py,sha256=JjT4l_XAjzUOsTAE7OS88zAgPd_h_O44oUnn2kVr8E0,2477
|
|
92
|
+
replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py,sha256=CoY_oMfdcwnh87ceuSpHXu4Czle9xxeMisO8XJUuJLE,1717
|
|
93
|
+
replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py,sha256=tjuqbkztWBU4K6qp5LPFU_GOGJf2f4oXneExtUEVUzw,3128
|
|
94
|
+
replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py,sha256=S5eCBZlTXxEAeX6yeZGC7j56gOcJ7lMNb4Cs_5PEj9E,2203
|
|
95
|
+
replay/models/extensions/ann/index_inferers/utils.py,sha256=6IST2FPSY3nuYu5KqzRpd4FgdaV3GnQRQlxp9LN_yyA,641
|
|
96
|
+
replay/models/extensions/ann/index_stores/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
97
|
+
replay/models/extensions/ann/index_stores/base_index_store.py,sha256=u4l2ybAXX92ZMGK7NqqosbKF75QgFqhAMVadd5ePj6Y,910
|
|
98
|
+
replay/models/extensions/ann/index_stores/hdfs_index_store.py,sha256=0zDq9PdDOiD6HvtZlfjTbuJHfeTOWBTQ_HiuqZmoxtA,3090
|
|
99
|
+
replay/models/extensions/ann/index_stores/shared_disk_index_store.py,sha256=thl4T1uYU4Gtk4nBao_qK8CbFBdX1xmXNishxgfCd-I,2030
|
|
100
|
+
replay/models/extensions/ann/index_stores/spark_files_index_store.py,sha256=QP_8mE7EIBbePIe0AB-IWuJLRA5MR3wswCEt8oHzn-0,3617
|
|
101
|
+
replay/models/extensions/ann/index_stores/utils.py,sha256=6r2GP_EFCaCguolW857pb4lRS8rh6_Nv_Edso9_j5no,3756
|
|
102
|
+
replay/models/extensions/ann/utils.py,sha256=AgQvThi_DvEtakQeTno9hVZVWiWMFHKTjRcQ2wLa5vk,1222
|
|
103
|
+
replay/models/kl_ucb.py,sha256=L6vC2KsTBTTx4ckmGhWybOiLa5Wt54N7cgl7jS2FQRg,6731
|
|
104
|
+
replay/models/knn.py,sha256=HEiGHHQg9pV1_EIWZHfK-XD0BNAm1bj1c0ND9rYnj3k,8992
|
|
105
|
+
replay/models/lin_ucb.py,sha256=iAR3PbbaQKqmisOKEx9ZyfpxnxcZomr6YauG4mvSakU,18800
|
|
106
|
+
replay/models/nn/__init__.py,sha256=AT3o1qXaxUq4_QIGlcGuSs54ZpueOo-SbpZwuGI-6os,41
|
|
107
|
+
replay/models/nn/loss/__init__.py,sha256=s3iO9QTZvLz_ony2b5K0hEmDmitrXQnAe9j6BRxLpR4,53
|
|
108
|
+
replay/models/nn/loss/sce.py,sha256=mRJZYmwQNT-kMi66pXrE1-OdM7y_YEQFHzC37odnEo4,5165
|
|
109
|
+
replay/models/nn/optimizer_utils/__init__.py,sha256=9aiEk662v9-qJgzi8TZYaaqQSiZWr4ZleFHwcLOZX14,219
|
|
110
|
+
replay/models/nn/optimizer_utils/optimizer_factory.py,sha256=OJPX4XD_uG0iZKzxmhzT60uS66swGtpZnAV5A98vcgs,3439
|
|
111
|
+
replay/models/nn/sequential/__init__.py,sha256=CI2n0cxs_amqJrwBMq6n0Z_uBOu7CGXfagqvE4Jlmjw,128
|
|
112
|
+
replay/models/nn/sequential/bert4rec/__init__.py,sha256=JfZqHOGxcvOkICl5cWmZbZhaKXpkIvua-Wj57VWWEhw,399
|
|
113
|
+
replay/models/nn/sequential/bert4rec/dataset.py,sha256=Y63LESZYqKDG3OSvrWRy4Tgkib24VKJ9c9qJsGarr5k,12110
|
|
114
|
+
replay/models/nn/sequential/bert4rec/lightning.py,sha256=vxAf1H1VfLqgZhOz9fxEMmw4L3wfOr_wFnWHn_vPE34,28351
|
|
115
|
+
replay/models/nn/sequential/bert4rec/model.py,sha256=C1AKcQ8KF0XMXERwrFneW9kg7hzc-9FIqhCc-t91F7o,17469
|
|
116
|
+
replay/models/nn/sequential/callbacks/__init__.py,sha256=Q7mSZ_RB6iyD7QZaBL_NJ0uh8cRfgxq7gtPHbkSyhoo,282
|
|
117
|
+
replay/models/nn/sequential/callbacks/prediction_callbacks.py,sha256=UtEzO9_f5Jwku9dbz7twr4o2_cV3L-viC4lQuce5l1c,10808
|
|
118
|
+
replay/models/nn/sequential/callbacks/validation_callback.py,sha256=ydcNkUhaFD78ogqZWySzzKg4BaPyEkaRqmLiD4qFDzM,6583
|
|
119
|
+
replay/models/nn/sequential/compiled/__init__.py,sha256=eSVcCaUH5cDJQRbC7K99X7uMNR-Z-KR4TmYOGKWWJCI,531
|
|
120
|
+
replay/models/nn/sequential/compiled/base_compiled_model.py,sha256=f4AuTyx5tufQOtOWUSEgj1cWvMZzSL7YN2Z-PtURgTY,10478
|
|
121
|
+
replay/models/nn/sequential/compiled/bert4rec_compiled.py,sha256=woGI3qk4J2Rb5FyaDwpSCuG-AMfyH34F6Bt5pV-wqk0,6798
|
|
122
|
+
replay/models/nn/sequential/compiled/sasrec_compiled.py,sha256=eCRpxPdu94KyvczYJx2jgt6xaTZ3RpNYvcfyyyTYuiA,6170
|
|
123
|
+
replay/models/nn/sequential/postprocessors/__init__.py,sha256=89LGzkNHukcuC2-rfpiz7vmv1zyk6MNY-8zaXrvtn0M,164
|
|
124
|
+
replay/models/nn/sequential/postprocessors/_base.py,sha256=Pi8vWcaiqj3XddaxbiOYd5ME7ZfIkk0GPoCgpFKdO0g,1300
|
|
125
|
+
replay/models/nn/sequential/postprocessors/postprocessors.py,sha256=fclLmGkJbWAnNBw-Rvc_kKQsw0rUF2jfJ7s6VF8ge4I,8021
|
|
126
|
+
replay/models/nn/sequential/sasrec/__init__.py,sha256=c6130lRpPkcbuGgkM7slagBIgH7Uk5zUtSzFDEwAsik,250
|
|
127
|
+
replay/models/nn/sequential/sasrec/dataset.py,sha256=Le_rG-MoCpWoSKdrEJOyLo3S617FLMEMI8Ix51YEQx0,9452
|
|
128
|
+
replay/models/nn/sequential/sasrec/lightning.py,sha256=nJthkJvgp-nBy6mtt_5PvzUFihArXYTnZdAih85F01U,27067
|
|
129
|
+
replay/models/nn/sequential/sasrec/model.py,sha256=xLPz2HvPkDGMaXiWcyug7auQgBG-ai37OPFb7_jvorU,27876
|
|
130
|
+
replay/models/optimization/__init__.py,sha256=N8xCuzu0jQGwHrIBjuTRf-ZcZuBJ6FB0d9C5a7izJQU,338
|
|
131
|
+
replay/models/optimization/optuna_mixin.py,sha256=pKu-Vw9l2LsDycubpdJiLkC1eE4pKrDG0T2lhUgRUB4,11960
|
|
132
|
+
replay/models/optimization/optuna_objective.py,sha256=UHWOJwBngPA3IRz9yAMEWPg00oyb7Wq9PXuRPYHIiLE,7538
|
|
133
|
+
replay/models/pop_rec.py,sha256=Ju9y2rU2vW_jFU9-W15fbbr5_ZzYGihSjSxsqKsAf0Q,4964
|
|
134
|
+
replay/models/query_pop_rec.py,sha256=UNsHtf3eQpJom73ZmEO5us4guI4SnCLJYTfuUpRgqes,4086
|
|
135
|
+
replay/models/random_rec.py,sha256=9SC012_X3sNzrAjDG1CPGhjisZb6gnv4VCW7yIMSNpk,8066
|
|
136
|
+
replay/models/slim.py,sha256=OAdTS64bObZujzHkq8vfP1kkoLMSWxk1KLg6lCCA0N8,4551
|
|
137
|
+
replay/models/thompson_sampling.py,sha256=gcjlVl1mPiEVt70y8frA762O-eCZzd3SVg1lnDRCEHk,1939
|
|
138
|
+
replay/models/ucb.py,sha256=b2qFgvOAZcyv5triPk18duqF_jt-ty7mypenjRLNWwQ,6952
|
|
139
|
+
replay/models/wilson.py,sha256=o7aUWjq3648dAfgGBoWD5Gu-HzdyobPMaH2lzCLijiA,4558
|
|
140
|
+
replay/models/word2vec.py,sha256=atfj6GjR_L-TdurRFr1yi7B3BicJ3ZdFxixW9RfojJs,8882
|
|
141
|
+
replay/nn/__init__.py,sha256=Bd_Xi9s5g1zWSjMwk50ztG9oezhs37r2L4-mfB-gEsg,256
|
|
142
|
+
replay/nn/agg.py,sha256=JneTgVlo00cEg5FxzIp6NvNVOXqvL45e9vsXPP_5ztg,3799
|
|
143
|
+
replay/nn/attention.py,sha256=RR_KsqvnrZ1ZYr51KTBA9q5gB-0sqhmakjH1JdIo9dE,7812
|
|
144
|
+
replay/nn/embedding.py,sha256=xY_zPpC055cTXAZ8TShUYP3ZrBUA2HQwn4dkOCKXYJ0,11876
|
|
145
|
+
replay/nn/ffn.py,sha256=ivOFu14289URepyEFxYov_XNYMUrINjU-2rEqoXxbnU,4618
|
|
146
|
+
replay/nn/head.py,sha256=csjwQrcA7M7FebgSL1tKDbjfaoni52CymQR0Zt8RhWg,2084
|
|
147
|
+
replay/nn/lightning/__init__.py,sha256=jHiwtYuboGUY4Of18zrkvdWD0xXJ_zuo83-XgiqxSfY,36
|
|
148
|
+
replay/nn/lightning/callback/__init__.py,sha256=ImNEJeIK-wJnqdkZgP8tWTDQHaS9xYqzTEf3FEM0XAw,253
|
|
149
|
+
replay/nn/lightning/callback/metrics_callback.py,sha256=dIu1wDtqjXH6ogFGsh2L-dpkgz7OKjtTrVbBLrI4pjg,6986
|
|
150
|
+
replay/nn/lightning/callback/predictions_callback.py,sha256=e9PeXNyyGz-m46FEaafgCToPEVC9T5Cb8Q4sFArnpLY,11347
|
|
151
|
+
replay/nn/lightning/module.py,sha256=jFvevwiriY9alZMBw6KAiRMsJv-dJ8fEVrenVRiuWeI,5246
|
|
152
|
+
replay/nn/lightning/optimizer.py,sha256=1tXhz9RIBHLpEQtZ1PUzCAc4mn6Q_E38zR0nf5km6U8,1778
|
|
153
|
+
replay/nn/lightning/postprocessor/__init__.py,sha256=LhUeOWDD5vRBDXF2tQEjvPKH1rNIlrf5KPbcV66AdtQ,77
|
|
154
|
+
replay/nn/lightning/postprocessor/_base.py,sha256=X0LtYItmxlXt4Sxk3cOdyIK3FG5dijQzyh5Kv6s5FjE,1592
|
|
155
|
+
replay/nn/lightning/postprocessor/seen_items.py,sha256=h-sfD3vmNCdS7lYvqCfqw9oPqutmaSIuZ0CIidG0Y30,2922
|
|
156
|
+
replay/nn/lightning/scheduler.py,sha256=CUuynPTFrKBrkpmbWR-xpfAkHZ0Vfz_THUDo3uoZi8k,2714
|
|
157
|
+
replay/nn/loss/__init__.py,sha256=YXAXQIN0coj8MxeK5isTGXgvMxhH5pUO6j1D3d7jl3A,471
|
|
158
|
+
replay/nn/loss/base.py,sha256=oD1vATWoQDi45zG9EPjg3hgDrfpr4ue_rQFfArn1dFs,8871
|
|
159
|
+
replay/nn/loss/bce.py,sha256=cPlxdJTBZ0b22K6V9ve4qo7xkp99CjEsnl3_vVGphqs,8373
|
|
160
|
+
replay/nn/loss/ce.py,sha256=jOmhLtKD_E0jX8tUfXpsmaaQVHKKiwXW9USB_GyN3ZU,13218
|
|
161
|
+
replay/nn/loss/login_ce.py,sha256=NER_Hbs_H3IXn_bkgwG25VQNQ6ZjjDcxq-aMI7pC2eM,16498
|
|
162
|
+
replay/nn/loss/logout_ce.py,sha256=KhcYyCnUzLZR1sFpxM6_QliLoxmC6MJoLkPOgf_ZYzU,10306
|
|
163
|
+
replay/nn/mask.py,sha256=Jbx7sulGZYfasNaD9CZzJma0cEVaDlxdpzs295507II,3329
|
|
164
|
+
replay/nn/normalization.py,sha256=Z86t5WCr4KfVR9qCCe-EIAwwomnIIxb11PP88WHA1JI,187
|
|
165
|
+
replay/nn/output.py,sha256=6uecMOMN4FGoQ-NzKGacZnlrk_9TwQswpC-x3G_DMTY,1291
|
|
166
|
+
replay/nn/sequential/__init__.py,sha256=jet_ueMz5Bm087JDph7ln87NID7DbCb0WENj-tjoOGg,229
|
|
167
|
+
replay/nn/sequential/sasrec/__init__.py,sha256=8crj-JL8xeP-cCOCnxCSVF_-R6feKhj0YRHOcaMsqrU,213
|
|
168
|
+
replay/nn/sequential/sasrec/agg.py,sha256=e-IkIO-MMbei2UGxTUopWvloguJoVaZiN31sXkdUVag,2004
|
|
169
|
+
replay/nn/sequential/sasrec/diff_transformer.py,sha256=4ehM5EMizajmWBAzmcj3CYSFl21V1R2b7RDRJlx3O4Q,4790
|
|
170
|
+
replay/nn/sequential/sasrec/model.py,sha256=sQ2FvfDyZ3G6PjbNME--fMboqUt66z9J8t8YYlJ9J6Q,14803
|
|
171
|
+
replay/nn/sequential/sasrec/transformer.py,sha256=sJf__IPnhbJWDPuFTPSbBGSSntznQtS-hJtJo3iFBkw,4037
|
|
172
|
+
replay/nn/sequential/twotower/__init__.py,sha256=-rEASPqKCbS55MTTgeDZ5irfWfM9or1vNTHZnJN2AcU,124
|
|
173
|
+
replay/nn/sequential/twotower/model.py,sha256=VxUUjldHndCkDjrXGqmxGnTi5fh8vmnr7XNBpYjsqW8,28659
|
|
174
|
+
replay/nn/sequential/twotower/reader.py,sha256=j4mlKx5Lf3hFnSgaxMLkuqWLZd3dkLchDI4JEuZHLGY,3674
|
|
175
|
+
replay/nn/transform/__init__.py,sha256=9PeaDHmftb0s1gEEgJRNWw6Bl2wfE_-lImatipaHUQ0,705
|
|
176
|
+
replay/nn/transform/copy.py,sha256=ZfNXbMJYTwXDMJ5T8ib9Dh5XOGLjj7gGB4NbBExFZiM,1302
|
|
177
|
+
replay/nn/transform/grouping.py,sha256=XOJoVBk234DI6x05Kqr7KOjLetDaLp2NMAJWHecQcsI,1384
|
|
178
|
+
replay/nn/transform/negative_sampling.py,sha256=R5di5-IuNtpbjcjHYcBTZYd6Lk2R5_I77PVioaL5s5w,7475
|
|
179
|
+
replay/nn/transform/next_token.py,sha256=UONG8_J-UxZdRCOEcz7fvU40k-hvE_h7ff014L9Ukpg,4491
|
|
180
|
+
replay/nn/transform/rename.py,sha256=_uD2e1UmtBRyOTVpHUnZ5xhePmClaGQsc0g7Es-rupE,1026
|
|
181
|
+
replay/nn/transform/reshape.py,sha256=sgswIogWHUwOVp02k13Qopn84LofqLoA4M7U1GAfmio,1359
|
|
182
|
+
replay/nn/transform/sequence_roll.py,sha256=7jf42SgWHU1L7SirqQWXx0h9a6VQQ29kehE4LmdUt9o,1531
|
|
183
|
+
replay/nn/transform/template/__init__.py,sha256=lYzAekZUXwncGR66Nq8YypplGOtL00GFfm0PalGiY5g,106
|
|
184
|
+
replay/nn/transform/template/sasrec.py,sha256=FoOhroe-S0JPaxIQ3Ba-3_gyslgj47RoLL2geOxNAO4,1906
|
|
185
|
+
replay/nn/transform/template/twotower.py,sha256=BIlbqTfKEMcyx2Ksr4qzAD0h0mdhiTLa1xcmZ2e8Ksc,896
|
|
186
|
+
replay/nn/transform/token_mask.py,sha256=WcalZkY2UCoNiq2mBtu8fqYFOUfqCh21XyDMgvIpeB4,2529
|
|
187
|
+
replay/nn/transform/trim.py,sha256=mPn6LPxu3c3yE14heMSRsDEU4h94tkFiRr62mOa3lKg,1608
|
|
188
|
+
replay/nn/utils.py,sha256=GumtN-QRP9ljXYti3YvuNk13e0Q92xvkYuCJBhaViCI,801
|
|
189
|
+
replay/preprocessing/__init__.py,sha256=c6wFPAc6lATyp0lE-ZDjHMsXyEMPKX7Usuqylv6H5XQ,597
|
|
190
|
+
replay/preprocessing/converter.py,sha256=JQ-4u5x0eXtswl1iH-bZITBXQov1nebnZ6XcvpD8Twk,4417
|
|
191
|
+
replay/preprocessing/discretizer.py,sha256=jzYqvoSVmiL-oS-ri9Om0vSDoU8bCQimjUoe7FiPfLU,27024
|
|
192
|
+
replay/preprocessing/filters.py,sha256=cCX8BikKNqcAGFpJkYssQkR_6tUjjktSlpZOxK1ezUw,49930
|
|
193
|
+
replay/preprocessing/history_based_fp.py,sha256=oEu1CkCz7xcGbPdSTHfhTe1NimnFo50Arn8qngRBgE8,18702
|
|
194
|
+
replay/preprocessing/label_encoder.py,sha256=puedlFGitjI_yi4uxRIR6L4Wz6oZ93gIEPeylC-jCtI,41459
|
|
195
|
+
replay/preprocessing/sessionizer.py,sha256=G6i0K3FwqtweRxvcSYraJ-tBWAT2HnV-bWHHlIZJF-s,12217
|
|
196
|
+
replay/preprocessing/utils.py,sha256=e-JRoadbeTe3Qvp_NXMZNQkmgedeR6iJLyO_82xKPd0,7109
|
|
197
|
+
replay/scenarios/__init__.py,sha256=XXAKEQPTLlve-0O6NPwFgahFrb4oGcIq3HaYaaGxG2E,94
|
|
198
|
+
replay/scenarios/fallback.py,sha256=dO3s9jqYup4rbgMaY6Z6HGm1r7SXkm7jOvNZDr5zm_U,7138
|
|
199
|
+
replay/splitters/__init__.py,sha256=9vhrZ8nCgq_NYJkv4wn0JYqhKURZH6Z8IyRNN1BX6AI,510
|
|
200
|
+
replay/splitters/base_splitter.py,sha256=zvYVEHBYrK8Y2qPv3kYachfLFwR9-kUAiU1UJSNGS8A,7749
|
|
201
|
+
replay/splitters/cold_user_random_splitter.py,sha256=32VgAHiwk9Emkofu1KqwGZrrFiyrYtSQ3YPdt5p_XoQ,4423
|
|
202
|
+
replay/splitters/k_folds.py,sha256=RDDL3gE6M5qfK5Ig-pxxJeq3O4uxsWJjLFQRRzQ2Ssg,6211
|
|
203
|
+
replay/splitters/last_n_splitter.py,sha256=hMWIGYFg17LioT08VBXut5Ic-w9oXsKd739cy2xuwYs,15368
|
|
204
|
+
replay/splitters/new_users_splitter.py,sha256=NksAdl_wL9zwHj3cY5NqrrnkOajgyUDloSsRZ9HUE48,9160
|
|
205
|
+
replay/splitters/random_next_n_splitter.py,sha256=aRqRe1jll7o5Hj-si-jyr341T4nXLfpX39crwVpLl-Y,8713
|
|
206
|
+
replay/splitters/random_splitter.py,sha256=0DO0qulT0jp_GXswmFh3BMJ7utS-z9e-r5jIrmTKGC4,2989
|
|
207
|
+
replay/splitters/ratio_splitter.py,sha256=rFWN-nKBYx1qKrmtYzjYf08DWFiKOCo5ZRUz-NHJFfs,17506
|
|
208
|
+
replay/splitters/time_splitter.py,sha256=0ZAMK26b--1wjrfzCuNVBh7gMPTa8SGf4LMEgACiUxA,9013
|
|
209
|
+
replay/splitters/two_stage_splitter.py,sha256=8Zn6BTJmZg04CD4l2jmil2dEu6xtglJaSS5mkotIXRc,17823
|
|
210
|
+
replay/utils/__init__.py,sha256=3Skc9bUISEPPMMxdUCCT_S1q-i7cAT3KT0nExe-VMrw,343
|
|
211
|
+
replay/utils/common.py,sha256=_sBKR1hlZavXll8NN0hyGDIEdLakccPofu8JskHpBgk,5488
|
|
212
|
+
replay/utils/dataframe_bucketizer.py,sha256=LipmBBQkdkLGroZpbP9i7qvTombLdMxo2dUUys1m5OY,3748
|
|
213
|
+
replay/utils/distributions.py,sha256=UuhaC9HI6HnUXW97fEd-TsyDk4JT8t7k1T_6l5FpOMs,1203
|
|
214
|
+
replay/utils/model_handler.py,sha256=6WRyd39B-UXTtKTHWD_ssYN1vMmkjd417bwKb50uqJY,5754
|
|
215
|
+
replay/utils/session_handler.py,sha256=fQo2wseow8yuzKnEXT-aYAXcQIgRbTTXp0v7g1VVi0w,5138
|
|
216
|
+
replay/utils/spark_utils.py,sha256=GbRp-MuUoO3Pc4chFvlmo9FskSlRLeNlC3Go5pEJ6Ok,27411
|
|
217
|
+
replay/utils/time.py,sha256=J8asoQBytPcNw-BLGADYIsKeWhIoN1H5hKiX9t2AMqo,9376
|
|
218
|
+
replay/utils/types.py,sha256=rD9q9CqEXgF4yy512Hv2nXclvwcnfodOnhBZ1HSUI4c,1260
|
|
219
|
+
replay_rec-0.21.0.dist-info/METADATA,sha256=9KaxfPOyxMV7l4O3L3qy59ACnvB1-ZbhwynJGKKlXzw,13573
|
|
220
|
+
replay_rec-0.21.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
221
|
+
replay_rec-0.21.0.dist-info/licenses/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
|
|
222
|
+
replay_rec-0.21.0.dist-info/licenses/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
|
|
223
|
+
replay_rec-0.21.0.dist-info/RECORD,,
|
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Most metrics require dataframe with recommendations
|
|
3
|
-
and dataframe with ground truth values —
|
|
4
|
-
which objects each user interacted with.
|
|
5
|
-
|
|
6
|
-
- recommendations (Union[pandas.DataFrame, spark.DataFrame]):
|
|
7
|
-
predictions of a recommender system,
|
|
8
|
-
DataFrame with columns ``[user_id, item_id, relevance]``
|
|
9
|
-
- ground_truth (Union[pandas.DataFrame, spark.DataFrame]):
|
|
10
|
-
test data, DataFrame with columns
|
|
11
|
-
``[user_id, item_id, timestamp, relevance]``
|
|
12
|
-
|
|
13
|
-
Metric is calculated for all users, presented in ``ground_truth``
|
|
14
|
-
for accurate metric calculation in case when the recommender system generated
|
|
15
|
-
recommendation not for all users. It is assumed, that all users,
|
|
16
|
-
we want to calculate metric for, have positive interactions.
|
|
17
|
-
|
|
18
|
-
But if we have users, who observed the recommendations, but have not responded,
|
|
19
|
-
those users will be ignored and metric will be overestimated.
|
|
20
|
-
For such case we propose additional optional parameter ``ground_truth_users``,
|
|
21
|
-
the dataframe with all users, which should be considered during the metric calculation.
|
|
22
|
-
|
|
23
|
-
- ground_truth_users (Optional[Union[pandas.DataFrame, spark.DataFrame]]):
|
|
24
|
-
full list of users to calculate metric for, DataFrame with ``user_id`` column
|
|
25
|
-
|
|
26
|
-
Every metric is calculated using top ``K`` items for each user.
|
|
27
|
-
It is also possible to calculate metrics
|
|
28
|
-
using multiple values for ``K`` simultaneously.
|
|
29
|
-
In this case the result will be a dictionary and not a number.
|
|
30
|
-
|
|
31
|
-
Make sure your recommendations do not contain user-item duplicates
|
|
32
|
-
as duplicates could lead to the wrong calculation results.
|
|
33
|
-
|
|
34
|
-
- k (Union[Iterable[int], int]):
|
|
35
|
-
a single number or a list, specifying the
|
|
36
|
-
truncation length for recommendation list for each user
|
|
37
|
-
|
|
38
|
-
By default, metrics are averaged by users,
|
|
39
|
-
but you can alternatively use method ``metric.median``.
|
|
40
|
-
Also, you can get the lower bound
|
|
41
|
-
of ``conf_interval`` for a given ``alpha``.
|
|
42
|
-
|
|
43
|
-
Diversity metrics require extra parameters on initialization stage,
|
|
44
|
-
but do not use ``ground_truth`` parameter.
|
|
45
|
-
|
|
46
|
-
For each metric, a formula for its calculation is given, because this is
|
|
47
|
-
important for the correct comparison of algorithms, as mentioned in our
|
|
48
|
-
`article <https://arxiv.org/abs/2206.12858>`_.
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
from replay.experimental.metrics.base_metric import Metric, NCISMetric
|
|
52
|
-
from replay.experimental.metrics.coverage import Coverage
|
|
53
|
-
from replay.experimental.metrics.hitrate import HitRate
|
|
54
|
-
from replay.experimental.metrics.map import MAP
|
|
55
|
-
from replay.experimental.metrics.mrr import MRR
|
|
56
|
-
from replay.experimental.metrics.ncis_precision import NCISPrecision
|
|
57
|
-
from replay.experimental.metrics.ndcg import NDCG
|
|
58
|
-
from replay.experimental.metrics.precision import Precision
|
|
59
|
-
from replay.experimental.metrics.recall import Recall
|
|
60
|
-
from replay.experimental.metrics.rocauc import RocAuc
|
|
61
|
-
from replay.experimental.metrics.surprisal import Surprisal
|
|
62
|
-
from replay.experimental.metrics.unexpectedness import Unexpectedness
|