replay-rec 0.19.0__py3-none-any.whl → 0.19.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. replay/__init__.py +1 -1
  2. replay/experimental/__init__.py +0 -0
  3. replay/experimental/metrics/__init__.py +62 -0
  4. replay/experimental/metrics/base_metric.py +602 -0
  5. replay/experimental/metrics/coverage.py +97 -0
  6. replay/experimental/metrics/experiment.py +175 -0
  7. replay/experimental/metrics/hitrate.py +26 -0
  8. replay/experimental/metrics/map.py +30 -0
  9. replay/experimental/metrics/mrr.py +18 -0
  10. replay/experimental/metrics/ncis_precision.py +31 -0
  11. replay/experimental/metrics/ndcg.py +49 -0
  12. replay/experimental/metrics/precision.py +22 -0
  13. replay/experimental/metrics/recall.py +25 -0
  14. replay/experimental/metrics/rocauc.py +49 -0
  15. replay/experimental/metrics/surprisal.py +90 -0
  16. replay/experimental/metrics/unexpectedness.py +76 -0
  17. replay/experimental/models/__init__.py +13 -0
  18. replay/experimental/models/admm_slim.py +205 -0
  19. replay/experimental/models/base_neighbour_rec.py +204 -0
  20. replay/experimental/models/base_rec.py +1340 -0
  21. replay/experimental/models/base_torch_rec.py +234 -0
  22. replay/experimental/models/cql.py +454 -0
  23. replay/experimental/models/ddpg.py +923 -0
  24. replay/experimental/models/dt4rec/__init__.py +0 -0
  25. replay/experimental/models/dt4rec/dt4rec.py +189 -0
  26. replay/experimental/models/dt4rec/gpt1.py +401 -0
  27. replay/experimental/models/dt4rec/trainer.py +127 -0
  28. replay/experimental/models/dt4rec/utils.py +265 -0
  29. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  30. replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
  31. replay/experimental/models/hierarchical_recommender.py +331 -0
  32. replay/experimental/models/implicit_wrap.py +131 -0
  33. replay/experimental/models/lightfm_wrap.py +302 -0
  34. replay/experimental/models/mult_vae.py +332 -0
  35. replay/experimental/models/neural_ts.py +986 -0
  36. replay/experimental/models/neuromf.py +406 -0
  37. replay/experimental/models/scala_als.py +296 -0
  38. replay/experimental/models/u_lin_ucb.py +115 -0
  39. replay/experimental/nn/data/__init__.py +1 -0
  40. replay/experimental/nn/data/schema_builder.py +102 -0
  41. replay/experimental/preprocessing/__init__.py +3 -0
  42. replay/experimental/preprocessing/data_preparator.py +839 -0
  43. replay/experimental/preprocessing/padder.py +229 -0
  44. replay/experimental/preprocessing/sequence_generator.py +208 -0
  45. replay/experimental/scenarios/__init__.py +1 -0
  46. replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  47. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
  48. replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
  49. replay/experimental/scenarios/obp_wrapper/utils.py +87 -0
  50. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  51. replay/experimental/scenarios/two_stages/reranker.py +117 -0
  52. replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
  53. replay/experimental/utils/__init__.py +0 -0
  54. replay/experimental/utils/logger.py +24 -0
  55. replay/experimental/utils/model_handler.py +186 -0
  56. replay/experimental/utils/session_handler.py +44 -0
  57. {replay_rec-0.19.0.dist-info → replay_rec-0.19.0rc0.dist-info}/METADATA +11 -3
  58. replay_rec-0.19.0rc0.dist-info/NOTICE +41 -0
  59. {replay_rec-0.19.0.dist-info → replay_rec-0.19.0rc0.dist-info}/RECORD +61 -5
  60. {replay_rec-0.19.0.dist-info → replay_rec-0.19.0rc0.dist-info}/WHEEL +1 -1
  61. {replay_rec-0.19.0.dist-info → replay_rec-0.19.0rc0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,76 @@
1
+ from typing import Optional
2
+
3
+ from replay.utils import DataFrameLike, SparkDataFrame
4
+ from replay.utils.spark_utils import convert2spark, get_top_k_recs
5
+
6
+ from .base_metric import RecOnlyMetric, fill_na_with_empty_array, filter_sort
7
+
8
+
9
+ class Unexpectedness(RecOnlyMetric):
10
+ """
11
+ Fraction of recommended items that are not present in some baseline recommendations.
12
+
13
+ >>> import pandas as pd
14
+ >>> from replay.utils.session_handler import get_spark_session, State
15
+ >>> spark = get_spark_session(1, 1)
16
+ >>> state = State(spark)
17
+
18
+ >>> log = pd.DataFrame({
19
+ ... "user_idx": [1, 1, 1],
20
+ ... "item_idx": [1, 2, 3],
21
+ ... "relevance": [5, 5, 5],
22
+ ... "timestamp": [1, 1, 1],
23
+ ... })
24
+ >>> recs = pd.DataFrame({
25
+ ... "user_idx": [1, 1, 1],
26
+ ... "item_idx": [0, 0, 1],
27
+ ... "relevance": [5, 5, 5],
28
+ ... "timestamp": [1, 1, 1],
29
+ ... })
30
+ >>> metric = Unexpectedness(log)
31
+ >>> round(metric(recs, 3), 2)
32
+ 0.67
33
+ """
34
+
35
+ _scala_udf_name = "getUnexpectednessMetricValue"
36
+
37
+ def __init__(self, pred: DataFrameLike, use_scala_udf: bool = False):
38
+ """
39
+ :param pred: model predictions
40
+ """
41
+ self._use_scala_udf = use_scala_udf
42
+ self.pred = convert2spark(pred)
43
+
44
+ @staticmethod
45
+ def _get_metric_value_by_user(k, *args) -> float:
46
+ pred = args[0]
47
+ base_pred = args[1]
48
+ if len(pred) == 0:
49
+ return 0
50
+ return 1.0 - len(set(pred[:k]) & set(base_pred[:k])) / k
51
+
52
+ def _get_enriched_recommendations(
53
+ self,
54
+ recommendations: SparkDataFrame,
55
+ ground_truth: SparkDataFrame, # noqa: ARG002
56
+ max_k: int,
57
+ ground_truth_users: Optional[DataFrameLike] = None,
58
+ ) -> SparkDataFrame:
59
+ recommendations = convert2spark(recommendations)
60
+ ground_truth_users = convert2spark(ground_truth_users)
61
+ base_pred = self.pred
62
+
63
+ # TO DO: preprocess base_recs once in __init__
64
+
65
+ base_recs = filter_sort(base_pred).withColumnRenamed("pred", "base_pred")
66
+
67
+ # if there are duplicates in recommendations,
68
+ # we will leave fewer than k recommendations after sort_udf
69
+ recommendations = get_top_k_recs(recommendations, k=max_k)
70
+ recommendations = filter_sort(recommendations)
71
+ recommendations = recommendations.join(base_recs, how="right", on=["user_idx"])
72
+
73
+ if ground_truth_users is not None:
74
+ recommendations = recommendations.join(ground_truth_users, on="user_idx", how="right")
75
+
76
+ return fill_na_with_empty_array(recommendations, "pred", base_pred.schema["item_idx"].dataType)
@@ -0,0 +1,13 @@
1
+ from replay.experimental.models.admm_slim import ADMMSLIM
2
+ from replay.experimental.models.base_torch_rec import TorchRecommender
3
+ from replay.experimental.models.cql import CQL
4
+ from replay.experimental.models.ddpg import DDPG
5
+ from replay.experimental.models.dt4rec.dt4rec import DT4Rec
6
+ from replay.experimental.models.hierarchical_recommender import HierarchicalRecommender
7
+ from replay.experimental.models.implicit_wrap import ImplicitWrap
8
+ from replay.experimental.models.lightfm_wrap import LightFMWrap
9
+ from replay.experimental.models.mult_vae import MultVAE
10
+ from replay.experimental.models.neural_ts import NeuralTS
11
+ from replay.experimental.models.neuromf import NeuroMF
12
+ from replay.experimental.models.scala_als import ScalaALSWrap
13
+ from replay.experimental.models.u_lin_ucb import ULinUCB
@@ -0,0 +1,205 @@
1
+ from typing import Any, Dict, Optional, Tuple
2
+
3
+ import numba as nb
4
+ import numpy as np
5
+ import pandas as pd
6
+ from scipy.sparse import coo_matrix, csr_matrix
7
+
8
+ from replay.experimental.models.base_neighbour_rec import NeighbourRec
9
+ from replay.experimental.utils.session_handler import State
10
+ from replay.models.extensions.ann.index_builders.base_index_builder import IndexBuilder
11
+ from replay.utils import SparkDataFrame
12
+
13
+
14
+ @nb.njit(parallel=True)
15
+ def _main_iteration(
16
+ inv_matrix,
17
+ p_x,
18
+ mat_b,
19
+ mat_c,
20
+ mat_gamma,
21
+ rho,
22
+ eps_abs,
23
+ eps_rel,
24
+ lambda_1,
25
+ items_count,
26
+ threshold,
27
+ multiplicator,
28
+ ): # pragma: no cover
29
+ # calculate mat_b
30
+ mat_b = p_x + np.dot(inv_matrix, rho * mat_c - mat_gamma)
31
+ vec_gamma = np.diag(mat_b) / np.diag(inv_matrix)
32
+ mat_b -= inv_matrix * vec_gamma
33
+
34
+ # calculate mat_c
35
+ prev_mat_c = mat_c
36
+ mat_c = mat_b + mat_gamma / rho
37
+ coef = lambda_1 / rho
38
+ mat_c = np.maximum(mat_c - coef, 0.0) - np.maximum(-mat_c - coef, 0.0)
39
+
40
+ # calculate mat_gamma
41
+ mat_gamma += rho * (mat_b - mat_c)
42
+
43
+ # calculate residuals
44
+ r_primal = np.linalg.norm(mat_b - mat_c)
45
+ r_dual = np.linalg.norm(-rho * (mat_c - prev_mat_c))
46
+ eps_primal = eps_abs * items_count + eps_rel * max(np.linalg.norm(mat_b), np.linalg.norm(mat_c))
47
+ eps_dual = eps_abs * items_count + eps_rel * np.linalg.norm(mat_gamma)
48
+ if r_primal > threshold * r_dual:
49
+ rho *= multiplicator
50
+ elif threshold * r_primal < r_dual:
51
+ rho /= multiplicator
52
+
53
+ return (
54
+ mat_b,
55
+ mat_c,
56
+ mat_gamma,
57
+ rho,
58
+ r_primal,
59
+ r_dual,
60
+ eps_primal,
61
+ eps_dual,
62
+ )
63
+
64
+
65
+ class ADMMSLIM(NeighbourRec):
66
+ """`ADMM SLIM: Sparse Recommendations for Many Users
67
+ <http://www.cs.columbia.edu/~jebara/papers/wsdm20_ADMM.pdf>`_
68
+
69
+ This is a modification for the basic SLIM model.
70
+ Recommendations are improved with Alternating Direction Method of Multipliers.
71
+ """
72
+
73
+ def _get_ann_infer_params(self) -> Dict[str, Any]:
74
+ return {
75
+ "features_col": None,
76
+ }
77
+
78
+ rho: float
79
+ threshold: float = 5
80
+ multiplicator: float = 2
81
+ eps_abs: float = 1.0e-3
82
+ eps_rel: float = 1.0e-3
83
+ max_iteration: int = 100
84
+ _mat_c: np.ndarray
85
+ _mat_b: np.ndarray
86
+ _mat_gamma: np.ndarray
87
+ _search_space = {
88
+ "lambda_1": {"type": "loguniform", "args": [1e-9, 50]},
89
+ "lambda_2": {"type": "loguniform", "args": [1e-9, 5000]},
90
+ }
91
+
92
+ def __init__(
93
+ self,
94
+ lambda_1: float = 5,
95
+ lambda_2: float = 5000,
96
+ seed: Optional[int] = None,
97
+ index_builder: Optional[IndexBuilder] = None,
98
+ ):
99
+ """
100
+ :param lambda_1: l1 regularization term
101
+ :param lambda_2: l2 regularization term
102
+ :param seed: random seed
103
+ :param index_builder: `IndexBuilder` instance that adds ANN functionality.
104
+ If not set, then ann will not be used.
105
+ """
106
+ if lambda_1 < 0 or lambda_2 <= 0:
107
+ msg = "Invalid regularization parameters"
108
+ raise ValueError(msg)
109
+ self.lambda_1 = lambda_1
110
+ self.lambda_2 = lambda_2
111
+ self.rho = lambda_2
112
+ self.seed = seed
113
+ if isinstance(index_builder, (IndexBuilder, type(None))):
114
+ self.index_builder = index_builder
115
+ elif isinstance(index_builder, dict):
116
+ self.init_builder_from_dict(index_builder)
117
+
118
+ @property
119
+ def _init_args(self):
120
+ return {
121
+ "lambda_1": self.lambda_1,
122
+ "lambda_2": self.lambda_2,
123
+ "seed": self.seed,
124
+ }
125
+
126
+ def _fit(
127
+ self,
128
+ log: SparkDataFrame,
129
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
130
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
131
+ ) -> None:
132
+ self.logger.debug("Fitting ADMM SLIM")
133
+ pandas_log = log.select("user_idx", "item_idx", "relevance").toPandas()
134
+ interactions_matrix = csr_matrix(
135
+ (
136
+ pandas_log["relevance"],
137
+ (pandas_log["user_idx"], pandas_log["item_idx"]),
138
+ ),
139
+ shape=(self._user_dim, self._item_dim),
140
+ )
141
+ self.logger.debug("Gram matrix")
142
+ xtx = (interactions_matrix.T @ interactions_matrix).toarray()
143
+ self.logger.debug("Inverse matrix")
144
+ inv_matrix = np.linalg.inv(xtx + (self.lambda_2 + self.rho) * np.eye(self._item_dim))
145
+ self.logger.debug("Main calculations")
146
+ p_x = inv_matrix @ xtx
147
+ mat_b, mat_c, mat_gamma = self._init_matrix(self._item_dim)
148
+ r_primal = np.linalg.norm(mat_b - mat_c)
149
+ r_dual = np.linalg.norm(self.rho * mat_c)
150
+ eps_primal, eps_dual = 0.0, 0.0
151
+ iteration = 0
152
+ while (r_primal > eps_primal or r_dual > eps_dual) and iteration < self.max_iteration:
153
+ iteration += 1
154
+ (
155
+ mat_b,
156
+ mat_c,
157
+ mat_gamma,
158
+ self.rho,
159
+ r_primal,
160
+ r_dual,
161
+ eps_primal,
162
+ eps_dual,
163
+ ) = _main_iteration(
164
+ inv_matrix,
165
+ p_x,
166
+ mat_b,
167
+ mat_c,
168
+ mat_gamma,
169
+ self.rho,
170
+ self.eps_abs,
171
+ self.eps_rel,
172
+ self.lambda_1,
173
+ self._item_dim,
174
+ self.threshold,
175
+ self.multiplicator,
176
+ )
177
+ result_message = (
178
+ f"Iteration: {iteration}. primal gap: "
179
+ f"{r_primal - eps_primal:.5}; dual gap: "
180
+ f" {r_dual - eps_dual:.5}; rho: {self.rho}"
181
+ )
182
+ self.logger.debug(result_message)
183
+
184
+ mat_c_sparse = coo_matrix(mat_c)
185
+ mat_c_pd = pd.DataFrame(
186
+ {
187
+ "item_idx_one": mat_c_sparse.row.astype(np.int32),
188
+ "item_idx_two": mat_c_sparse.col.astype(np.int32),
189
+ "similarity": mat_c_sparse.data,
190
+ }
191
+ )
192
+ self.similarity = State().session.createDataFrame(
193
+ mat_c_pd,
194
+ schema="item_idx_one int, item_idx_two int, similarity double",
195
+ )
196
+ self.similarity.cache().count()
197
+
198
+ def _init_matrix(self, size: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
199
+ """Matrix initialization"""
200
+ if self.seed is not None:
201
+ np.random.seed(self.seed)
202
+ mat_b = np.random.rand(size, size)
203
+ mat_c = np.random.rand(size, size)
204
+ mat_gamma = np.random.rand(size, size)
205
+ return mat_b, mat_c, mat_gamma
@@ -0,0 +1,204 @@
1
+ """
2
+ NeighbourRec - base class that requires log at prediction time.
3
+ Part of set of abstract classes (from base_rec.py)
4
+ """
5
+
6
+ from abc import ABC
7
+ from typing import Any, Dict, Iterable, Optional, Union
8
+
9
+ from replay.experimental.models.base_rec import Recommender
10
+ from replay.models.extensions.ann.ann_mixin import ANNMixin
11
+ from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
12
+
13
+ if PYSPARK_AVAILABLE:
14
+ from pyspark.sql import functions as sf
15
+ from pyspark.sql.column import Column
16
+
17
+
18
+ class NeighbourRec(Recommender, ANNMixin, ABC):
19
+ """Base class that requires log at prediction time"""
20
+
21
+ similarity: Optional[SparkDataFrame]
22
+ can_predict_item_to_item: bool = True
23
+ can_predict_cold_users: bool = True
24
+ can_change_metric: bool = False
25
+ item_to_item_metrics = ["similarity"]
26
+ _similarity_metric = "similarity"
27
+
28
+ @property
29
+ def _dataframes(self):
30
+ return {"similarity": self.similarity}
31
+
32
+ def _clear_cache(self):
33
+ if hasattr(self, "similarity"):
34
+ self.similarity.unpersist()
35
+
36
+ @property
37
+ def similarity_metric(self):
38
+ return self._similarity_metric
39
+
40
+ @similarity_metric.setter
41
+ def similarity_metric(self, value):
42
+ if not self.can_change_metric:
43
+ msg = "This class does not support changing similarity metrics"
44
+ raise ValueError(msg)
45
+ if value not in self.item_to_item_metrics:
46
+ msg = f"Select one of the valid metrics for predict: {self.item_to_item_metrics}"
47
+ raise ValueError(msg)
48
+ self._similarity_metric = value
49
+
50
+ def _predict_pairs_inner(
51
+ self,
52
+ log: SparkDataFrame,
53
+ filter_df: SparkDataFrame,
54
+ condition: Column,
55
+ users: SparkDataFrame,
56
+ ) -> SparkDataFrame:
57
+ """
58
+ Get recommendations for all provided users
59
+ and filter results with ``filter_df`` by ``condition``.
60
+ It allows to implement both ``predict_pairs`` and usual ``predict``@k.
61
+
62
+ :param log: historical interactions, SparkDataFrame
63
+ ``[user_idx, item_idx, timestamp, relevance]``.
64
+ :param filter_df: SparkDataFrame use to filter items:
65
+ ``[item_idx_filter]`` or ``[user_idx_filter, item_idx_filter]``.
66
+ :param condition: condition used for inner join with ``filter_df``
67
+ :param users: users to calculate recommendations for
68
+ :return: SparkDataFrame ``[user_idx, item_idx, relevance]``
69
+ """
70
+ if log is None:
71
+ msg = "log is not provided, but it is required for prediction"
72
+ raise ValueError(msg)
73
+
74
+ recs = (
75
+ log.join(users, how="inner", on="user_idx")
76
+ .join(
77
+ self.similarity,
78
+ how="inner",
79
+ on=sf.col("item_idx") == sf.col("item_idx_one"),
80
+ )
81
+ .join(
82
+ filter_df,
83
+ how="inner",
84
+ on=condition,
85
+ )
86
+ .groupby("user_idx", "item_idx_two")
87
+ .agg(sf.sum(self.similarity_metric).alias("relevance"))
88
+ .withColumnRenamed("item_idx_two", "item_idx")
89
+ )
90
+ return recs
91
+
92
+ def _predict(
93
+ self,
94
+ log: SparkDataFrame,
95
+ k: int, # noqa: ARG002
96
+ users: SparkDataFrame,
97
+ items: SparkDataFrame,
98
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
99
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
100
+ filter_seen_items: bool = True, # noqa: ARG002
101
+ ) -> SparkDataFrame:
102
+ return self._predict_pairs_inner(
103
+ log=log,
104
+ filter_df=items.withColumnRenamed("item_idx", "item_idx_filter"),
105
+ condition=sf.col("item_idx_two") == sf.col("item_idx_filter"),
106
+ users=users,
107
+ )
108
+
109
+ def _predict_pairs(
110
+ self,
111
+ pairs: SparkDataFrame,
112
+ log: Optional[SparkDataFrame] = None,
113
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
114
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
115
+ ) -> SparkDataFrame:
116
+ if log is None:
117
+ msg = "log is not provided, but it is required for prediction"
118
+ raise ValueError(msg)
119
+
120
+ return self._predict_pairs_inner(
121
+ log=log,
122
+ filter_df=(
123
+ pairs.withColumnRenamed("user_idx", "user_idx_filter").withColumnRenamed("item_idx", "item_idx_filter")
124
+ ),
125
+ condition=(sf.col("user_idx") == sf.col("user_idx_filter"))
126
+ & (sf.col("item_idx_two") == sf.col("item_idx_filter")),
127
+ users=pairs.select("user_idx").distinct(),
128
+ )
129
+
130
+ def get_nearest_items(
131
+ self,
132
+ items: Union[SparkDataFrame, Iterable],
133
+ k: int,
134
+ metric: Optional[str] = None,
135
+ candidates: Optional[Union[SparkDataFrame, Iterable]] = None,
136
+ ) -> SparkDataFrame:
137
+ """
138
+ Get k most similar items be the `metric` for each of the `items`.
139
+
140
+ :param items: spark dataframe or list of item ids to find neighbors
141
+ :param k: number of neighbors
142
+ :param metric: metric is not used to find neighbours in NeighbourRec,
143
+ the parameter is ignored
144
+ :param candidates: spark dataframe or list of items
145
+ to consider as similar, e.g. popular/new items. If None,
146
+ all items presented during model training are used.
147
+ :return: dataframe with the most similar items an distance,
148
+ where bigger value means greater similarity.
149
+ spark-dataframe with columns ``[item_idx, neighbour_item_idx, similarity]``
150
+ """
151
+
152
+ if metric is not None:
153
+ self.logger.debug(
154
+ "Metric is not used to determine nearest items in %s model",
155
+ str(self),
156
+ )
157
+
158
+ return self._get_nearest_items_wrap(
159
+ items=items,
160
+ k=k,
161
+ metric=metric,
162
+ candidates=candidates,
163
+ )
164
+
165
+ def _get_nearest_items(
166
+ self,
167
+ items: SparkDataFrame,
168
+ metric: Optional[str] = None,
169
+ candidates: Optional[SparkDataFrame] = None,
170
+ ) -> SparkDataFrame:
171
+ similarity_filtered = self.similarity.join(
172
+ items.withColumnRenamed("item_idx", "item_idx_one"),
173
+ on="item_idx_one",
174
+ )
175
+
176
+ if candidates is not None:
177
+ similarity_filtered = similarity_filtered.join(
178
+ candidates.withColumnRenamed("item_idx", "item_idx_two"),
179
+ on="item_idx_two",
180
+ )
181
+
182
+ return similarity_filtered.select(
183
+ "item_idx_one",
184
+ "item_idx_two",
185
+ "similarity" if metric is None else metric,
186
+ )
187
+
188
+ def _get_ann_build_params(self, interactions: SparkDataFrame) -> Dict[str, Any]:
189
+ self.index_builder.index_params.items_count = interactions.select(sf.max("item_idx")).first()[0] + 1
190
+ return {
191
+ "features_col": None,
192
+ }
193
+
194
+ def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame: # noqa: ARG002
195
+ similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
196
+ return similarity_df
197
+
198
+ def _get_vectors_to_infer_ann_inner(
199
+ self, interactions: SparkDataFrame, queries: SparkDataFrame # noqa: ARG002
200
+ ) -> SparkDataFrame:
201
+ user_vectors = interactions.groupBy("user_idx").agg(
202
+ sf.collect_list("item_idx").alias("vector_items"), sf.collect_list("relevance").alias("vector_relevances")
203
+ )
204
+ return user_vectors