replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
@@ -1,138 +0,0 @@
1
- from os.path import join
2
- from typing import Optional
3
-
4
- from replay.data import get_schema
5
- from replay.experimental.models.base_rec import Recommender
6
- from replay.preprocessing import CSRConverter
7
- from replay.utils import PandasDataFrame, SparkDataFrame
8
- from replay.utils.spark_utils import load_pickled_from_parquet, save_picklable_to_parquet
9
-
10
-
11
- class ImplicitWrap(Recommender):
12
- """Wrapper for `implicit
13
- <https://github.com/benfred/implicit>`_
14
-
15
- Example:
16
-
17
- >>> import implicit
18
- >>> model = implicit.als.AlternatingLeastSquares(factors=5)
19
- >>> als = ImplicitWrap(model)
20
-
21
- This way you can use implicit models as any other in replay
22
- with conversions made under the hood.
23
-
24
- >>> import pandas as pd
25
- >>> from replay.utils.spark_utils import convert2spark
26
- >>> df = pd.DataFrame({"user_idx": [1, 1, 2, 2], "item_idx": [1, 2, 2, 3], "relevance": [1, 1, 1, 1]})
27
- >>> df = convert2spark(df)
28
- >>> als.fit_predict(df, 1, users=[1])[["user_idx", "item_idx"]].toPandas()
29
- user_idx item_idx
30
- 0 1 3
31
- """
32
-
33
- def __init__(self, model):
34
- """Provide initialized ``implicit`` model."""
35
- self.model = model
36
- self.logger.info(
37
- "The model is a wrapper of a non-distributed model which may affect performance"
38
- )
39
-
40
- @property
41
- def _init_args(self):
42
- return {"model": None}
43
-
44
- def _save_model(self, path: str):
45
- save_picklable_to_parquet(self.model, join(path, "model"))
46
-
47
- def _load_model(self, path: str):
48
- self.model = load_pickled_from_parquet(join(path, "model"))
49
-
50
- def _fit(
51
- self,
52
- log: SparkDataFrame,
53
- user_features: Optional[SparkDataFrame] = None,
54
- item_features: Optional[SparkDataFrame] = None,
55
- ) -> None:
56
- matrix = CSRConverter(
57
- first_dim_column="user_idx",
58
- second_dim_column="item_idx",
59
- data_column="relevance"
60
- ).transform(log)
61
- self.model.fit(matrix)
62
-
63
- @staticmethod
64
- def _pd_func(model, items_to_use=None, user_item_data=None, filter_seen_items=False):
65
- def predict_by_user_item(pandas_df):
66
- user = int(pandas_df["user_idx"].iloc[0])
67
- items = items_to_use if items_to_use else pandas_df.item_idx.to_list()
68
-
69
- items_res, rel = model.recommend(
70
- userid=user,
71
- user_items=user_item_data[user] if filter_seen_items else None,
72
- N=len(items),
73
- filter_already_liked_items=filter_seen_items,
74
- items=items,
75
- )
76
- return PandasDataFrame(
77
- {
78
- "user_idx": [user] * len(items_res),
79
- "item_idx": items_res,
80
- "relevance": rel,
81
- }
82
- )
83
-
84
- return predict_by_user_item
85
-
86
- # pylint: disable=too-many-arguments
87
- def _predict(
88
- self,
89
- log: SparkDataFrame,
90
- k: int,
91
- users: SparkDataFrame,
92
- items: SparkDataFrame,
93
- user_features: Optional[SparkDataFrame] = None,
94
- item_features: Optional[SparkDataFrame] = None,
95
- filter_seen_items: bool = True,
96
- ) -> SparkDataFrame:
97
-
98
- items_to_use = items.distinct().toPandas().item_idx.tolist()
99
- user_item_data = CSRConverter(
100
- first_dim_column="user_idx",
101
- second_dim_column="item_idx",
102
- data_column="relevance"
103
- ).transform(log)
104
- model = self.model
105
- rec_schema = get_schema(
106
- query_column="user_idx",
107
- item_column="item_idx",
108
- rating_column="relevance",
109
- has_timestamp=False,
110
- )
111
- return (
112
- users.select("user_idx")
113
- .groupby("user_idx")
114
- .applyInPandas(self._pd_func(
115
- model=model,
116
- items_to_use=items_to_use,
117
- user_item_data=user_item_data,
118
- filter_seen_items=filter_seen_items), rec_schema)
119
- )
120
-
121
- def _predict_pairs(
122
- self,
123
- pairs: SparkDataFrame,
124
- log: Optional[SparkDataFrame] = None,
125
- user_features: Optional[SparkDataFrame] = None,
126
- item_features: Optional[SparkDataFrame] = None,
127
- ) -> SparkDataFrame:
128
-
129
- model = self.model
130
- rec_schema = get_schema(
131
- query_column="user_idx",
132
- item_column="item_idx",
133
- rating_column="relevance",
134
- has_timestamp=False,
135
- )
136
- return pairs.groupby("user_idx").applyInPandas(
137
- self._pd_func(model=model, filter_seen_items=False),
138
- rec_schema)
@@ -1,327 +0,0 @@
1
- import os
2
- from os.path import join
3
- from typing import Optional, Tuple
4
-
5
- import numpy as np
6
- from lightfm import LightFM
7
- from scipy.sparse import csr_matrix, diags, hstack
8
- from sklearn.preprocessing import MinMaxScaler
9
-
10
- from replay.data import get_schema
11
- from replay.experimental.models.base_rec import HybridRecommender
12
- from replay.experimental.utils.session_handler import State
13
- from replay.preprocessing import CSRConverter
14
- from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
15
- from replay.utils.spark_utils import check_numeric, load_pickled_from_parquet, save_picklable_to_parquet
16
-
17
- if PYSPARK_AVAILABLE:
18
- import pyspark.sql.functions as sf
19
-
20
-
21
- # pylint: disable=too-many-locals, too-many-instance-attributes
22
- class LightFMWrap(HybridRecommender):
23
- """Wrapper for LightFM."""
24
-
25
- epochs: int = 10
26
- _search_space = {
27
- "loss": {
28
- "type": "categorical",
29
- "args": ["logistic", "bpr", "warp", "warp-kos"],
30
- },
31
- "no_components": {"type": "loguniform_int", "args": [8, 512]},
32
- }
33
- user_feat_scaler: Optional[MinMaxScaler] = None
34
- item_feat_scaler: Optional[MinMaxScaler] = None
35
-
36
- def __init__(
37
- self,
38
- no_components: int = 128,
39
- loss: str = "warp",
40
- random_state: Optional[int] = None,
41
- ): # pylint: disable=too-many-arguments
42
- np.random.seed(42)
43
- self.no_components = no_components
44
- self.loss = loss
45
- self.random_state = random_state
46
- cpu_count = os.cpu_count()
47
- self.num_threads = cpu_count if cpu_count is not None else 1
48
-
49
- @property
50
- def _init_args(self):
51
- return {
52
- "no_components": self.no_components,
53
- "loss": self.loss,
54
- "random_state": self.random_state,
55
- }
56
-
57
- def _save_model(self, path: str):
58
- save_picklable_to_parquet(self.model, join(path, "model"))
59
- save_picklable_to_parquet(self.user_feat_scaler, join(path, "user_feat_scaler"))
60
- save_picklable_to_parquet(self.item_feat_scaler, join(path, "item_feat_scaler"))
61
-
62
- def _load_model(self, path: str):
63
- self.model = load_pickled_from_parquet(join(path, "model"))
64
- self.user_feat_scaler = load_pickled_from_parquet(join(path, "user_feat_scaler"))
65
- self.item_feat_scaler = load_pickled_from_parquet(join(path, "item_feat_scaler"))
66
-
67
- def _feature_table_to_csr(
68
- self,
69
- log_ids_list: SparkDataFrame,
70
- feature_table: Optional[SparkDataFrame] = None,
71
- ) -> Optional[csr_matrix]:
72
- """
73
- Transform features to sparse matrix
74
- Matrix consists of two parts:
75
- 1) Left one is a ohe-hot encoding of user and item ids.
76
- Matrix size is: number of users or items * number of user or items in fit.
77
- Cold users and items are represented with empty strings
78
- 2) Right one is a numerical features, passed with feature_table.
79
- MinMaxScaler is applied per column, and then value is divided by the row sum.
80
-
81
- :param feature_table: dataframe with ``user_idx`` or ``item_idx``,
82
- other columns are features.
83
- :param log_ids_list: dataframe with ``user_idx`` or ``item_idx``,
84
- containing unique ids from log.
85
- :returns: feature matrix
86
- """
87
-
88
- if feature_table is None:
89
- return None
90
-
91
- check_numeric(feature_table)
92
- log_ids_list = log_ids_list.distinct()
93
- entity = "item" if "item_idx" in feature_table.columns else "user"
94
- idx_col_name = f"{entity}_idx"
95
-
96
- # filter features by log
97
- feature_table = feature_table.join(
98
- log_ids_list, on=idx_col_name, how="inner"
99
- )
100
-
101
- fit_dim = getattr(self, f"_{entity}_dim")
102
- matrix_height = max(
103
- fit_dim,
104
- log_ids_list.select(sf.max(idx_col_name)).collect()[0][0] + 1,
105
- )
106
- if not feature_table.rdd.isEmpty():
107
- matrix_height = max(
108
- matrix_height,
109
- feature_table.select(sf.max(idx_col_name)).collect()[0][0] + 1,
110
- )
111
-
112
- features_np = (
113
- feature_table.select(
114
- idx_col_name,
115
- # first column contains id, next contain features
116
- *(
117
- sorted(
118
- list(
119
- set(feature_table.columns).difference(
120
- {idx_col_name}
121
- )
122
- )
123
- )
124
- ),
125
- )
126
- .toPandas()
127
- .to_numpy()
128
- )
129
- entities_ids = features_np[:, 0]
130
- features_np = features_np[:, 1:]
131
- number_of_features = features_np.shape[1]
132
-
133
- all_ids_list = log_ids_list.toPandas().to_numpy().ravel()
134
- entities_seen_in_fit = all_ids_list[all_ids_list < fit_dim]
135
-
136
- entity_id_features = csr_matrix(
137
- (
138
- [1.0] * entities_seen_in_fit.shape[0],
139
- (entities_seen_in_fit, entities_seen_in_fit),
140
- ),
141
- shape=(matrix_height, fit_dim),
142
- )
143
-
144
- scaler_name = f"{entity}_feat_scaler"
145
- if getattr(self, scaler_name) is None:
146
- if not features_np.size:
147
- raise ValueError(f"features for {entity}s from log are absent")
148
- setattr(self, scaler_name, MinMaxScaler().fit(features_np))
149
-
150
- if features_np.size:
151
- features_np = getattr(self, scaler_name).transform(features_np)
152
- sparse_features = csr_matrix(
153
- (
154
- features_np.ravel(),
155
- (
156
- np.repeat(entities_ids, number_of_features),
157
- np.tile(
158
- np.arange(number_of_features),
159
- entities_ids.shape[0],
160
- ),
161
- ),
162
- ),
163
- shape=(matrix_height, number_of_features),
164
- )
165
-
166
- else:
167
- sparse_features = csr_matrix((matrix_height, number_of_features))
168
-
169
- concat_features = hstack([entity_id_features, sparse_features])
170
- concat_features_sum = concat_features.sum(axis=1).A.ravel()
171
- mask = concat_features_sum != 0.0
172
- concat_features_sum[mask] = 1.0 / concat_features_sum[mask]
173
- return diags(concat_features_sum, format="csr") @ concat_features
174
-
175
- def _fit(
176
- self,
177
- log: SparkDataFrame,
178
- user_features: Optional[SparkDataFrame] = None,
179
- item_features: Optional[SparkDataFrame] = None,
180
- ) -> None:
181
- self.user_feat_scaler = None
182
- self.item_feat_scaler = None
183
-
184
- interactions_matrix = CSRConverter(
185
- first_dim_column="user_idx",
186
- second_dim_column="item_idx",
187
- data_column="relevance",
188
- row_count=self._user_dim,
189
- column_count=self._item_dim
190
- ).transform(log)
191
- csr_item_features = self._feature_table_to_csr(
192
- log.select("item_idx").distinct(), item_features
193
- )
194
- csr_user_features = self._feature_table_to_csr(
195
- log.select("user_idx").distinct(), user_features
196
- )
197
-
198
- if user_features is not None:
199
- self.can_predict_cold_users = True
200
- if item_features is not None:
201
- self.can_predict_cold_items = True
202
-
203
- self.model = LightFM(
204
- loss=self.loss,
205
- no_components=self.no_components,
206
- random_state=self.random_state,
207
- ).fit(
208
- interactions=interactions_matrix,
209
- epochs=self.epochs,
210
- num_threads=self.num_threads,
211
- item_features=csr_item_features,
212
- user_features=csr_user_features,
213
- )
214
-
215
- def _predict_selected_pairs(
216
- self,
217
- pairs: SparkDataFrame,
218
- user_features: Optional[SparkDataFrame] = None,
219
- item_features: Optional[SparkDataFrame] = None,
220
- ):
221
- def predict_by_user(pandas_df: PandasDataFrame) -> PandasDataFrame:
222
- pandas_df["relevance"] = model.predict(
223
- user_ids=pandas_df["user_idx"].to_numpy(),
224
- item_ids=pandas_df["item_idx"].to_numpy(),
225
- item_features=csr_item_features,
226
- user_features=csr_user_features,
227
- )
228
- return pandas_df
229
-
230
- model = self.model
231
-
232
- if self.can_predict_cold_users and user_features is None:
233
- raise ValueError("User features are missing for predict")
234
- if self.can_predict_cold_items and item_features is None:
235
- raise ValueError("Item features are missing for predict")
236
-
237
- csr_item_features = self._feature_table_to_csr(
238
- pairs.select("item_idx").distinct(), item_features
239
- )
240
- csr_user_features = self._feature_table_to_csr(
241
- pairs.select("user_idx").distinct(), user_features
242
- )
243
- rec_schema = get_schema(
244
- query_column="user_idx",
245
- item_column="item_idx",
246
- rating_column="relevance",
247
- has_timestamp=False,
248
- )
249
- return pairs.groupby("user_idx").applyInPandas(
250
- predict_by_user, rec_schema
251
- )
252
-
253
- # pylint: disable=too-many-arguments
254
- def _predict(
255
- self,
256
- log: SparkDataFrame,
257
- k: int,
258
- users: SparkDataFrame,
259
- items: SparkDataFrame,
260
- user_features: Optional[SparkDataFrame] = None,
261
- item_features: Optional[SparkDataFrame] = None,
262
- filter_seen_items: bool = True,
263
- ) -> SparkDataFrame:
264
- return self._predict_selected_pairs(
265
- users.crossJoin(items), user_features, item_features
266
- )
267
-
268
- def _predict_pairs(
269
- self,
270
- pairs: SparkDataFrame,
271
- log: Optional[SparkDataFrame] = None,
272
- user_features: Optional[SparkDataFrame] = None,
273
- item_features: Optional[SparkDataFrame] = None,
274
- ) -> SparkDataFrame:
275
- return self._predict_selected_pairs(
276
- pairs, user_features, item_features
277
- )
278
-
279
- def _get_features(
280
- self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
281
- ) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
282
- """
283
- Get features from LightFM.
284
- LightFM has methods get_item_representations/get_user_representations,
285
- which accept object matrix and return features.
286
-
287
- :param ids: id item_idx/user_idx to get features for
288
- :param features: features for item_idx/user_idx
289
- :return: spark-dataframe with biases and vectors for users/items and vector size
290
- """
291
- entity = "item" if "item_idx" in ids.columns else "user"
292
- ids_list = ids.toPandas()[f"{entity}_idx"]
293
-
294
- # models without features use sparse matrix
295
- if features is None:
296
- matrix_width = getattr(self, f"fit_{entity}s").count()
297
- warm_ids = ids_list[ids_list < matrix_width]
298
- sparse_features = csr_matrix(
299
- (
300
- [1] * warm_ids.shape[0],
301
- (warm_ids, warm_ids),
302
- ),
303
- shape=(ids_list.max() + 1, matrix_width),
304
- )
305
- else:
306
- sparse_features = self._feature_table_to_csr(ids, features)
307
-
308
- biases, vectors = getattr(self.model, f"get_{entity}_representations")(
309
- sparse_features
310
- )
311
-
312
- embed_list = list(
313
- zip(
314
- ids_list,
315
- biases[ids_list].tolist(),
316
- vectors[ids_list].tolist(),
317
- )
318
- )
319
- lightfm_factors = State().session.createDataFrame(
320
- embed_list,
321
- schema=[
322
- f"{entity}_idx",
323
- f"{entity}_bias",
324
- f"{entity}_factors",
325
- ],
326
- )
327
- return lightfm_factors, self.model.no_components