replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
@@ -1,271 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import logging
4
-
5
- from dataclasses import dataclass
6
-
7
- from obp.policy.base import BaseOfflinePolicyLearner
8
- from pyspark.sql import DataFrame
9
-
10
- from replay.utils.spark_utils import convert2spark
11
- from replay.experimental.scenarios.obp_wrapper.obp_optuna_objective import OBPObjective
12
- from replay.experimental.scenarios.obp_wrapper.utils import split_bandit_feedback
13
- from replay.models.base_rec import BaseRecommender
14
- from replay.data import Dataset, FeatureHint, FeatureInfo, FeatureSchema, FeatureType
15
-
16
- from optuna import create_study
17
- from optuna.samplers import TPESampler
18
-
19
- from typing import (
20
- Any,
21
- Dict,
22
- List,
23
- Optional,
24
- )
25
-
26
-
27
- def obp2df(action: np.ndarray,
28
- reward: np.ndarray,
29
- timestamp: np.ndarray) -> Optional[pd.DataFrame]:
30
- """
31
- Converts OBP log to the pandas DataFrame
32
- """
33
-
34
- n_interactions = len(action)
35
-
36
- df = pd.DataFrame({"user_idx": np.arange(n_interactions),
37
- "item_idx": action,
38
- "rating": reward,
39
- "timestamp": timestamp,
40
- })
41
-
42
- return df
43
-
44
-
45
- def context2df(context: np.ndarray,
46
- idx_col: np.ndarray,
47
- idx_col_name: str) -> Optional[pd.DataFrame]:
48
- """
49
- Converts OBP log to the pandas DataFrame
50
- """
51
-
52
- df1 = pd.DataFrame({
53
- idx_col_name + "_idx": idx_col
54
- })
55
- cols = [str(i) + "_" + idx_col_name for i in range(context.shape[1])]
56
- df2 = pd.DataFrame(context, columns=cols)
57
-
58
- return df1.join(df2)
59
-
60
-
61
- @dataclass
62
- class OBPOfflinePolicyLearner(BaseOfflinePolicyLearner):
63
- """
64
- Off-policy learner which wraps OBP data representation into replay format.
65
-
66
- :param n_actions: Number of actions.
67
-
68
- :param len_list: Length of a list of actions in a recommendation/ranking inferface,
69
- slate size. When Open Bandit Dataset is used, 3 should be set.
70
-
71
- :param replay_model: Any model from replay library with fit, predict functions.
72
-
73
- :param dataset: Dataset of interactions (user_id, item_id, rating).
74
- Constructing inside the fit method. Used for predict of replay_model.
75
- """
76
-
77
- replay_model: Optional[BaseRecommender] = None
78
- log: Optional[DataFrame] = None
79
- max_usr_id: int = 0
80
- item_features: DataFrame = None
81
- _study = None
82
- _logger: Optional[logging.Logger] = None
83
- _objective = OBPObjective
84
-
85
- def __post_init__(self) -> None:
86
- """Initialize Class."""
87
- self.feature_schema = FeatureSchema(
88
- [
89
- FeatureInfo(
90
- column="user_idx",
91
- feature_type=FeatureType.CATEGORICAL,
92
- feature_hint=FeatureHint.QUERY_ID,
93
- ),
94
- FeatureInfo(
95
- column="item_idx",
96
- feature_type=FeatureType.CATEGORICAL,
97
- feature_hint=FeatureHint.ITEM_ID,
98
- ),
99
- FeatureInfo(
100
- column="rating",
101
- feature_type=FeatureType.NUMERICAL,
102
- feature_hint=FeatureHint.RATING,
103
- ),
104
- FeatureInfo(
105
- column="timestamp",
106
- feature_type=FeatureType.NUMERICAL,
107
- feature_hint=FeatureHint.TIMESTAMP,
108
- ),
109
- ]
110
- )
111
-
112
- @property
113
- def logger(self) -> logging.Logger:
114
- """
115
- :return: get library logger
116
- """
117
- if self._logger is None:
118
- self._logger = logging.getLogger("replay")
119
- return self._logger
120
-
121
- # pylint: disable=too-many-arguments, arguments-differ
122
- def fit(self,
123
- action: np.ndarray,
124
- reward: np.ndarray,
125
- timestamp: np.ndarray,
126
- context: np.ndarray = None,
127
- action_context: np.ndarray = None) -> None:
128
- """
129
- Fits an offline bandit policy on the given logged bandit data.
130
- This `fit` method wraps bandit data and calls `fit` method for the replay_model.
131
-
132
- :param action: Actions sampled by the logging/behavior policy
133
- for each data in logged bandit data, i.e., :math:`a_i`.
134
-
135
- :param reward: Rewards observed for each data in logged bandit data, i.e., :math:`r_i`.
136
-
137
- :param timestamp: Moment of time when user interacted with corresponding item.
138
-
139
- :param context: Context vectors observed for each data, i.e., :math:`x_i`.
140
-
141
- :param action_context: Context vectors observed for each action.
142
- """
143
-
144
- log = convert2spark(obp2df(action, reward, timestamp))
145
- self.log = log
146
-
147
- user_features = None
148
- self.max_usr_id = reward.shape[0]
149
-
150
- if context is not None:
151
- user_features = convert2spark(context2df(context,
152
- np.arange(context.shape[0]),
153
- "user"))
154
-
155
- if action_context is not None:
156
- self.item_features = convert2spark(context2df(action_context,
157
- np.arange(self.n_actions),
158
- "item"))
159
-
160
- dataset = Dataset(feature_schema=self.feature_schema,
161
- interactions=log,
162
- query_features=user_features,
163
- item_features=self.item_features)
164
- self.replay_model._fit_wrap(dataset)
165
-
166
- # pylint: disable=arguments-renamed
167
- def predict(self,
168
- n_rounds: int = 1,
169
- context: np.ndarray = None) -> np.ndarray:
170
- '''Predict best actions for new data.
171
- Action set predicted by this `predict` method can contain duplicate items.
172
- If a non-repetitive action set is needed, please use the `sample_action` method.
173
-
174
- :context: Context vectors for new data.
175
-
176
- :return: Action choices made by a classifier, which can contain duplicate items.
177
- If a non-repetitive action set is needed, please use the `sample_action` method.
178
- '''
179
-
180
- user_features = None
181
- if context is not None:
182
- user_features = convert2spark(context2df(context,
183
- np.arange(self.max_usr_id,
184
- self.max_usr_id + n_rounds),
185
- 'user'))
186
-
187
- users = convert2spark(pd.DataFrame({"user_idx": np.arange(self.max_usr_id,
188
- self.max_usr_id + n_rounds)}))
189
- items = convert2spark(pd.DataFrame({"item_idx": np.arange(self.n_actions)}))
190
-
191
- self.max_usr_id += n_rounds
192
-
193
- dataset = Dataset(feature_schema=self.feature_schema,
194
- interactions=self.log,
195
- query_features=user_features,
196
- item_features=self.item_features,
197
- check_consistency=False)
198
-
199
- action_dist = self.replay_model._predict_proba(dataset,
200
- self.len_list,
201
- users,
202
- items,
203
- filter_seen_items=False)
204
-
205
- return action_dist
206
-
207
- # pylint: disable=too-many-arguments, too-many-locals, no-member
208
- def optimize(
209
- self,
210
- bandit_feedback: Dict[str, np.ndarray],
211
- val_size: float = 0.3,
212
- param_borders: Optional[Dict[str, List[Any]]] = None,
213
- criterion: str = "ipw",
214
- budget: int = 10,
215
- new_study: bool = True,
216
- ) -> Optional[Dict[str, Any]]:
217
- '''Optimize model parameters using optuna.
218
- Optimization is carried out over the IPW/DR/DM scores(IPW by default).
219
-
220
- :param bandit_feedback: Bandit log data with fields
221
- ``[action, reward, context, action_context,
222
- n_rounds, n_actions, position, pscore]`` as in OpenBanditPipeline.
223
-
224
- :param val_size: Size of validation subset.
225
-
226
- :param param_borders: Dictionary of parameter names with pair of borders
227
- for the parameters optimization algorithm.
228
-
229
- :param criterion: Score for optimization. Available are `ipw`, `dr` and `dm`.
230
-
231
- :param budget: Number of trials for the optimization algorithm.
232
-
233
- :param new_study: Flag to create new study or not for optuna.
234
-
235
- :return: Dictionary of parameter names with optimal value of corresponding parameter.
236
- '''
237
-
238
- bandit_feedback_train,\
239
- bandit_feedback_val = split_bandit_feedback(bandit_feedback, val_size)
240
-
241
- if self.replay_model._search_space is None:
242
- self.logger.warning(
243
- "%s has no hyper parameters to optimize", str(self)
244
- )
245
- return None
246
-
247
- if self._study is None or new_study:
248
- self._study = create_study(
249
- direction="maximize", sampler=TPESampler()
250
- )
251
-
252
- search_space = self.replay_model._prepare_param_borders(param_borders)
253
- if (
254
- self.replay_model._init_params_in_search_space(search_space)
255
- and not self.replay_model._params_tried()
256
- ):
257
- self._study.enqueue_trial(self.replay_model._init_args)
258
-
259
- objective = self._objective(
260
- search_space=search_space,
261
- bandit_feedback_train=bandit_feedback_train,
262
- bandit_feedback_val=bandit_feedback_val,
263
- learner=self,
264
- criterion=criterion,
265
- k=self.len_list
266
- )
267
-
268
- self._study.optimize(objective, budget)
269
- best_params = self._study.best_params
270
- self.replay_model.set_params(**best_params)
271
- return best_params
@@ -1,88 +0,0 @@
1
- from sklearn.linear_model import LogisticRegression
2
- from obp.ope import RegressionModel
3
- import numpy as np
4
- from typing import Dict, List, Tuple
5
-
6
-
7
- def get_est_rewards_by_reg(n_actions, len_list, bandit_feedback_train, bandit_feedback_test):
8
- """
9
- Fit Logistic Regression to rewards from `bandit_feedback`.
10
- """
11
- regression_model = RegressionModel(
12
- n_actions=n_actions,
13
- len_list=len_list,
14
- action_context=bandit_feedback_train["action_context"],
15
- base_model=LogisticRegression(max_iter=1000, random_state=12345),
16
- )
17
-
18
- regression_model.fit(
19
- context=bandit_feedback_train["context"],
20
- action=bandit_feedback_train["action"],
21
- reward=bandit_feedback_train["reward"],
22
- position=bandit_feedback_train["position"],
23
- pscore=bandit_feedback_train["pscore"],
24
- )
25
-
26
- estimated_rewards_by_reg_model = regression_model.predict(
27
- context=bandit_feedback_test["context"],
28
- )
29
-
30
- return estimated_rewards_by_reg_model
31
-
32
-
33
- def bandit_subset(borders: List[int],
34
- bandit_feedback: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
35
- """
36
- This function returns subset of a `bandit_feedback`
37
- with borders specified in `borders`.
38
-
39
- :param bandit_feedback: Bandit log data with fields
40
- ``[action, reward, context, action_context,
41
- n_rounds, n_actions, position, pscore]``
42
- as in OpenBanditPipeline.
43
- :param borders: List with two values ``[left, right]``
44
- :return: Returns subset of a `bandit_feedback` for each key with
45
- indexes from `left`(including) to `right`(excluding).
46
- """
47
- assert len(borders) == 2
48
-
49
- left, right = borders
50
-
51
- assert left < right
52
-
53
- position = None if bandit_feedback["position"] is None\
54
- else bandit_feedback["position"][left:right]
55
-
56
- return {
57
- "n_rounds": right - left,
58
- "n_actions": bandit_feedback["n_actions"],
59
- "action": bandit_feedback["action"][left:right],
60
- "position": position,
61
- "reward": bandit_feedback["reward"][left:right],
62
- "pscore": bandit_feedback["pscore"][left:right],
63
- "context": bandit_feedback["context"][left:right],
64
- "action_context": bandit_feedback["action_context"][left:right]
65
- }
66
-
67
-
68
- def split_bandit_feedback(bandit_feedback: Dict[str, np.ndarray],
69
- val_size: int = 0.3) -> Tuple[Dict[str, np.ndarray],
70
- Dict[str, np.ndarray]]:
71
- '''
72
- Split `bandit_feedback` into two subsets.
73
- :param bandit_feedback: Bandit log data with fields
74
- ``[action, reward, context, action_context,
75
- n_rounds, n_actions, position, pscore]``
76
- as in OpenBanditPipeline.
77
- :param val_size: Number in range ``[0, 1]`` corresponding to the proportion of
78
- train/val split.
79
- :return: `bandit_feedback_train` and `bandit_feedback_val` split.
80
- '''
81
-
82
- n_rounds = bandit_feedback["n_rounds"]
83
- n_rounds_train = int(n_rounds * (1.0 - val_size))
84
-
85
- bandit_feedback_train = bandit_subset([0, n_rounds_train], bandit_feedback)
86
- bandit_feedback_val = bandit_subset([n_rounds_train, n_rounds], bandit_feedback)
87
-
88
- return bandit_feedback_train, bandit_feedback_val
@@ -1,116 +0,0 @@
1
- import logging
2
- from abc import abstractmethod
3
- from typing import Dict, Optional
4
-
5
- from lightautoml.automl.presets.tabular_presets import TabularAutoML
6
- from lightautoml.tasks import Task
7
- from pyspark.sql import DataFrame
8
-
9
- from replay.utils.spark_utils import convert2spark, get_top_k_recs
10
-
11
-
12
- class ReRanker:
13
- """
14
- Base class for models which re-rank recommendations produced by other models.
15
- May be used as a part of two-stages recommendation pipeline.
16
- """
17
-
18
- _logger: Optional[logging.Logger] = None
19
-
20
- @property
21
- def logger(self) -> logging.Logger:
22
- """
23
- :returns: get library logger
24
- """
25
- if self._logger is None:
26
- self._logger = logging.getLogger("replay")
27
- return self._logger
28
-
29
- @abstractmethod
30
- def fit(self, data: DataFrame, fit_params: Optional[Dict] = None) -> None:
31
- """
32
- Fit the model which re-rank user-item pairs generated outside the models.
33
-
34
- :param data: spark dataframe with obligatory ``[user_idx, item_idx, target]``
35
- columns and features' columns
36
- :param fit_params: dict of parameters to pass to model.fit()
37
- """
38
-
39
- @abstractmethod
40
- def predict(self, data, k) -> DataFrame:
41
- """
42
- Re-rank data with the model and get top-k recommendations for each user.
43
-
44
- :param data: spark dataframe with obligatory ``[user_idx, item_idx]``
45
- columns and features' columns
46
- :param k: number of recommendations for each user
47
- """
48
-
49
-
50
- class LamaWrap(ReRanker):
51
- """
52
- LightAutoML TabularPipeline binary classification model wrapper for recommendations re-ranking.
53
- Read more: https://github.com/sberbank-ai-lab/LightAutoML
54
- """
55
-
56
- def __init__(
57
- self,
58
- params: Optional[Dict] = None,
59
- config_path: Optional[str] = None,
60
- ):
61
- """
62
- Initialize LightAutoML TabularPipeline with passed params/configuration file.
63
-
64
- :param params: dict of model parameters
65
- :param config_path: path to configuration file
66
- """
67
- self.model = TabularAutoML(
68
- task=Task("binary"),
69
- config_path=config_path,
70
- **(params if params is not None else {}),
71
- )
72
-
73
- def fit(self, data: DataFrame, fit_params: Optional[Dict] = None) -> None:
74
- """
75
- Fit the LightAutoML TabularPipeline model with binary classification task.
76
- Data should include negative and positive user-item pairs.
77
-
78
- :param data: spark dataframe with obligatory ``[user_idx, item_idx, target]``
79
- columns and features' columns. `Target` column should consist of zeros and ones
80
- as the model is a binary classification model.
81
- :param fit_params: dict of parameters to pass to model.fit()
82
- See LightAutoML TabularPipeline fit_predict parameters.
83
- """
84
-
85
- params = {"roles": {"target": "target"}, "verbose": 1}
86
- params.update({} if fit_params is None else fit_params)
87
- data = data.drop("user_idx", "item_idx")
88
- data_pd = data.toPandas()
89
- self.model.fit_predict(data_pd, **params)
90
-
91
- def predict(self, data: DataFrame, k: int) -> DataFrame:
92
- """
93
- Re-rank data with the model and get top-k recommendations for each user.
94
-
95
- :param data: spark dataframe with obligatory ``[user_idx, item_idx]``
96
- columns and features' columns
97
- :param k: number of recommendations for each user
98
- :return: spark dataframe with top-k recommendations for each user
99
- the dataframe columns are ``[user_idx, item_idx, relevance]``
100
- """
101
- data_pd = data.toPandas()
102
- candidates_ids = data_pd[["user_idx", "item_idx"]]
103
- data_pd.drop(columns=["user_idx", "item_idx"], inplace=True)
104
- self.logger.info("Starting re-ranking")
105
- candidates_pred = self.model.predict(data_pd)
106
- candidates_ids.loc[:, "relevance"] = candidates_pred.data[:, 0]
107
- self.logger.info(
108
- "%s candidates rated for %s users",
109
- candidates_ids.shape[0],
110
- candidates_ids["user_idx"].nunique(),
111
- )
112
-
113
- self.logger.info("top-k")
114
- return get_top_k_recs(
115
- recs=convert2spark(candidates_ids), k=k,
116
- )