replay-rec 0.20.3__py3-none-any.whl → 0.20.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. replay/__init__.py +1 -1
  2. replay/experimental/__init__.py +0 -0
  3. replay/experimental/metrics/__init__.py +62 -0
  4. replay/experimental/metrics/base_metric.py +603 -0
  5. replay/experimental/metrics/coverage.py +97 -0
  6. replay/experimental/metrics/experiment.py +175 -0
  7. replay/experimental/metrics/hitrate.py +26 -0
  8. replay/experimental/metrics/map.py +30 -0
  9. replay/experimental/metrics/mrr.py +18 -0
  10. replay/experimental/metrics/ncis_precision.py +31 -0
  11. replay/experimental/metrics/ndcg.py +49 -0
  12. replay/experimental/metrics/precision.py +22 -0
  13. replay/experimental/metrics/recall.py +25 -0
  14. replay/experimental/metrics/rocauc.py +49 -0
  15. replay/experimental/metrics/surprisal.py +90 -0
  16. replay/experimental/metrics/unexpectedness.py +76 -0
  17. replay/experimental/models/__init__.py +50 -0
  18. replay/experimental/models/admm_slim.py +257 -0
  19. replay/experimental/models/base_neighbour_rec.py +200 -0
  20. replay/experimental/models/base_rec.py +1386 -0
  21. replay/experimental/models/base_torch_rec.py +234 -0
  22. replay/experimental/models/cql.py +454 -0
  23. replay/experimental/models/ddpg.py +932 -0
  24. replay/experimental/models/dt4rec/__init__.py +0 -0
  25. replay/experimental/models/dt4rec/dt4rec.py +189 -0
  26. replay/experimental/models/dt4rec/gpt1.py +401 -0
  27. replay/experimental/models/dt4rec/trainer.py +127 -0
  28. replay/experimental/models/dt4rec/utils.py +264 -0
  29. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  30. replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
  31. replay/experimental/models/hierarchical_recommender.py +331 -0
  32. replay/experimental/models/implicit_wrap.py +131 -0
  33. replay/experimental/models/lightfm_wrap.py +303 -0
  34. replay/experimental/models/mult_vae.py +332 -0
  35. replay/experimental/models/neural_ts.py +986 -0
  36. replay/experimental/models/neuromf.py +406 -0
  37. replay/experimental/models/scala_als.py +293 -0
  38. replay/experimental/models/u_lin_ucb.py +115 -0
  39. replay/experimental/nn/data/__init__.py +1 -0
  40. replay/experimental/nn/data/schema_builder.py +102 -0
  41. replay/experimental/preprocessing/__init__.py +3 -0
  42. replay/experimental/preprocessing/data_preparator.py +839 -0
  43. replay/experimental/preprocessing/padder.py +229 -0
  44. replay/experimental/preprocessing/sequence_generator.py +208 -0
  45. replay/experimental/scenarios/__init__.py +1 -0
  46. replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  47. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
  48. replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
  49. replay/experimental/scenarios/obp_wrapper/utils.py +85 -0
  50. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  51. replay/experimental/scenarios/two_stages/reranker.py +117 -0
  52. replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
  53. replay/experimental/utils/__init__.py +0 -0
  54. replay/experimental/utils/logger.py +24 -0
  55. replay/experimental/utils/model_handler.py +186 -0
  56. replay/experimental/utils/session_handler.py +44 -0
  57. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/METADATA +11 -17
  58. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/RECORD +61 -6
  59. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/WHEEL +0 -0
  60. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/LICENSE +0 -0
  61. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,76 @@
