replay-rec 0.19.0rc0__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +6 -2
- replay/data/dataset.py +9 -9
- replay/data/nn/__init__.py +6 -6
- replay/data/nn/sequence_tokenizer.py +44 -38
- replay/data/nn/sequential_dataset.py +13 -8
- replay/data/nn/torch_sequential_dataset.py +14 -13
- replay/data/nn/utils.py +1 -1
- replay/metrics/base_metric.py +1 -1
- replay/metrics/coverage.py +7 -11
- replay/metrics/experiment.py +3 -3
- replay/metrics/offline_metrics.py +2 -2
- replay/models/__init__.py +19 -0
- replay/models/association_rules.py +1 -4
- replay/models/base_neighbour_rec.py +6 -9
- replay/models/base_rec.py +44 -293
- replay/models/cat_pop_rec.py +2 -1
- replay/models/common.py +69 -0
- replay/models/extensions/ann/ann_mixin.py +30 -25
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
- replay/models/extensions/ann/utils.py +4 -3
- replay/models/knn.py +18 -17
- replay/models/nn/sequential/bert4rec/dataset.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +2 -2
- replay/models/nn/sequential/compiled/__init__.py +10 -0
- replay/models/nn/sequential/compiled/base_compiled_model.py +3 -1
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
- replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
- replay/models/nn/sequential/sasrec/dataset.py +1 -1
- replay/models/nn/sequential/sasrec/model.py +1 -1
- replay/models/optimization/__init__.py +14 -0
- replay/models/optimization/optuna_mixin.py +279 -0
- replay/{optimization → models/optimization}/optuna_objective.py +13 -15
- replay/models/slim.py +2 -4
- replay/models/word2vec.py +7 -12
- replay/preprocessing/discretizer.py +1 -2
- replay/preprocessing/history_based_fp.py +1 -1
- replay/preprocessing/label_encoder.py +1 -1
- replay/splitters/cold_user_random_splitter.py +13 -7
- replay/splitters/last_n_splitter.py +17 -10
- replay/utils/__init__.py +6 -2
- replay/utils/common.py +4 -2
- replay/utils/model_handler.py +11 -31
- replay/utils/session_handler.py +2 -2
- replay/utils/spark_utils.py +2 -2
- replay/utils/types.py +28 -18
- replay/utils/warnings.py +26 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info}/METADATA +56 -40
- replay_rec-0.20.0.dist-info/RECORD +139 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info}/WHEEL +1 -1
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +0 -62
- replay/experimental/metrics/base_metric.py +0 -602
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -13
- replay/experimental/models/admm_slim.py +0 -205
- replay/experimental/models/base_neighbour_rec.py +0 -204
- replay/experimental/models/base_rec.py +0 -1340
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -454
- replay/experimental/models/ddpg.py +0 -923
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -265
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/hierarchical_recommender.py +0 -331
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -302
- replay/experimental/models/mult_vae.py +0 -332
- replay/experimental/models/neural_ts.py +0 -986
- replay/experimental/models/neuromf.py +0 -406
- replay/experimental/models/scala_als.py +0 -296
- replay/experimental/models/u_lin_ucb.py +0 -115
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -102
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -839
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
- replay/experimental/scenarios/two_stages/__init__.py +0 -0
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -186
- replay/experimental/utils/session_handler.py +0 -44
- replay/optimization/__init__.py +0 -5
- replay_rec-0.19.0rc0.dist-info/RECORD +0 -191
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info/licenses}/LICENSE +0 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info/licenses}/NOTICE +0 -0
|
@@ -1,76 +0,0 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
|
-
from replay.utils import DataFrameLike, SparkDataFrame
|
|
4
|
-
from replay.utils.spark_utils import convert2spark, get_top_k_recs
|
|
5
|
-
|
|
6
|
-
from .base_metric import RecOnlyMetric, fill_na_with_empty_array, filter_sort
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class Unexpectedness(RecOnlyMetric):
|
|
10
|
-
"""
|
|
11
|
-
Fraction of recommended items that are not present in some baseline recommendations.
|
|
12
|
-
|
|
13
|
-
>>> import pandas as pd
|
|
14
|
-
>>> from replay.utils.session_handler import get_spark_session, State
|
|
15
|
-
>>> spark = get_spark_session(1, 1)
|
|
16
|
-
>>> state = State(spark)
|
|
17
|
-
|
|
18
|
-
>>> log = pd.DataFrame({
|
|
19
|
-
... "user_idx": [1, 1, 1],
|
|
20
|
-
... "item_idx": [1, 2, 3],
|
|
21
|
-
... "relevance": [5, 5, 5],
|
|
22
|
-
... "timestamp": [1, 1, 1],
|
|
23
|
-
... })
|
|
24
|
-
>>> recs = pd.DataFrame({
|
|
25
|
-
... "user_idx": [1, 1, 1],
|
|
26
|
-
... "item_idx": [0, 0, 1],
|
|
27
|
-
... "relevance": [5, 5, 5],
|
|
28
|
-
... "timestamp": [1, 1, 1],
|
|
29
|
-
... })
|
|
30
|
-
>>> metric = Unexpectedness(log)
|
|
31
|
-
>>> round(metric(recs, 3), 2)
|
|
32
|
-
0.67
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
_scala_udf_name = "getUnexpectednessMetricValue"
|
|
36
|
-
|
|
37
|
-
def __init__(self, pred: DataFrameLike, use_scala_udf: bool = False):
|
|
38
|
-
"""
|
|
39
|
-
:param pred: model predictions
|
|
40
|
-
"""
|
|
41
|
-
self._use_scala_udf = use_scala_udf
|
|
42
|
-
self.pred = convert2spark(pred)
|
|
43
|
-
|
|
44
|
-
@staticmethod
|
|
45
|
-
def _get_metric_value_by_user(k, *args) -> float:
|
|
46
|
-
pred = args[0]
|
|
47
|
-
base_pred = args[1]
|
|
48
|
-
if len(pred) == 0:
|
|
49
|
-
return 0
|
|
50
|
-
return 1.0 - len(set(pred[:k]) & set(base_pred[:k])) / k
|
|
51
|
-
|
|
52
|
-
def _get_enriched_recommendations(
|
|
53
|
-
self,
|
|
54
|
-
recommendations: SparkDataFrame,
|
|
55
|
-
ground_truth: SparkDataFrame, # noqa: ARG002
|
|
56
|
-
max_k: int,
|
|
57
|
-
ground_truth_users: Optional[DataFrameLike] = None,
|
|
58
|
-
) -> SparkDataFrame:
|
|
59
|
-
recommendations = convert2spark(recommendations)
|
|
60
|
-
ground_truth_users = convert2spark(ground_truth_users)
|
|
61
|
-
base_pred = self.pred
|
|
62
|
-
|
|
63
|
-
# TO DO: preprocess base_recs once in __init__
|
|
64
|
-
|
|
65
|
-
base_recs = filter_sort(base_pred).withColumnRenamed("pred", "base_pred")
|
|
66
|
-
|
|
67
|
-
# if there are duplicates in recommendations,
|
|
68
|
-
# we will leave fewer than k recommendations after sort_udf
|
|
69
|
-
recommendations = get_top_k_recs(recommendations, k=max_k)
|
|
70
|
-
recommendations = filter_sort(recommendations)
|
|
71
|
-
recommendations = recommendations.join(base_recs, how="right", on=["user_idx"])
|
|
72
|
-
|
|
73
|
-
if ground_truth_users is not None:
|
|
74
|
-
recommendations = recommendations.join(ground_truth_users, on="user_idx", how="right")
|
|
75
|
-
|
|
76
|
-
return fill_na_with_empty_array(recommendations, "pred", base_pred.schema["item_idx"].dataType)
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
from replay.experimental.models.admm_slim import ADMMSLIM
|
|
2
|
-
from replay.experimental.models.base_torch_rec import TorchRecommender
|
|
3
|
-
from replay.experimental.models.cql import CQL
|
|
4
|
-
from replay.experimental.models.ddpg import DDPG
|
|
5
|
-
from replay.experimental.models.dt4rec.dt4rec import DT4Rec
|
|
6
|
-
from replay.experimental.models.hierarchical_recommender import HierarchicalRecommender
|
|
7
|
-
from replay.experimental.models.implicit_wrap import ImplicitWrap
|
|
8
|
-
from replay.experimental.models.lightfm_wrap import LightFMWrap
|
|
9
|
-
from replay.experimental.models.mult_vae import MultVAE
|
|
10
|
-
from replay.experimental.models.neural_ts import NeuralTS
|
|
11
|
-
from replay.experimental.models.neuromf import NeuroMF
|
|
12
|
-
from replay.experimental.models.scala_als import ScalaALSWrap
|
|
13
|
-
from replay.experimental.models.u_lin_ucb import ULinUCB
|
|
@@ -1,205 +0,0 @@
|
|
|
1
|
-
from typing import Any, Dict, Optional, Tuple
|
|
2
|
-
|
|
3
|
-
import numba as nb
|
|
4
|
-
import numpy as np
|
|
5
|
-
import pandas as pd
|
|
6
|
-
from scipy.sparse import coo_matrix, csr_matrix
|
|
7
|
-
|
|
8
|
-
from replay.experimental.models.base_neighbour_rec import NeighbourRec
|
|
9
|
-
from replay.experimental.utils.session_handler import State
|
|
10
|
-
from replay.models.extensions.ann.index_builders.base_index_builder import IndexBuilder
|
|
11
|
-
from replay.utils import SparkDataFrame
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@nb.njit(parallel=True)
|
|
15
|
-
def _main_iteration(
|
|
16
|
-
inv_matrix,
|
|
17
|
-
p_x,
|
|
18
|
-
mat_b,
|
|
19
|
-
mat_c,
|
|
20
|
-
mat_gamma,
|
|
21
|
-
rho,
|
|
22
|
-
eps_abs,
|
|
23
|
-
eps_rel,
|
|
24
|
-
lambda_1,
|
|
25
|
-
items_count,
|
|
26
|
-
threshold,
|
|
27
|
-
multiplicator,
|
|
28
|
-
): # pragma: no cover
|
|
29
|
-
# calculate mat_b
|
|
30
|
-
mat_b = p_x + np.dot(inv_matrix, rho * mat_c - mat_gamma)
|
|
31
|
-
vec_gamma = np.diag(mat_b) / np.diag(inv_matrix)
|
|
32
|
-
mat_b -= inv_matrix * vec_gamma
|
|
33
|
-
|
|
34
|
-
# calculate mat_c
|
|
35
|
-
prev_mat_c = mat_c
|
|
36
|
-
mat_c = mat_b + mat_gamma / rho
|
|
37
|
-
coef = lambda_1 / rho
|
|
38
|
-
mat_c = np.maximum(mat_c - coef, 0.0) - np.maximum(-mat_c - coef, 0.0)
|
|
39
|
-
|
|
40
|
-
# calculate mat_gamma
|
|
41
|
-
mat_gamma += rho * (mat_b - mat_c)
|
|
42
|
-
|
|
43
|
-
# calculate residuals
|
|
44
|
-
r_primal = np.linalg.norm(mat_b - mat_c)
|
|
45
|
-
r_dual = np.linalg.norm(-rho * (mat_c - prev_mat_c))
|
|
46
|
-
eps_primal = eps_abs * items_count + eps_rel * max(np.linalg.norm(mat_b), np.linalg.norm(mat_c))
|
|
47
|
-
eps_dual = eps_abs * items_count + eps_rel * np.linalg.norm(mat_gamma)
|
|
48
|
-
if r_primal > threshold * r_dual:
|
|
49
|
-
rho *= multiplicator
|
|
50
|
-
elif threshold * r_primal < r_dual:
|
|
51
|
-
rho /= multiplicator
|
|
52
|
-
|
|
53
|
-
return (
|
|
54
|
-
mat_b,
|
|
55
|
-
mat_c,
|
|
56
|
-
mat_gamma,
|
|
57
|
-
rho,
|
|
58
|
-
r_primal,
|
|
59
|
-
r_dual,
|
|
60
|
-
eps_primal,
|
|
61
|
-
eps_dual,
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
class ADMMSLIM(NeighbourRec):
|
|
66
|
-
"""`ADMM SLIM: Sparse Recommendations for Many Users
|
|
67
|
-
<http://www.cs.columbia.edu/~jebara/papers/wsdm20_ADMM.pdf>`_
|
|
68
|
-
|
|
69
|
-
This is a modification for the basic SLIM model.
|
|
70
|
-
Recommendations are improved with Alternating Direction Method of Multipliers.
|
|
71
|
-
"""
|
|
72
|
-
|
|
73
|
-
def _get_ann_infer_params(self) -> Dict[str, Any]:
|
|
74
|
-
return {
|
|
75
|
-
"features_col": None,
|
|
76
|
-
}
|
|
77
|
-
|
|
78
|
-
rho: float
|
|
79
|
-
threshold: float = 5
|
|
80
|
-
multiplicator: float = 2
|
|
81
|
-
eps_abs: float = 1.0e-3
|
|
82
|
-
eps_rel: float = 1.0e-3
|
|
83
|
-
max_iteration: int = 100
|
|
84
|
-
_mat_c: np.ndarray
|
|
85
|
-
_mat_b: np.ndarray
|
|
86
|
-
_mat_gamma: np.ndarray
|
|
87
|
-
_search_space = {
|
|
88
|
-
"lambda_1": {"type": "loguniform", "args": [1e-9, 50]},
|
|
89
|
-
"lambda_2": {"type": "loguniform", "args": [1e-9, 5000]},
|
|
90
|
-
}
|
|
91
|
-
|
|
92
|
-
def __init__(
|
|
93
|
-
self,
|
|
94
|
-
lambda_1: float = 5,
|
|
95
|
-
lambda_2: float = 5000,
|
|
96
|
-
seed: Optional[int] = None,
|
|
97
|
-
index_builder: Optional[IndexBuilder] = None,
|
|
98
|
-
):
|
|
99
|
-
"""
|
|
100
|
-
:param lambda_1: l1 regularization term
|
|
101
|
-
:param lambda_2: l2 regularization term
|
|
102
|
-
:param seed: random seed
|
|
103
|
-
:param index_builder: `IndexBuilder` instance that adds ANN functionality.
|
|
104
|
-
If not set, then ann will not be used.
|
|
105
|
-
"""
|
|
106
|
-
if lambda_1 < 0 or lambda_2 <= 0:
|
|
107
|
-
msg = "Invalid regularization parameters"
|
|
108
|
-
raise ValueError(msg)
|
|
109
|
-
self.lambda_1 = lambda_1
|
|
110
|
-
self.lambda_2 = lambda_2
|
|
111
|
-
self.rho = lambda_2
|
|
112
|
-
self.seed = seed
|
|
113
|
-
if isinstance(index_builder, (IndexBuilder, type(None))):
|
|
114
|
-
self.index_builder = index_builder
|
|
115
|
-
elif isinstance(index_builder, dict):
|
|
116
|
-
self.init_builder_from_dict(index_builder)
|
|
117
|
-
|
|
118
|
-
@property
|
|
119
|
-
def _init_args(self):
|
|
120
|
-
return {
|
|
121
|
-
"lambda_1": self.lambda_1,
|
|
122
|
-
"lambda_2": self.lambda_2,
|
|
123
|
-
"seed": self.seed,
|
|
124
|
-
}
|
|
125
|
-
|
|
126
|
-
def _fit(
|
|
127
|
-
self,
|
|
128
|
-
log: SparkDataFrame,
|
|
129
|
-
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
130
|
-
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
131
|
-
) -> None:
|
|
132
|
-
self.logger.debug("Fitting ADMM SLIM")
|
|
133
|
-
pandas_log = log.select("user_idx", "item_idx", "relevance").toPandas()
|
|
134
|
-
interactions_matrix = csr_matrix(
|
|
135
|
-
(
|
|
136
|
-
pandas_log["relevance"],
|
|
137
|
-
(pandas_log["user_idx"], pandas_log["item_idx"]),
|
|
138
|
-
),
|
|
139
|
-
shape=(self._user_dim, self._item_dim),
|
|
140
|
-
)
|
|
141
|
-
self.logger.debug("Gram matrix")
|
|
142
|
-
xtx = (interactions_matrix.T @ interactions_matrix).toarray()
|
|
143
|
-
self.logger.debug("Inverse matrix")
|
|
144
|
-
inv_matrix = np.linalg.inv(xtx + (self.lambda_2 + self.rho) * np.eye(self._item_dim))
|
|
145
|
-
self.logger.debug("Main calculations")
|
|
146
|
-
p_x = inv_matrix @ xtx
|
|
147
|
-
mat_b, mat_c, mat_gamma = self._init_matrix(self._item_dim)
|
|
148
|
-
r_primal = np.linalg.norm(mat_b - mat_c)
|
|
149
|
-
r_dual = np.linalg.norm(self.rho * mat_c)
|
|
150
|
-
eps_primal, eps_dual = 0.0, 0.0
|
|
151
|
-
iteration = 0
|
|
152
|
-
while (r_primal > eps_primal or r_dual > eps_dual) and iteration < self.max_iteration:
|
|
153
|
-
iteration += 1
|
|
154
|
-
(
|
|
155
|
-
mat_b,
|
|
156
|
-
mat_c,
|
|
157
|
-
mat_gamma,
|
|
158
|
-
self.rho,
|
|
159
|
-
r_primal,
|
|
160
|
-
r_dual,
|
|
161
|
-
eps_primal,
|
|
162
|
-
eps_dual,
|
|
163
|
-
) = _main_iteration(
|
|
164
|
-
inv_matrix,
|
|
165
|
-
p_x,
|
|
166
|
-
mat_b,
|
|
167
|
-
mat_c,
|
|
168
|
-
mat_gamma,
|
|
169
|
-
self.rho,
|
|
170
|
-
self.eps_abs,
|
|
171
|
-
self.eps_rel,
|
|
172
|
-
self.lambda_1,
|
|
173
|
-
self._item_dim,
|
|
174
|
-
self.threshold,
|
|
175
|
-
self.multiplicator,
|
|
176
|
-
)
|
|
177
|
-
result_message = (
|
|
178
|
-
f"Iteration: {iteration}. primal gap: "
|
|
179
|
-
f"{r_primal - eps_primal:.5}; dual gap: "
|
|
180
|
-
f" {r_dual - eps_dual:.5}; rho: {self.rho}"
|
|
181
|
-
)
|
|
182
|
-
self.logger.debug(result_message)
|
|
183
|
-
|
|
184
|
-
mat_c_sparse = coo_matrix(mat_c)
|
|
185
|
-
mat_c_pd = pd.DataFrame(
|
|
186
|
-
{
|
|
187
|
-
"item_idx_one": mat_c_sparse.row.astype(np.int32),
|
|
188
|
-
"item_idx_two": mat_c_sparse.col.astype(np.int32),
|
|
189
|
-
"similarity": mat_c_sparse.data,
|
|
190
|
-
}
|
|
191
|
-
)
|
|
192
|
-
self.similarity = State().session.createDataFrame(
|
|
193
|
-
mat_c_pd,
|
|
194
|
-
schema="item_idx_one int, item_idx_two int, similarity double",
|
|
195
|
-
)
|
|
196
|
-
self.similarity.cache().count()
|
|
197
|
-
|
|
198
|
-
def _init_matrix(self, size: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
199
|
-
"""Matrix initialization"""
|
|
200
|
-
if self.seed is not None:
|
|
201
|
-
np.random.seed(self.seed)
|
|
202
|
-
mat_b = np.random.rand(size, size)
|
|
203
|
-
mat_c = np.random.rand(size, size)
|
|
204
|
-
mat_gamma = np.random.rand(size, size)
|
|
205
|
-
return mat_b, mat_c, mat_gamma
|
|
@@ -1,204 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
NeighbourRec - base class that requires log at prediction time.
|
|
3
|
-
Part of set of abstract classes (from base_rec.py)
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
from abc import ABC
|
|
7
|
-
from typing import Any, Dict, Iterable, Optional, Union
|
|
8
|
-
|
|
9
|
-
from replay.experimental.models.base_rec import Recommender
|
|
10
|
-
from replay.models.extensions.ann.ann_mixin import ANNMixin
|
|
11
|
-
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
12
|
-
|
|
13
|
-
if PYSPARK_AVAILABLE:
|
|
14
|
-
from pyspark.sql import functions as sf
|
|
15
|
-
from pyspark.sql.column import Column
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class NeighbourRec(Recommender, ANNMixin, ABC):
|
|
19
|
-
"""Base class that requires log at prediction time"""
|
|
20
|
-
|
|
21
|
-
similarity: Optional[SparkDataFrame]
|
|
22
|
-
can_predict_item_to_item: bool = True
|
|
23
|
-
can_predict_cold_users: bool = True
|
|
24
|
-
can_change_metric: bool = False
|
|
25
|
-
item_to_item_metrics = ["similarity"]
|
|
26
|
-
_similarity_metric = "similarity"
|
|
27
|
-
|
|
28
|
-
@property
|
|
29
|
-
def _dataframes(self):
|
|
30
|
-
return {"similarity": self.similarity}
|
|
31
|
-
|
|
32
|
-
def _clear_cache(self):
|
|
33
|
-
if hasattr(self, "similarity"):
|
|
34
|
-
self.similarity.unpersist()
|
|
35
|
-
|
|
36
|
-
@property
|
|
37
|
-
def similarity_metric(self):
|
|
38
|
-
return self._similarity_metric
|
|
39
|
-
|
|
40
|
-
@similarity_metric.setter
|
|
41
|
-
def similarity_metric(self, value):
|
|
42
|
-
if not self.can_change_metric:
|
|
43
|
-
msg = "This class does not support changing similarity metrics"
|
|
44
|
-
raise ValueError(msg)
|
|
45
|
-
if value not in self.item_to_item_metrics:
|
|
46
|
-
msg = f"Select one of the valid metrics for predict: {self.item_to_item_metrics}"
|
|
47
|
-
raise ValueError(msg)
|
|
48
|
-
self._similarity_metric = value
|
|
49
|
-
|
|
50
|
-
def _predict_pairs_inner(
|
|
51
|
-
self,
|
|
52
|
-
log: SparkDataFrame,
|
|
53
|
-
filter_df: SparkDataFrame,
|
|
54
|
-
condition: Column,
|
|
55
|
-
users: SparkDataFrame,
|
|
56
|
-
) -> SparkDataFrame:
|
|
57
|
-
"""
|
|
58
|
-
Get recommendations for all provided users
|
|
59
|
-
and filter results with ``filter_df`` by ``condition``.
|
|
60
|
-
It allows to implement both ``predict_pairs`` and usual ``predict``@k.
|
|
61
|
-
|
|
62
|
-
:param log: historical interactions, SparkDataFrame
|
|
63
|
-
``[user_idx, item_idx, timestamp, relevance]``.
|
|
64
|
-
:param filter_df: SparkDataFrame use to filter items:
|
|
65
|
-
``[item_idx_filter]`` or ``[user_idx_filter, item_idx_filter]``.
|
|
66
|
-
:param condition: condition used for inner join with ``filter_df``
|
|
67
|
-
:param users: users to calculate recommendations for
|
|
68
|
-
:return: SparkDataFrame ``[user_idx, item_idx, relevance]``
|
|
69
|
-
"""
|
|
70
|
-
if log is None:
|
|
71
|
-
msg = "log is not provided, but it is required for prediction"
|
|
72
|
-
raise ValueError(msg)
|
|
73
|
-
|
|
74
|
-
recs = (
|
|
75
|
-
log.join(users, how="inner", on="user_idx")
|
|
76
|
-
.join(
|
|
77
|
-
self.similarity,
|
|
78
|
-
how="inner",
|
|
79
|
-
on=sf.col("item_idx") == sf.col("item_idx_one"),
|
|
80
|
-
)
|
|
81
|
-
.join(
|
|
82
|
-
filter_df,
|
|
83
|
-
how="inner",
|
|
84
|
-
on=condition,
|
|
85
|
-
)
|
|
86
|
-
.groupby("user_idx", "item_idx_two")
|
|
87
|
-
.agg(sf.sum(self.similarity_metric).alias("relevance"))
|
|
88
|
-
.withColumnRenamed("item_idx_two", "item_idx")
|
|
89
|
-
)
|
|
90
|
-
return recs
|
|
91
|
-
|
|
92
|
-
def _predict(
|
|
93
|
-
self,
|
|
94
|
-
log: SparkDataFrame,
|
|
95
|
-
k: int, # noqa: ARG002
|
|
96
|
-
users: SparkDataFrame,
|
|
97
|
-
items: SparkDataFrame,
|
|
98
|
-
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
99
|
-
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
100
|
-
filter_seen_items: bool = True, # noqa: ARG002
|
|
101
|
-
) -> SparkDataFrame:
|
|
102
|
-
return self._predict_pairs_inner(
|
|
103
|
-
log=log,
|
|
104
|
-
filter_df=items.withColumnRenamed("item_idx", "item_idx_filter"),
|
|
105
|
-
condition=sf.col("item_idx_two") == sf.col("item_idx_filter"),
|
|
106
|
-
users=users,
|
|
107
|
-
)
|
|
108
|
-
|
|
109
|
-
def _predict_pairs(
|
|
110
|
-
self,
|
|
111
|
-
pairs: SparkDataFrame,
|
|
112
|
-
log: Optional[SparkDataFrame] = None,
|
|
113
|
-
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
114
|
-
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
115
|
-
) -> SparkDataFrame:
|
|
116
|
-
if log is None:
|
|
117
|
-
msg = "log is not provided, but it is required for prediction"
|
|
118
|
-
raise ValueError(msg)
|
|
119
|
-
|
|
120
|
-
return self._predict_pairs_inner(
|
|
121
|
-
log=log,
|
|
122
|
-
filter_df=(
|
|
123
|
-
pairs.withColumnRenamed("user_idx", "user_idx_filter").withColumnRenamed("item_idx", "item_idx_filter")
|
|
124
|
-
),
|
|
125
|
-
condition=(sf.col("user_idx") == sf.col("user_idx_filter"))
|
|
126
|
-
& (sf.col("item_idx_two") == sf.col("item_idx_filter")),
|
|
127
|
-
users=pairs.select("user_idx").distinct(),
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
def get_nearest_items(
|
|
131
|
-
self,
|
|
132
|
-
items: Union[SparkDataFrame, Iterable],
|
|
133
|
-
k: int,
|
|
134
|
-
metric: Optional[str] = None,
|
|
135
|
-
candidates: Optional[Union[SparkDataFrame, Iterable]] = None,
|
|
136
|
-
) -> SparkDataFrame:
|
|
137
|
-
"""
|
|
138
|
-
Get k most similar items be the `metric` for each of the `items`.
|
|
139
|
-
|
|
140
|
-
:param items: spark dataframe or list of item ids to find neighbors
|
|
141
|
-
:param k: number of neighbors
|
|
142
|
-
:param metric: metric is not used to find neighbours in NeighbourRec,
|
|
143
|
-
the parameter is ignored
|
|
144
|
-
:param candidates: spark dataframe or list of items
|
|
145
|
-
to consider as similar, e.g. popular/new items. If None,
|
|
146
|
-
all items presented during model training are used.
|
|
147
|
-
:return: dataframe with the most similar items an distance,
|
|
148
|
-
where bigger value means greater similarity.
|
|
149
|
-
spark-dataframe with columns ``[item_idx, neighbour_item_idx, similarity]``
|
|
150
|
-
"""
|
|
151
|
-
|
|
152
|
-
if metric is not None:
|
|
153
|
-
self.logger.debug(
|
|
154
|
-
"Metric is not used to determine nearest items in %s model",
|
|
155
|
-
str(self),
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
return self._get_nearest_items_wrap(
|
|
159
|
-
items=items,
|
|
160
|
-
k=k,
|
|
161
|
-
metric=metric,
|
|
162
|
-
candidates=candidates,
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
def _get_nearest_items(
|
|
166
|
-
self,
|
|
167
|
-
items: SparkDataFrame,
|
|
168
|
-
metric: Optional[str] = None,
|
|
169
|
-
candidates: Optional[SparkDataFrame] = None,
|
|
170
|
-
) -> SparkDataFrame:
|
|
171
|
-
similarity_filtered = self.similarity.join(
|
|
172
|
-
items.withColumnRenamed("item_idx", "item_idx_one"),
|
|
173
|
-
on="item_idx_one",
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
if candidates is not None:
|
|
177
|
-
similarity_filtered = similarity_filtered.join(
|
|
178
|
-
candidates.withColumnRenamed("item_idx", "item_idx_two"),
|
|
179
|
-
on="item_idx_two",
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
return similarity_filtered.select(
|
|
183
|
-
"item_idx_one",
|
|
184
|
-
"item_idx_two",
|
|
185
|
-
"similarity" if metric is None else metric,
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
def _get_ann_build_params(self, interactions: SparkDataFrame) -> Dict[str, Any]:
|
|
189
|
-
self.index_builder.index_params.items_count = interactions.select(sf.max("item_idx")).first()[0] + 1
|
|
190
|
-
return {
|
|
191
|
-
"features_col": None,
|
|
192
|
-
}
|
|
193
|
-
|
|
194
|
-
def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame: # noqa: ARG002
|
|
195
|
-
similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
|
|
196
|
-
return similarity_df
|
|
197
|
-
|
|
198
|
-
def _get_vectors_to_infer_ann_inner(
|
|
199
|
-
self, interactions: SparkDataFrame, queries: SparkDataFrame # noqa: ARG002
|
|
200
|
-
) -> SparkDataFrame:
|
|
201
|
-
user_vectors = interactions.groupBy("user_idx").agg(
|
|
202
|
-
sf.collect_list("item_idx").alias("vector_items"), sf.collect_list("relevance").alias("vector_relevances")
|
|
203
|
-
)
|
|
204
|
-
return user_vectors
|