replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
@@ -1,74 +0,0 @@
1
- from typing import Optional
2
-
3
- from replay.utils import DataFrameLike, SparkDataFrame
4
- from replay.utils.spark_utils import convert2spark, get_top_k_recs
5
-
6
- from .base_metric import RecOnlyMetric, fill_na_with_empty_array, filter_sort
7
-
8
-
9
- # pylint: disable=too-few-public-methods
10
- class Unexpectedness(RecOnlyMetric):
11
- """
12
- Fraction of recommended items that are not present in some baseline recommendations.
13
-
14
- >>> import pandas as pd
15
- >>> from replay.utils.session_handler import get_spark_session, State
16
- >>> spark = get_spark_session(1, 1)
17
- >>> state = State(spark)
18
-
19
- >>> log = pd.DataFrame({"user_idx": [1, 1, 1], "item_idx": [1, 2, 3], "relevance": [5, 5, 5], "timestamp": [1, 1, 1]})
20
- >>> recs = pd.DataFrame({"user_idx": [1, 1, 1], "item_idx": [0, 0, 1], "relevance": [5, 5, 5], "timestamp": [1, 1, 1]})
21
- >>> metric = Unexpectedness(log)
22
- >>> round(metric(recs, 3), 2)
23
- 0.67
24
- """
25
-
26
- _scala_udf_name = "getUnexpectednessMetricValue"
27
-
28
- def __init__(
29
- self, pred: DataFrameLike,
30
- use_scala_udf: bool = False
31
- ): # pylint: disable=super-init-not-called
32
- """
33
- :param pred: model predictions
34
- """
35
- self._use_scala_udf = use_scala_udf
36
- self.pred = convert2spark(pred)
37
-
38
- @staticmethod
39
- def _get_metric_value_by_user(k, *args) -> float:
40
- pred = args[0]
41
- base_pred = args[1]
42
- if len(pred) == 0:
43
- return 0
44
- return 1.0 - len(set(pred[:k]) & set(base_pred[:k])) / k
45
-
46
- def _get_enriched_recommendations(
47
- self,
48
- recommendations: SparkDataFrame,
49
- ground_truth: SparkDataFrame,
50
- max_k: int,
51
- ground_truth_users: Optional[DataFrameLike] = None,
52
- ) -> SparkDataFrame:
53
- recommendations = convert2spark(recommendations)
54
- ground_truth_users = convert2spark(ground_truth_users)
55
- base_pred = self.pred
56
-
57
- # TO DO: preprocess base_recs once in __init__
58
-
59
- base_recs = filter_sort(base_pred).withColumnRenamed("pred", "base_pred")
60
-
61
- # if there are duplicates in recommendations,
62
- # we will leave fewer than k recommendations after sort_udf
63
- recommendations = get_top_k_recs(recommendations, k=max_k)
64
- recommendations = filter_sort(recommendations)
65
- recommendations = recommendations.join(base_recs, how="right", on=["user_idx"])
66
-
67
- if ground_truth_users is not None:
68
- recommendations = recommendations.join(
69
- ground_truth_users, on="user_idx", how="right"
70
- )
71
-
72
- return fill_na_with_empty_array(
73
- recommendations, "pred", base_pred.schema["item_idx"].dataType
74
- )
@@ -1,10 +0,0 @@
1
- from replay.experimental.models.admm_slim import ADMMSLIM
2
- from replay.experimental.models.base_torch_rec import TorchRecommender
3
- from replay.experimental.models.ddpg import DDPG
4
- from replay.experimental.models.dt4rec.dt4rec import DT4Rec
5
- from replay.experimental.models.implicit_wrap import ImplicitWrap
6
- from replay.experimental.models.lightfm_wrap import LightFMWrap
7
- from replay.experimental.models.mult_vae import MultVAE
8
- from replay.experimental.models.neuromf import NeuroMF
9
- from replay.experimental.models.scala_als import ScalaALSWrap
10
- from replay.experimental.models.cql import CQL
@@ -1,216 +0,0 @@
1
- from typing import Any, Dict, Optional, Tuple
2
-
3
- import numba as nb
4
- import numpy as np
5
- import pandas as pd
6
- from scipy.sparse import coo_matrix, csr_matrix
7
-
8
- from replay.experimental.models.base_neighbour_rec import NeighbourRec
9
- from replay.experimental.utils.session_handler import State
10
- from replay.models.extensions.ann.index_builders.base_index_builder import IndexBuilder
11
- from replay.utils import SparkDataFrame
12
-
13
-
14
- # pylint: disable=too-many-arguments, too-many-locals
15
- @nb.njit(parallel=True)
16
- def _main_iteration(
17
- inv_matrix,
18
- p_x,
19
- mat_b,
20
- mat_c,
21
- mat_gamma,
22
- rho,
23
- eps_abs,
24
- eps_rel,
25
- lambda_1,
26
- items_count,
27
- threshold,
28
- multiplicator,
29
- ): # pragma: no cover
30
-
31
- # calculate mat_b
32
- mat_b = p_x + np.dot(inv_matrix, rho * mat_c - mat_gamma)
33
- vec_gamma = np.diag(mat_b) / np.diag(inv_matrix)
34
- mat_b -= inv_matrix * vec_gamma
35
-
36
- # calculate mat_c
37
- prev_mat_c = mat_c
38
- mat_c = mat_b + mat_gamma / rho
39
- coef = lambda_1 / rho
40
- mat_c = np.maximum(mat_c - coef, 0.0) - np.maximum(-mat_c - coef, 0.0)
41
-
42
- # calculate mat_gamma
43
- mat_gamma += rho * (mat_b - mat_c)
44
-
45
- # calculate residuals
46
- r_primal = np.linalg.norm(mat_b - mat_c)
47
- r_dual = np.linalg.norm(-rho * (mat_c - prev_mat_c))
48
- eps_primal = eps_abs * items_count + eps_rel * max(
49
- np.linalg.norm(mat_b), np.linalg.norm(mat_c)
50
- )
51
- eps_dual = eps_abs * items_count + eps_rel * np.linalg.norm(mat_gamma)
52
- if r_primal > threshold * r_dual:
53
- rho *= multiplicator
54
- elif threshold * r_primal < r_dual:
55
- rho /= multiplicator
56
-
57
- return (
58
- mat_b,
59
- mat_c,
60
- mat_gamma,
61
- rho,
62
- r_primal,
63
- r_dual,
64
- eps_primal,
65
- eps_dual,
66
- )
67
-
68
-
69
- # pylint: disable=too-many-instance-attributes, too-many-ancestors
70
- class ADMMSLIM(NeighbourRec):
71
- """`ADMM SLIM: Sparse Recommendations for Many Users
72
- <http://www.cs.columbia.edu/~jebara/papers/wsdm20_ADMM.pdf>`_
73
-
74
- This is a modification for the basic SLIM model.
75
- Recommendations are improved with Alternating Direction Method of Multipliers.
76
- """
77
-
78
- def _get_ann_infer_params(self) -> Dict[str, Any]:
79
- return {
80
- "features_col": None,
81
- }
82
-
83
- rho: float
84
- threshold: float = 5
85
- multiplicator: float = 2
86
- eps_abs: float = 1.0e-3
87
- eps_rel: float = 1.0e-3
88
- max_iteration: int = 100
89
- _mat_c: np.ndarray
90
- _mat_b: np.ndarray
91
- _mat_gamma: np.ndarray
92
- _search_space = {
93
- "lambda_1": {"type": "loguniform", "args": [1e-9, 50]},
94
- "lambda_2": {"type": "loguniform", "args": [1e-9, 5000]},
95
- }
96
-
97
- def __init__(
98
- self,
99
- lambda_1: float = 5,
100
- lambda_2: float = 5000,
101
- seed: Optional[int] = None,
102
- index_builder: Optional[IndexBuilder] = None,
103
- ):
104
- """
105
- :param lambda_1: l1 regularization term
106
- :param lambda_2: l2 regularization term
107
- :param seed: random seed
108
- :param index_builder: `IndexBuilder` instance that adds ANN functionality.
109
- If not set, then ann will not be used.
110
- """
111
- if lambda_1 < 0 or lambda_2 <= 0:
112
- raise ValueError("Invalid regularization parameters")
113
- self.lambda_1 = lambda_1
114
- self.lambda_2 = lambda_2
115
- self.rho = lambda_2
116
- self.seed = seed
117
- if isinstance(index_builder, (IndexBuilder, type(None))):
118
- self.index_builder = index_builder
119
- elif isinstance(index_builder, dict):
120
- self.init_builder_from_dict(index_builder)
121
-
122
- @property
123
- def _init_args(self):
124
- return {
125
- "lambda_1": self.lambda_1,
126
- "lambda_2": self.lambda_2,
127
- "seed": self.seed,
128
- }
129
-
130
- # pylint: disable=too-many-locals
131
- def _fit(
132
- self,
133
- log: SparkDataFrame,
134
- user_features: Optional[SparkDataFrame] = None,
135
- item_features: Optional[SparkDataFrame] = None,
136
- ) -> None:
137
- self.logger.debug("Fitting ADMM SLIM")
138
- pandas_log = log.select("user_idx", "item_idx", "relevance").toPandas()
139
- interactions_matrix = csr_matrix(
140
- (
141
- pandas_log["relevance"],
142
- (pandas_log["user_idx"], pandas_log["item_idx"]),
143
- ),
144
- shape=(self._user_dim, self._item_dim),
145
- )
146
- self.logger.debug("Gram matrix")
147
- xtx = (interactions_matrix.T @ interactions_matrix).toarray()
148
- self.logger.debug("Inverse matrix")
149
- inv_matrix = np.linalg.inv(
150
- xtx + (self.lambda_2 + self.rho) * np.eye(self._item_dim)
151
- )
152
- self.logger.debug("Main calculations")
153
- p_x = inv_matrix @ xtx
154
- mat_b, mat_c, mat_gamma = self._init_matrix(self._item_dim)
155
- r_primal = np.linalg.norm(mat_b - mat_c)
156
- r_dual = np.linalg.norm(self.rho * mat_c)
157
- eps_primal, eps_dual = 0.0, 0.0
158
- iteration = 0
159
- while (
160
- r_primal > eps_primal or r_dual > eps_dual
161
- ) and iteration < self.max_iteration:
162
- iteration += 1
163
- (
164
- mat_b,
165
- mat_c,
166
- mat_gamma,
167
- self.rho,
168
- r_primal,
169
- r_dual,
170
- eps_primal,
171
- eps_dual,
172
- ) = _main_iteration(
173
- inv_matrix,
174
- p_x,
175
- mat_b,
176
- mat_c,
177
- mat_gamma,
178
- self.rho,
179
- self.eps_abs,
180
- self.eps_rel,
181
- self.lambda_1,
182
- self._item_dim,
183
- self.threshold,
184
- self.multiplicator,
185
- )
186
- result_message = (
187
- f"Iteration: {iteration}. primal gap: "
188
- f"{r_primal - eps_primal:.5}; dual gap: "
189
- f" {r_dual - eps_dual:.5}; rho: {self.rho}"
190
- )
191
- self.logger.debug(result_message)
192
-
193
- mat_c_sparse = coo_matrix(mat_c)
194
- mat_c_pd = pd.DataFrame(
195
- {
196
- "item_idx_one": mat_c_sparse.row.astype(np.int32),
197
- "item_idx_two": mat_c_sparse.col.astype(np.int32),
198
- "similarity": mat_c_sparse.data,
199
- }
200
- )
201
- self.similarity = State().session.createDataFrame(
202
- mat_c_pd,
203
- schema="item_idx_one int, item_idx_two int, similarity double",
204
- )
205
- self.similarity.cache().count()
206
-
207
- def _init_matrix(
208
- self, size: int
209
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
210
- """Matrix initialization"""
211
- if self.seed is not None:
212
- np.random.seed(self.seed)
213
- mat_b = np.random.rand(size, size) # type: ignore
214
- mat_c = np.random.rand(size, size) # type: ignore
215
- mat_gamma = np.random.rand(size, size) # type: ignore
216
- return mat_b, mat_c, mat_gamma
@@ -1,222 +0,0 @@
1
- # pylint: disable=too-many-lines
2
- """
3
- NeighbourRec - base class that requires log at prediction time.
4
- Part of set of abstract classes (from base_rec.py)
5
- """
6
-
7
- from abc import ABC
8
- from typing import Any, Dict, Iterable, Optional, Union
9
-
10
- from replay.experimental.models.base_rec import Recommender
11
- from replay.models.extensions.ann.ann_mixin import ANNMixin
12
- from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
13
-
14
- if PYSPARK_AVAILABLE:
15
- from pyspark.sql import functions as sf
16
- from pyspark.sql.column import Column
17
-
18
-
19
- class NeighbourRec(Recommender, ANNMixin, ABC):
20
- """Base class that requires log at prediction time"""
21
-
22
- similarity: Optional[SparkDataFrame]
23
- can_predict_item_to_item: bool = True
24
- can_predict_cold_users: bool = True
25
- can_change_metric: bool = False
26
- item_to_item_metrics = ["similarity"]
27
- _similarity_metric = "similarity"
28
-
29
- @property
30
- def _dataframes(self):
31
- return {"similarity": self.similarity}
32
-
33
- def _clear_cache(self):
34
- if hasattr(self, "similarity"):
35
- self.similarity.unpersist()
36
-
37
- # pylint: disable=missing-function-docstring
38
- @property
39
- def similarity_metric(self):
40
- return self._similarity_metric
41
-
42
- @similarity_metric.setter
43
- def similarity_metric(self, value):
44
- if not self.can_change_metric:
45
- raise ValueError(
46
- "This class does not support changing similarity metrics"
47
- )
48
- if value not in self.item_to_item_metrics:
49
- raise ValueError(
50
- f"Select one of the valid metrics for predict: "
51
- f"{self.item_to_item_metrics}"
52
- )
53
- self._similarity_metric = value
54
-
55
- def _predict_pairs_inner(
56
- self,
57
- log: SparkDataFrame,
58
- filter_df: SparkDataFrame,
59
- condition: Column,
60
- users: SparkDataFrame,
61
- ) -> SparkDataFrame:
62
- """
63
- Get recommendations for all provided users
64
- and filter results with ``filter_df`` by ``condition``.
65
- It allows to implement both ``predict_pairs`` and usual ``predict``@k.
66
-
67
- :param log: historical interactions, SparkDataFrame
68
- ``[user_idx, item_idx, timestamp, relevance]``.
69
- :param filter_df: SparkDataFrame use to filter items:
70
- ``[item_idx_filter]`` or ``[user_idx_filter, item_idx_filter]``.
71
- :param condition: condition used for inner join with ``filter_df``
72
- :param users: users to calculate recommendations for
73
- :return: SparkDataFrame ``[user_idx, item_idx, relevance]``
74
- """
75
- if log is None:
76
- raise ValueError(
77
- "log is not provided, but it is required for prediction"
78
- )
79
-
80
- recs = (
81
- log.join(users, how="inner", on="user_idx")
82
- .join(
83
- self.similarity,
84
- how="inner",
85
- on=sf.col("item_idx") == sf.col("item_idx_one"),
86
- )
87
- .join(
88
- filter_df,
89
- how="inner",
90
- on=condition,
91
- )
92
- .groupby("user_idx", "item_idx_two")
93
- .agg(sf.sum(self.similarity_metric).alias("relevance"))
94
- .withColumnRenamed("item_idx_two", "item_idx")
95
- )
96
- return recs
97
-
98
- # pylint: disable=too-many-arguments
99
- def _predict(
100
- self,
101
- log: SparkDataFrame,
102
- k: int,
103
- users: SparkDataFrame,
104
- items: SparkDataFrame,
105
- user_features: Optional[SparkDataFrame] = None,
106
- item_features: Optional[SparkDataFrame] = None,
107
- filter_seen_items: bool = True,
108
- ) -> SparkDataFrame:
109
-
110
- return self._predict_pairs_inner(
111
- log=log,
112
- filter_df=items.withColumnRenamed("item_idx", "item_idx_filter"),
113
- condition=sf.col("item_idx_two") == sf.col("item_idx_filter"),
114
- users=users,
115
- )
116
-
117
- def _predict_pairs(
118
- self,
119
- pairs: SparkDataFrame,
120
- log: Optional[SparkDataFrame] = None,
121
- user_features: Optional[SparkDataFrame] = None,
122
- item_features: Optional[SparkDataFrame] = None,
123
- ) -> SparkDataFrame:
124
-
125
- if log is None:
126
- raise ValueError(
127
- "log is not provided, but it is required for prediction"
128
- )
129
-
130
- return self._predict_pairs_inner(
131
- log=log,
132
- filter_df=(
133
- pairs.withColumnRenamed(
134
- "user_idx", "user_idx_filter"
135
- ).withColumnRenamed("item_idx", "item_idx_filter")
136
- ),
137
- condition=(sf.col("user_idx") == sf.col("user_idx_filter"))
138
- & (sf.col("item_idx_two") == sf.col("item_idx_filter")),
139
- users=pairs.select("user_idx").distinct(),
140
- )
141
-
142
- def get_nearest_items(
143
- self,
144
- items: Union[SparkDataFrame, Iterable],
145
- k: int,
146
- metric: Optional[str] = None,
147
- candidates: Optional[Union[SparkDataFrame, Iterable]] = None,
148
- ) -> SparkDataFrame:
149
- """
150
- Get k most similar items be the `metric` for each of the `items`.
151
-
152
- :param items: spark dataframe or list of item ids to find neighbors
153
- :param k: number of neighbors
154
- :param metric: metric is not used to find neighbours in NeighbourRec,
155
- the parameter is ignored
156
- :param candidates: spark dataframe or list of items
157
- to consider as similar, e.g. popular/new items. If None,
158
- all items presented during model training are used.
159
- :return: dataframe with the most similar items an distance,
160
- where bigger value means greater similarity.
161
- spark-dataframe with columns ``[item_idx, neighbour_item_idx, similarity]``
162
- """
163
-
164
- if metric is not None:
165
- self.logger.debug(
166
- "Metric is not used to determine nearest items in %s model",
167
- str(self),
168
- )
169
-
170
- return self._get_nearest_items_wrap(
171
- items=items,
172
- k=k,
173
- metric=metric,
174
- candidates=candidates,
175
- )
176
-
177
- def _get_nearest_items(
178
- self,
179
- items: SparkDataFrame,
180
- metric: Optional[str] = None,
181
- candidates: Optional[SparkDataFrame] = None,
182
- ) -> SparkDataFrame:
183
-
184
- similarity_filtered = self.similarity.join(
185
- items.withColumnRenamed("item_idx", "item_idx_one"),
186
- on="item_idx_one",
187
- )
188
-
189
- if candidates is not None:
190
- similarity_filtered = similarity_filtered.join(
191
- candidates.withColumnRenamed("item_idx", "item_idx_two"),
192
- on="item_idx_two",
193
- )
194
-
195
- return similarity_filtered.select(
196
- "item_idx_one",
197
- "item_idx_two",
198
- "similarity" if metric is None else metric,
199
- )
200
-
201
- def _get_ann_build_params(self, interactions: SparkDataFrame) -> Dict[str, Any]:
202
- self.index_builder.index_params.items_count = interactions.select(sf.max("item_idx")).first()[0] + 1
203
- return {
204
- "features_col": None,
205
- }
206
-
207
- def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame:
208
- similarity_df = self.similarity.select(
209
- "similarity", "item_idx_one", "item_idx_two"
210
- )
211
- return similarity_df
212
-
213
- def _get_vectors_to_infer_ann_inner(
214
- self, interactions: SparkDataFrame, queries: SparkDataFrame
215
- ) -> SparkDataFrame:
216
-
217
- user_vectors = (
218
- interactions.groupBy("user_idx").agg(
219
- sf.collect_list("item_idx").alias("vector_items"),
220
- sf.collect_list("relevance").alias("vector_relevances"))
221
- )
222
- return user_vectors