1
+ from typing import Optional
2
+
3
+ from replay.utils import DataFrameLike, SparkDataFrame
4
+ from replay.utils.spark_utils import convert2spark, get_top_k_recs
5
+
6
+ from .base_metric import RecOnlyMetric, fill_na_with_empty_array, filter_sort
7
+
8
+
9
+ class Unexpectedness(RecOnlyMetric):
10
+ """
11
+ Fraction of recommended items that are not present in some baseline recommendations.
12
+
13
+ >>> import pandas as pd
14
+ >>> from replay.utils.session_handler import get_spark_session, State
15
+ >>> spark = get_spark_session(1, 1)
16
+ >>> state = State(spark)
17
+
18
+ >>> log = pd.DataFrame({
19
+ ... "user_idx": [1, 1, 1],
20
+ ... "item_idx": [1, 2, 3],
21
+ ... "relevance": [5, 5, 5],
22
+ ... "timestamp": [1, 1, 1],
23
+ ... })
24
+ >>> recs = pd.DataFrame({
25
+ ... "user_idx": [1, 1, 1],
26
+ ... "item_idx": [0, 0, 1],
27
+ ... "relevance": [5, 5, 5],
28
+ ... "timestamp": [1, 1, 1],
29
+ ... })
30
+ >>> metric = Unexpectedness(log)
31
+ >>> round(metric(recs, 3), 2)
32
+ 0.67
33
+ """
34
+
35
+ _scala_udf_name = "getUnexpectednessMetricValue"
36
+
37
+ def __init__(self, pred: DataFrameLike, use_scala_udf: bool = False):
38
+ """
39
+ :param pred: model predictions
40
+ """
41
+ self._use_scala_udf = use_scala_udf
42
+ self.pred = convert2spark(pred)
43
+
44
+ @staticmethod
45
+ def _get_metric_value_by_user(k, *args) -> float:
46
+ pred = args[0]
47
+ base_pred = args[1]
48
+ if len(pred) == 0:
49
+ return 0
50
+ return 1.0 - len(set(pred[:k]) & set(base_pred[:k])) / k
51
+
52
+ def _get_enriched_recommendations(
53
+ self,
54
+ recommendations: SparkDataFrame,
55
+ ground_truth: SparkDataFrame, # noqa: ARG002
56
+ max_k: int,
57
+ ground_truth_users: Optional[DataFrameLike] = None,
58
+ ) -> SparkDataFrame:
59
+ recommendations = convert2spark(recommendations)
60
+ ground_truth_users = convert2spark(ground_truth_users)
61
+ base_pred = self.pred
62
+
63
+ # TO DO: preprocess base_recs once in __init__
64
+
65
+ base_recs = filter_sort(base_pred).withColumnRenamed("pred", "base_pred")
66
+
67
+ # if there are duplicates in recommendations,
68
+ # we will leave fewer than k recommendations after sort_udf
69
+ recommendations = get_top_k_recs(recommendations, k=max_k)
70
+ recommendations = filter_sort(recommendations)
71
+ recommendations = recommendations.join(base_recs, how="right", on=["user_idx"])
72
+
73
+ if ground_truth_users is not None:
74
+ recommendations = recommendations.join(ground_truth_users, on="user_idx", how="right")
75
+
76
+ return fill_na_with_empty_array(recommendations, "pred", base_pred.schema["item_idx"].dataType)
@@ -0,0 +1,50 @@
1
+ from typing import Any
2
+
3
+ from replay.experimental.models.admm_slim import ADMMSLIM
4
+ from replay.experimental.models.base_torch_rec import TorchRecommender
5
+ from replay.experimental.models.cql import CQL
6
+ from replay.experimental.models.ddpg import DDPG
7
+ from replay.experimental.models.dt4rec.dt4rec import DT4Rec
8
+ from replay.experimental.models.hierarchical_recommender import HierarchicalRecommender
9
+ from replay.experimental.models.implicit_wrap import ImplicitWrap
10
+ from replay.experimental.models.mult_vae import MultVAE
11
+ from replay.experimental.models.neural_ts import NeuralTS
12
+ from replay.experimental.models.neuromf import NeuroMF
13
+ from replay.experimental.models.scala_als import ScalaALSWrap
14
+ from replay.experimental.models.u_lin_ucb import ULinUCB
15
+
16
+ __all__ = [
17
+ "ADMMSLIM",
18
+ "CQL",
19
+ "DDPG",
20
+ "DT4Rec",
21
+ "HierarchicalRecommender",
22
+ "ImplicitWrap",
23
+ "MultVAE",
24
+ "NeuralTS",
25
+ "NeuroMF",
26
+ "ScalaALSWrap",
27
+ "TorchRecommender",
28
+ "ULinUCB",
29
+ ]
30
+
31
+ CONDITIONAL_IMPORTS = {"LightFMWrap": "replay.experimental.models.lightfm_wrap"}
32
+
33
+
34
+ class ConditionalAccessError(Exception):
35
+ """Raised when trying to access conditional elements from parent module instead of a direct import."""
36
+
37
+
38
+ def __getattr__(name: str) -> Any:
39
+ if name in CONDITIONAL_IMPORTS:
40
+ msg = (
41
+ f"{name} relies on manual dependency installation and cannot be accessed via higher-level modules. "
42
+ f"If you wish to use this attribute, import it directly from {CONDITIONAL_IMPORTS[name]}"
43
+ )
44
+
45
+ raise ConditionalAccessError(msg)
46
+
47
+ if name in __all__:
48
+ return globals()[name]
49
+ msg = f"module {__name__!r} has no attribute {name!r}"
50
+ raise AttributeError(msg)
@@ -0,0 +1,257 @@
1
+ from collections.abc import Iterable
2
+ from typing import Any, Optional, Union
3
+
4
+ import numba as nb
5
+ import numpy as np
6
+ import pandas as pd
7
+ from scipy.sparse import coo_matrix, csr_matrix
8
+
9
+ from replay.experimental.models.base_neighbour_rec import NeighbourRec
10
+ from replay.experimental.models.base_rec import Recommender
11
+ from replay.experimental.utils.session_handler import State
12
+ from replay.models.extensions.ann.index_builders.base_index_builder import IndexBuilder
13
+ from replay.utils import SparkDataFrame
14
+ from replay.utils.spark_utils import get_top_k_recs, return_recs
15
+
16
+
17
+ @nb.njit(parallel=True)
18
+ def _main_iteration(
19
+ inv_matrix,
20
+ p_x,
21
+ mat_b,
22
+ mat_c,
23
+ mat_gamma,
24
+ rho,
25
+ eps_abs,
26
+ eps_rel,
27
+ lambda_1,
28
+ items_count,
29
+ threshold,
30
+ multiplicator,
31
+ ): # pragma: no cover
32
+ # calculate mat_b
33
+ mat_b = p_x + np.dot(inv_matrix, rho * mat_c - mat_gamma)
34
+ vec_gamma = np.diag(mat_b) / np.diag(inv_matrix)
35
+ mat_b -= inv_matrix * vec_gamma
36
+
37
+ # calculate mat_c
38
+ prev_mat_c = mat_c
39
+ mat_c = mat_b + mat_gamma / rho
40
+ coef = lambda_1 / rho
41
+ mat_c = np.maximum(mat_c - coef, 0.0) - np.maximum(-mat_c - coef, 0.0)
42
+
43
+ # calculate mat_gamma
44
+ mat_gamma += rho * (mat_b - mat_c)
45
+
46
+ # calculate residuals
47
+ r_primal = np.linalg.norm(mat_b - mat_c)
48
+ r_dual = np.linalg.norm(-rho * (mat_c - prev_mat_c))
49
+ eps_primal = eps_abs * items_count + eps_rel * max(np.linalg.norm(mat_b), np.linalg.norm(mat_c))
50
+ eps_dual = eps_abs * items_count + eps_rel * np.linalg.norm(mat_gamma)
51
+ if r_primal > threshold * r_dual:
52
+ rho *= multiplicator
53
+ elif threshold * r_primal < r_dual:
54
+ rho /= multiplicator
55
+
56
+ return (
57
+ mat_b,
58
+ mat_c,
59
+ mat_gamma,
60
+ rho,
61
+ r_primal,
62
+ r_dual,
63
+ eps_primal,
64
+ eps_dual,
65
+ )
66
+
67
+
68
+ class ADMMSLIM(NeighbourRec):
69
+ """`ADMM SLIM: Sparse Recommendations for Many Users
70
+ <http://www.cs.columbia.edu/~jebara/papers/wsdm20_ADMM.pdf>`_
71
+
72
+ This is a modification for the basic SLIM model.
73
+ Recommendations are improved with Alternating Direction Method of Multipliers.
74
+ """
75
+
76
+ def _get_ann_infer_params(self) -> dict[str, Any]:
77
+ return {
78
+ "features_col": None,
79
+ }
80
+
81
+ rho: float
82
+ threshold: float = 5
83
+ multiplicator: float = 2
84
+ eps_abs: float = 1.0e-3
85
+ eps_rel: float = 1.0e-3
86
+ max_iteration: int = 100
87
+ _mat_c: np.ndarray
88
+ _mat_b: np.ndarray
89
+ _mat_gamma: np.ndarray
90
+ _search_space = {
91
+ "lambda_1": {"type": "loguniform", "args": [1e-9, 50]},
92
+ "lambda_2": {"type": "loguniform", "args": [1e-9, 5000]},
93
+ }
94
+
95
+ def __init__(
96
+ self,
97
+ lambda_1: float = 5,
98
+ lambda_2: float = 5000,
99
+ seed: Optional[int] = None,
100
+ index_builder: Optional[IndexBuilder] = None,
101
+ ):
102
+ """
103
+ :param lambda_1: l1 regularization term
104
+ :param lambda_2: l2 regularization term
105
+ :param seed: random seed
106
+ :param index_builder: `IndexBuilder` instance that adds ANN functionality.
107
+ If not set, then ann will not be used.
108
+ """
109
+ self.init_index_builder(index_builder)
110
+
111
+ if lambda_1 < 0 or lambda_2 <= 0:
112
+ msg = "Invalid regularization parameters"
113
+ raise ValueError(msg)
114
+ self.lambda_1 = lambda_1
115
+ self.lambda_2 = lambda_2
116
+ self.rho = lambda_2
117
+ self.seed = seed
118
+
119
+ @property
120
+ def _init_args(self):
121
+ return {
122
+ "lambda_1": self.lambda_1,
123
+ "lambda_2": self.lambda_2,
124
+ "seed": self.seed,
125
+ }
126
+
127
+ def fit(self, log: "SparkDataFrame", user_features=None, item_features=None) -> None:
128
+ """Wrapper extends `_fit_wrap`, adds construction of ANN index by flag.
129
+
130
+ Args:
131
+ dataset: historical interactions with query/item features
132
+ ``[user_id, item_id, timestamp, rating]``
133
+ """
134
+ Recommender._fit_wrap(self, log, user_features, item_features)
135
+
136
+ if self._use_ann:
137
+ vectors, ann_params = self._configure_index_builder(log)
138
+ self.index_builder.build_index(vectors, **ann_params)
139
+
140
+ def _fit(
141
+ self,
142
+ log: SparkDataFrame,
143
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
144
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
145
+ ) -> None:
146
+ self.logger.debug("Fitting ADMM SLIM")
147
+ pandas_log = log.select("user_idx", "item_idx", "relevance").toPandas()
148
+ interactions_matrix = csr_matrix(
149
+ (
150
+ pandas_log["relevance"],
151
+ (pandas_log["user_idx"], pandas_log["item_idx"]),
152
+ ),
153
+ shape=(self._user_dim, self._item_dim),
154
+ )
155
+ self.logger.debug("Gram matrix")
156
+ xtx = (interactions_matrix.T @ interactions_matrix).toarray()
157
+ self.logger.debug("Inverse matrix")
158
+ inv_matrix = np.linalg.inv(xtx + (self.lambda_2 + self.rho) * np.eye(self._item_dim))
159
+ self.logger.debug("Main calculations")
160
+ p_x = inv_matrix @ xtx
161
+ mat_b, mat_c, mat_gamma = self._init_matrix(self._item_dim)
162
+ r_primal = np.linalg.norm(mat_b - mat_c)
163
+ r_dual = np.linalg.norm(self.rho * mat_c)
164
+ eps_primal, eps_dual = 0.0, 0.0
165
+ iteration = 0
166
+ while (r_primal > eps_primal or r_dual > eps_dual) and iteration < self.max_iteration:
167
+ iteration += 1
168
+ (
169
+ mat_b,
170
+ mat_c,
171
+ mat_gamma,
172
+ self.rho,
173
+ r_primal,
174
+ r_dual,
175
+ eps_primal,
176
+ eps_dual,
177
+ ) = _main_iteration(
178
+ inv_matrix,
179
+ p_x,
180
+ mat_b,
181
+ mat_c,
182
+ mat_gamma,
183
+ self.rho,
184
+ self.eps_abs,
185
+ self.eps_rel,
186
+ self.lambda_1,
187
+ self._item_dim,
188
+ self.threshold,
189
+ self.multiplicator,
190
+ )
191
+ result_message = (
192
+ f"Iteration: {iteration}. primal gap: "
193
+ f"{r_primal - eps_primal:.5}; dual gap: "
194
+ f" {r_dual - eps_dual:.5}; rho: {self.rho}"
195
+ )
196
+ self.logger.debug(result_message)
197
+
198
+ mat_c_sparse = coo_matrix(mat_c)
199
+ mat_c_pd = pd.DataFrame(
200
+ {
201
+ "item_idx_one": mat_c_sparse.row.astype(np.int32),
202
+ "item_idx_two": mat_c_sparse.col.astype(np.int32),
203
+ "similarity": mat_c_sparse.data,
204
+ }
205
+ )
206
+ self.similarity = State().session.createDataFrame(
207
+ mat_c_pd,
208
+ schema="item_idx_one int, item_idx_two int, similarity double",
209
+ )
210
+ self.similarity.cache().count()
211
+
212
+ def _predict_wrap(
213
+ self,
214
+ log: SparkDataFrame,
215
+ k: int,
216
+ users: Optional[Union[SparkDataFrame, Iterable]] = None,
217
+ items: Optional[Union[SparkDataFrame, Iterable]] = None,
218
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
219
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
220
+ filter_seen_items: bool = True,
221
+ recs_file_path: Optional[str] = None,
222
+ ) -> Optional[SparkDataFrame]:
223
+ log, users, items = self._filter_interactions_queries_items_dataframes(log, k, users, items)
224
+
225
+ if self._use_ann:
226
+ vectors = self._get_vectors_to_infer_ann(log, users, filter_seen_items)
227
+ ann_params = self._get_ann_infer_params()
228
+ inferer = self.index_builder.produce_inferer(filter_seen_items)
229
+ recs = inferer.infer(vectors, ann_params["features_col"], k)
230
+ else:
231
+ recs = self._predict(
232
+ log,
233
+ k,
234
+ users,
235
+ items,
236
+ filter_seen_items,
237
+ )
238
+
239
+ if not self._use_ann:
240
+ if filter_seen_items and log:
241
+ recs = self._filter_seen(recs=recs, log=log, users=users, k=k)
242
+
243
+ recs = get_top_k_recs(recs, k=k).select("user_idx", "item_idx", "relevance")
244
+
245
+ output = return_recs(recs, recs_file_path)
246
+ self._clear_model_temp_view("filter_seen_queries_interactions")
247
+ self._clear_model_temp_view("filter_seen_num_seen")
248
+ return output
249
+
250
+ def _init_matrix(self, size: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
251
+ """Matrix initialization"""
252
+ if self.seed is not None:
253
+ np.random.seed(self.seed)
254
+ mat_b = np.random.rand(size, size)
255
+ mat_c = np.random.rand(size, size)
256
+ mat_gamma = np.random.rand(size, size)
257
+ return mat_b, mat_c, mat_gamma
@@ -0,0 +1,200 @@
1
+ """
2
+ NeighbourRec - base class that requires log at prediction time.
3
+ Part of set of abstract classes (from base_rec.py)
4
+ """
5
+
6
+ from abc import ABC
7
+ from collections.abc import Iterable
8
+ from typing import Any, Optional, Union
9
+
10
+ from replay.experimental.models.base_rec import Recommender
11
+ from replay.models.extensions.ann.ann_mixin import ANNMixin
12
+ from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
13
+
14
+ if PYSPARK_AVAILABLE:
15
+ from pyspark.sql import functions as sf
16
+ from pyspark.sql.column import Column
17
+
18
+
19
+ class NeighbourRec(ANNMixin, Recommender, ABC):
20
+ """Base class that requires log at prediction time"""
21
+
22
+ similarity: Optional[SparkDataFrame]
23
+ can_predict_item_to_item: bool = True
24
+ can_predict_cold_users: bool = True
25
+ can_change_metric: bool = False
26
+ item_to_item_metrics = ["similarity"]
27
+ _similarity_metric = "similarity"
28
+
29
+ @property
30
+ def _dataframes(self):
31
+ return {"similarity": self.similarity}
32
+
33
+ def _clear_cache(self):
34
+ if hasattr(self, "similarity"):
35
+ self.similarity.unpersist()
36
+
37
+ @property
38
+ def similarity_metric(self):
39
+ return self._similarity_metric
40
+
41
+ @similarity_metric.setter
42
+ def similarity_metric(self, value):
43
+ if not self.can_change_metric:
44
+ msg = "This class does not support changing similarity metrics"
45
+ raise ValueError(msg)
46
+ if value not in self.item_to_item_metrics:
47
+ msg = f"Select one of the valid metrics for predict: {self.item_to_item_metrics}"
48
+ raise ValueError(msg)
49
+ self._similarity_metric = value
50
+
51
+ def _predict_pairs_inner(
52
+ self,
53
+ log: SparkDataFrame,
54
+ filter_df: SparkDataFrame,
55
+ condition: Column,
56
+ users: SparkDataFrame,
57
+ ) -> SparkDataFrame:
58
+ """
59
+ Get recommendations for all provided users
60
+ and filter results with ``filter_df`` by ``condition``.
61
+ It allows to implement both ``predict_pairs`` and usual ``predict``@k.
62
+
63
+ :param log: historical interactions, SparkDataFrame
64
+ ``[user_idx, item_idx, timestamp, relevance]``.
65
+ :param filter_df: SparkDataFrame use to filter items:
66
+ ``[item_idx_filter]`` or ``[user_idx_filter, item_idx_filter]``.
67
+ :param condition: condition used for inner join with ``filter_df``
68
+ :param users: users to calculate recommendations for
69
+ :return: SparkDataFrame ``[user_idx, item_idx, relevance]``
70
+ """
71
+ if log is None:
72
+ msg = "log is not provided, but it is required for prediction"
73
+ raise ValueError(msg)
74
+
75
+ recs = (
76
+ log.join(users, how="inner", on="user_idx")
77
+ .join(
78
+ self.similarity,
79
+ how="inner",
80
+ on=sf.col("item_idx") == sf.col("item_idx_one"),
81
+ )
82
+ .join(
83
+ filter_df,
84
+ how="inner",
85
+ on=condition,
86
+ )
87
+ .groupby("user_idx", "item_idx_two")
88
+ .agg(sf.sum(self.similarity_metric).alias("relevance"))
89
+ .withColumnRenamed("item_idx_two", "item_idx")
90
+ )
91
+ return recs
92
+
93
+ def _predict(
94
+ self,
95
+ log: SparkDataFrame,
96
+ k: int, # noqa: ARG002
97
+ users: SparkDataFrame,
98
+ items: SparkDataFrame,
99
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
100
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
101
+ filter_seen_items: bool = True, # noqa: ARG002
102
+ ) -> SparkDataFrame:
103
+ return self._predict_pairs_inner(
104
+ log=log,
105
+ filter_df=items.withColumnRenamed("item_idx", "item_idx_filter"),
106
+ condition=sf.col("item_idx_two") == sf.col("item_idx_filter"),
107
+ users=users,
108
+ )
109
+
110
+ def _predict_pairs(
111
+ self,
112
+ pairs: SparkDataFrame,
113
+ log: Optional[SparkDataFrame] = None,
114
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
115
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
116
+ ) -> SparkDataFrame:
117
+ if log is None:
118
+ msg = "log is not provided, but it is required for prediction"
119
+ raise ValueError(msg)
120
+
121
+ return self._predict_pairs_inner(
122
+ log=log,
123
+ filter_df=(
124
+ pairs.withColumnRenamed("user_idx", "user_idx_filter").withColumnRenamed("item_idx", "item_idx_filter")
125
+ ),
126
+ condition=(sf.col("user_idx") == sf.col("user_idx_filter"))
127
+ & (sf.col("item_idx_two") == sf.col("item_idx_filter")),
128
+ users=pairs.select("user_idx").distinct(),
129
+ )
130
+
131
+ def get_nearest_items(
132
+ self,
133
+ items: Union[SparkDataFrame, Iterable],
134
+ k: int,
135
+ metric: Optional[str] = None,
136
+ candidates: Optional[Union[SparkDataFrame, Iterable]] = None,
137
+ ) -> SparkDataFrame:
138
+ """
139
+ Get k most similar items be the `metric` for each of the `items`.
140
+
141
+ :param items: spark dataframe or list of item ids to find neighbors
142
+ :param k: number of neighbors
143
+ :param metric: metric is not used to find neighbours in NeighbourRec,
144
+ the parameter is ignored
145
+ :param candidates: spark dataframe or list of items
146
+ to consider as similar, e.g. popular/new items. If None,
147
+ all items presented during model training are used.
148
+ :return: dataframe with the most similar items an distance,
149
+ where bigger value means greater similarity.
150
+ spark-dataframe with columns ``[item_idx, neighbour_item_idx, similarity]``
151
+ """
152
+
153
+ if metric is not None:
154
+ self.logger.debug(
155
+ "Metric is not used to determine nearest items in %s model",
156
+ str(self),
157
+ )
158
+
159
+ return self._get_nearest_items_wrap(
160
+ items=items,
161
+ k=k,
162
+ metric=metric,
163
+ candidates=candidates,
164
+ )
165
+
166
+ def _get_nearest_items(
167
+ self,
168
+ items: SparkDataFrame,
169
+ metric: Optional[str] = None,
170
+ candidates: Optional[SparkDataFrame] = None,
171
+ ) -> SparkDataFrame:
172
+ similarity_filtered = self.similarity.join(
173
+ items.withColumnRenamed("item_idx", "item_idx_one"),
174
+ on="item_idx_one",
175
+ )
176
+
177
+ if candidates is not None:
178
+ similarity_filtered = similarity_filtered.join(
179
+ candidates.withColumnRenamed("item_idx", "item_idx_two"),
180
+ on="item_idx_two",
181
+ )
182
+
183
+ return similarity_filtered.select(
184
+ "item_idx_one",
185
+ "item_idx_two",
186
+ "similarity" if metric is None else metric,
187
+ )
188
+
189
+ def _configure_index_builder(self, interactions: SparkDataFrame) -> dict[str, Any]:
190
+ similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
191
+ self.index_builder.index_params.items_count = interactions.select(sf.max("item_idx")).first()[0] + 1
192
+ return similarity_df, {"features_col": None}
193
+
194
+ def _get_vectors_to_infer_ann_inner(
195
+ self, interactions: SparkDataFrame, queries: SparkDataFrame # noqa: ARG002
196
+ ) -> SparkDataFrame:
197
+ user_vectors = interactions.groupBy("user_idx").agg(
198
+ sf.collect_list("item_idx").alias("vector_items"), sf.collect_list("relevance").alias("vector_relevances")
199
+ )
200
+ return user_vectors