replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +0 -61
- replay/experimental/metrics/base_metric.py +0 -661
- replay/experimental/metrics/coverage.py +0 -117
- replay/experimental/metrics/experiment.py +0 -200
- replay/experimental/metrics/hitrate.py +0 -27
- replay/experimental/metrics/map.py +0 -31
- replay/experimental/metrics/mrr.py +0 -19
- replay/experimental/metrics/ncis_precision.py +0 -32
- replay/experimental/metrics/ndcg.py +0 -50
- replay/experimental/metrics/precision.py +0 -23
- replay/experimental/metrics/recall.py +0 -26
- replay/experimental/metrics/rocauc.py +0 -50
- replay/experimental/metrics/surprisal.py +0 -102
- replay/experimental/metrics/unexpectedness.py +0 -74
- replay/experimental/models/__init__.py +0 -10
- replay/experimental/models/admm_slim.py +0 -216
- replay/experimental/models/base_neighbour_rec.py +0 -222
- replay/experimental/models/base_rec.py +0 -1361
- replay/experimental/models/base_torch_rec.py +0 -247
- replay/experimental/models/cql.py +0 -468
- replay/experimental/models/ddpg.py +0 -1007
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +0 -193
- replay/experimental/models/dt4rec/gpt1.py +0 -411
- replay/experimental/models/dt4rec/trainer.py +0 -128
- replay/experimental/models/dt4rec/utils.py +0 -274
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
- replay/experimental/models/implicit_wrap.py +0 -138
- replay/experimental/models/lightfm_wrap.py +0 -327
- replay/experimental/models/mult_vae.py +0 -374
- replay/experimental/models/neuromf.py +0 -462
- replay/experimental/models/scala_als.py +0 -311
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -58
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -929
- replay/experimental/preprocessing/padder.py +0 -231
- replay/experimental/preprocessing/sequence_generator.py +0 -218
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
- replay/experimental/scenarios/two_stages/reranker.py +0 -116
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -213
- replay/experimental/utils/session_handler.py +0 -47
- replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
- replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
replay/models/base_rec.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# pylint: disable=too-many-lines
|
|
2
1
|
"""
|
|
3
2
|
Base abstract classes:
|
|
4
3
|
- BaseRecommender - the simplest base class
|
|
@@ -19,8 +18,8 @@ from copy import deepcopy
|
|
|
19
18
|
from os.path import join
|
|
20
19
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
|
|
21
20
|
|
|
22
|
-
import pandas as pd
|
|
23
21
|
import numpy as np
|
|
22
|
+
import pandas as pd
|
|
24
23
|
from numpy.random import default_rng
|
|
25
24
|
from optuna import create_study
|
|
26
25
|
from optuna.samplers import TPESampler
|
|
@@ -33,8 +32,10 @@ from replay.utils.session_handler import State
|
|
|
33
32
|
from replay.utils.spark_utils import SparkCollectToMasterWarning
|
|
34
33
|
|
|
35
34
|
if PYSPARK_AVAILABLE:
|
|
36
|
-
from pyspark.sql import
|
|
37
|
-
|
|
35
|
+
from pyspark.sql import (
|
|
36
|
+
Window,
|
|
37
|
+
functions as sf,
|
|
38
|
+
)
|
|
38
39
|
|
|
39
40
|
from replay.utils.spark_utils import (
|
|
40
41
|
cache_temp_view,
|
|
@@ -53,7 +54,6 @@ if PYSPARK_AVAILABLE:
|
|
|
53
54
|
)
|
|
54
55
|
|
|
55
56
|
|
|
56
|
-
# pylint: disable=too-few-public-methods
|
|
57
57
|
class IsSavable(ABC):
|
|
58
58
|
"""
|
|
59
59
|
Common methods and attributes for saving and loading RePlay models
|
|
@@ -133,7 +133,7 @@ class RecommenderCommons:
|
|
|
133
133
|
Create Spark SQL temporary view for df, cache it and add temp view name to self.cached_dfs.
|
|
134
134
|
Temp view name is : "id_<python object id>_model_<RePlay model name>_<df_name>"
|
|
135
135
|
"""
|
|
136
|
-
full_name = f"id_{id(self)}_model_{
|
|
136
|
+
full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
|
|
137
137
|
cache_temp_view(df, full_name)
|
|
138
138
|
|
|
139
139
|
if self.cached_dfs is None:
|
|
@@ -146,22 +146,19 @@ class RecommenderCommons:
|
|
|
146
146
|
Temp view to replace will be constructed as
|
|
147
147
|
"id_<python object id>_model_<RePlay model name>_<df_name>"
|
|
148
148
|
"""
|
|
149
|
-
full_name = f"id_{id(self)}_model_{
|
|
149
|
+
full_name = f"id_{id(self)}_model_{self!s}_{df_name}"
|
|
150
150
|
drop_temp_view(full_name)
|
|
151
151
|
if self.cached_dfs is not None:
|
|
152
152
|
self.cached_dfs.discard(full_name)
|
|
153
153
|
|
|
154
154
|
|
|
155
|
-
# pylint: disable=too-many-instance-attributes
|
|
156
155
|
class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
157
156
|
"""Base recommender"""
|
|
158
157
|
|
|
159
158
|
model: Any
|
|
160
159
|
can_predict_cold_queries: bool = False
|
|
161
160
|
can_predict_cold_items: bool = False
|
|
162
|
-
_search_space: Optional[
|
|
163
|
-
Dict[str, Union[str, Sequence[Union[str, int, float]]]]
|
|
164
|
-
] = None
|
|
161
|
+
_search_space: Optional[Dict[str, Union[str, Sequence[Union[str, int, float]]]]] = None
|
|
165
162
|
_objective = MainObjective
|
|
166
163
|
study = None
|
|
167
164
|
criterion = None
|
|
@@ -172,7 +169,6 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
172
169
|
_query_dim_size: int
|
|
173
170
|
_item_dim_size: int
|
|
174
171
|
|
|
175
|
-
# pylint: disable=too-many-arguments, too-many-locals, no-member
|
|
176
172
|
def optimize(
|
|
177
173
|
self,
|
|
178
174
|
train_dataset: Dataset,
|
|
@@ -211,21 +207,14 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
211
207
|
)
|
|
212
208
|
|
|
213
209
|
if self._search_space is None:
|
|
214
|
-
self.logger.warning(
|
|
215
|
-
"%s has no hyper parameters to optimize", str(self)
|
|
216
|
-
)
|
|
210
|
+
self.logger.warning("%s has no hyper parameters to optimize", str(self))
|
|
217
211
|
return None
|
|
218
212
|
|
|
219
213
|
if self.study is None or new_study:
|
|
220
|
-
self.study = create_study(
|
|
221
|
-
direction="maximize", sampler=TPESampler()
|
|
222
|
-
)
|
|
214
|
+
self.study = create_study(direction="maximize", sampler=TPESampler())
|
|
223
215
|
|
|
224
216
|
search_space = self._prepare_param_borders(param_borders)
|
|
225
|
-
if (
|
|
226
|
-
self._init_params_in_search_space(search_space)
|
|
227
|
-
and not self._params_tried()
|
|
228
|
-
):
|
|
217
|
+
if self._init_params_in_search_space(search_space) and not self._params_tried():
|
|
229
218
|
self.study.enqueue_trial(self._init_args)
|
|
230
219
|
|
|
231
220
|
split_data = self._prepare_split_data(train_dataset, test_dataset)
|
|
@@ -244,7 +233,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
244
233
|
|
|
245
234
|
def _init_params_in_search_space(self, search_space):
|
|
246
235
|
"""Check if model params are inside search space"""
|
|
247
|
-
params = self._init_args
|
|
236
|
+
params = self._init_args
|
|
248
237
|
outside_search_space = {}
|
|
249
238
|
for param, value in params.items():
|
|
250
239
|
if param not in search_space:
|
|
@@ -252,12 +241,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
252
241
|
borders = search_space[param]["args"]
|
|
253
242
|
param_type = search_space[param]["type"]
|
|
254
243
|
|
|
255
|
-
extra_category =
|
|
256
|
-
|
|
257
|
-
)
|
|
258
|
-
param_out_of_bounds = param_type != "categorical" and (
|
|
259
|
-
value < borders[0] or value > borders[1]
|
|
260
|
-
)
|
|
244
|
+
extra_category = param_type == "categorical" and value not in borders
|
|
245
|
+
param_out_of_bounds = param_type != "categorical" and (value < borders[0] or value > borders[1])
|
|
261
246
|
if extra_category or param_out_of_bounds:
|
|
262
247
|
outside_search_space[param] = {
|
|
263
248
|
"borders": borders,
|
|
@@ -299,11 +284,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
299
284
|
# If used didn't specify some params to be tested optuna still needs to suggest them
|
|
300
285
|
# This part makes sure this suggestion will be constant
|
|
301
286
|
args = self._init_args
|
|
302
|
-
missing_borders = {
|
|
303
|
-
param: args[param]
|
|
304
|
-
for param in search_space
|
|
305
|
-
if param not in param_borders
|
|
306
|
-
}
|
|
287
|
+
missing_borders = {param: args[param] for param in search_space if param not in param_borders}
|
|
307
288
|
for param, value in missing_borders.items():
|
|
308
289
|
if search_space[param]["type"] == "categorical":
|
|
309
290
|
search_space[param]["args"] = [value]
|
|
@@ -315,21 +296,14 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
315
296
|
def _check_borders(self, param, borders):
|
|
316
297
|
"""Raise value error if param borders are not valid"""
|
|
317
298
|
if param not in self._search_space:
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
)
|
|
299
|
+
msg = f"Hyper parameter {param} is not defined for {self!s}"
|
|
300
|
+
raise ValueError(msg)
|
|
321
301
|
if not isinstance(borders, list):
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
raise ValueError(
|
|
328
|
-
f"""
|
|
329
|
-
Hyper parameter {param} is numerical
|
|
330
|
-
but bounds are not in ([lower, upper]) format
|
|
331
|
-
"""
|
|
332
|
-
)
|
|
302
|
+
msg = f"Parameter {param} borders are not a list"
|
|
303
|
+
raise ValueError()
|
|
304
|
+
if self._search_space[param]["type"] != "categorical" and len(borders) != 2:
|
|
305
|
+
msg = f"Hyper parameter {param} is numerical but bounds are not in ([lower, upper]) format"
|
|
306
|
+
raise ValueError(msg)
|
|
333
307
|
|
|
334
308
|
def _prepare_split_data(
|
|
335
309
|
self,
|
|
@@ -373,16 +347,12 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
373
347
|
item_features = None
|
|
374
348
|
if dataset.query_features is not None:
|
|
375
349
|
query_features = dataset.query_features.join(
|
|
376
|
-
dataset.interactions.select(
|
|
377
|
-
dataset.feature_schema.query_id_column
|
|
378
|
-
).distinct(),
|
|
350
|
+
dataset.interactions.select(dataset.feature_schema.query_id_column).distinct(),
|
|
379
351
|
on=dataset.feature_schema.query_id_column,
|
|
380
352
|
)
|
|
381
353
|
if dataset.item_features is not None:
|
|
382
354
|
item_features = dataset.item_features.join(
|
|
383
|
-
dataset.interactions.select(
|
|
384
|
-
dataset.feature_schema.item_id_column
|
|
385
|
-
).distinct(),
|
|
355
|
+
dataset.interactions.select(dataset.feature_schema.item_id_column).distinct(),
|
|
386
356
|
on=dataset.feature_schema.item_id_column,
|
|
387
357
|
)
|
|
388
358
|
|
|
@@ -431,12 +401,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
431
401
|
self.fit_items = sf.broadcast(items)
|
|
432
402
|
self._num_queries = self.fit_queries.count()
|
|
433
403
|
self._num_items = self.fit_items.count()
|
|
434
|
-
self._query_dim_size = (
|
|
435
|
-
|
|
436
|
-
)
|
|
437
|
-
self._item_dim_size = (
|
|
438
|
-
self.fit_items.agg({self.item_column: "max"}).collect()[0][0] + 1
|
|
439
|
-
)
|
|
404
|
+
self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).collect()[0][0] + 1
|
|
405
|
+
self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).collect()[0][0] + 1
|
|
440
406
|
self._fit(dataset)
|
|
441
407
|
|
|
442
408
|
@abstractmethod
|
|
@@ -452,18 +418,14 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
452
418
|
:return:
|
|
453
419
|
"""
|
|
454
420
|
|
|
455
|
-
def _filter_seen(
|
|
456
|
-
self, recs: SparkDataFrame, interactions: SparkDataFrame, k: int, queries: SparkDataFrame
|
|
457
|
-
):
|
|
421
|
+
def _filter_seen(self, recs: SparkDataFrame, interactions: SparkDataFrame, k: int, queries: SparkDataFrame):
|
|
458
422
|
"""
|
|
459
423
|
Filter seen items (presented in interactions) out of the queries' recommendations.
|
|
460
424
|
For each query return from `k` to `k + number of seen by query` recommendations.
|
|
461
425
|
"""
|
|
462
426
|
queries_interactions = interactions.join(queries, on=self.query_column)
|
|
463
427
|
self._cache_model_temp_view(queries_interactions, "filter_seen_queries_interactions")
|
|
464
|
-
num_seen = queries_interactions.groupBy(self.query_column).agg(
|
|
465
|
-
sf.count(self.item_column).alias("seen_count")
|
|
466
|
-
)
|
|
428
|
+
num_seen = queries_interactions.groupBy(self.query_column).agg(sf.count(self.item_column).alias("seen_count"))
|
|
467
429
|
self._cache_model_temp_view(num_seen, "filter_seen_num_seen")
|
|
468
430
|
|
|
469
431
|
# count maximal number of items seen by queries
|
|
@@ -474,11 +436,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
474
436
|
# crop recommendations to first k + max_seen items for each query
|
|
475
437
|
recs = recs.withColumn(
|
|
476
438
|
"temp_rank",
|
|
477
|
-
sf.row_number().over(
|
|
478
|
-
Window.partitionBy(self.query_column).orderBy(
|
|
479
|
-
sf.col(self.rating_column).desc()
|
|
480
|
-
)
|
|
481
|
-
),
|
|
439
|
+
sf.row_number().over(Window.partitionBy(self.query_column).orderBy(sf.col(self.rating_column).desc())),
|
|
482
440
|
).filter(sf.col("temp_rank") <= sf.lit(max_seen + k))
|
|
483
441
|
|
|
484
442
|
# leave k + number of items seen by query recommendations in recs
|
|
@@ -494,8 +452,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
494
452
|
queries_interactions.withColumnRenamed(self.item_column, "item")
|
|
495
453
|
.withColumnRenamed(self.query_column, "query")
|
|
496
454
|
.select("query", "item"),
|
|
497
|
-
on=(sf.col(self.query_column) == sf.col("query"))
|
|
498
|
-
& (sf.col(self.item_column) == sf.col("item")),
|
|
455
|
+
on=(sf.col(self.query_column) == sf.col("query")) & (sf.col(self.item_column) == sf.col("item")),
|
|
499
456
|
how="anti",
|
|
500
457
|
).drop("query", "item")
|
|
501
458
|
|
|
@@ -556,7 +513,6 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
556
513
|
)
|
|
557
514
|
return dataset, queries, items
|
|
558
515
|
|
|
559
|
-
# pylint: disable=too-many-arguments
|
|
560
516
|
def _predict_wrap(
|
|
561
517
|
self,
|
|
562
518
|
dataset: Optional[Dataset],
|
|
@@ -589,9 +545,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
589
545
|
:return: cached recommendation dataframe with columns ``[user_idx, item_idx, rating]``
|
|
590
546
|
or None if `file_path` is provided
|
|
591
547
|
"""
|
|
592
|
-
dataset, queries, items = self._filter_interactions_queries_items_dataframes(
|
|
593
|
-
dataset, k, queries, items
|
|
594
|
-
)
|
|
548
|
+
dataset, queries, items = self._filter_interactions_queries_items_dataframes(dataset, k, queries, items)
|
|
595
549
|
|
|
596
550
|
recs = self._predict(
|
|
597
551
|
dataset,
|
|
@@ -630,21 +584,16 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
630
584
|
if can_predict_cold:
|
|
631
585
|
return main_df, interactions_df
|
|
632
586
|
|
|
633
|
-
num_new, main_df = filter_cold(
|
|
634
|
-
main_df, fit_entities, col_name=column
|
|
635
|
-
)
|
|
587
|
+
num_new, main_df = filter_cold(main_df, fit_entities, col_name=column)
|
|
636
588
|
if num_new > 0:
|
|
637
589
|
self.logger.info(
|
|
638
590
|
"%s model can't predict cold %ss, they will be ignored",
|
|
639
591
|
self,
|
|
640
592
|
entity,
|
|
641
593
|
)
|
|
642
|
-
_, interactions_df = filter_cold(
|
|
643
|
-
interactions_df, fit_entities, col_name=column
|
|
644
|
-
)
|
|
594
|
+
_, interactions_df = filter_cold(interactions_df, fit_entities, col_name=column)
|
|
645
595
|
return main_df, interactions_df
|
|
646
596
|
|
|
647
|
-
# pylint: disable=too-many-arguments
|
|
648
597
|
@abstractmethod
|
|
649
598
|
def _predict(
|
|
650
599
|
self,
|
|
@@ -673,12 +622,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
673
622
|
"""
|
|
674
623
|
|
|
675
624
|
def _predict_proba(
|
|
676
|
-
self,
|
|
677
|
-
dataset : Dataset,
|
|
678
|
-
k: int,
|
|
679
|
-
queries: SparkDataFrame,
|
|
680
|
-
items: SparkDataFrame,
|
|
681
|
-
filter_seen_items: bool = True
|
|
625
|
+
self, dataset: Dataset, k: int, queries: SparkDataFrame, items: SparkDataFrame, filter_seen_items: bool = True
|
|
682
626
|
) -> np.ndarray:
|
|
683
627
|
"""
|
|
684
628
|
Inner method where model actually predicts.
|
|
@@ -706,11 +650,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
706
650
|
n_users = queries.select("user_idx").count()
|
|
707
651
|
n_items = items.select("item_idx").count()
|
|
708
652
|
|
|
709
|
-
recs = self._predict(dataset,
|
|
710
|
-
k,
|
|
711
|
-
queries,
|
|
712
|
-
items,
|
|
713
|
-
filter_seen_items)
|
|
653
|
+
recs = self._predict(dataset, k, queries, items, filter_seen_items)
|
|
714
654
|
|
|
715
655
|
recs = get_top_k_recs(recs, k=k, query_column=self.query_column, rating_column=self.rating_column).select(
|
|
716
656
|
self.query_column, self.item_column, self.rating_column
|
|
@@ -718,17 +658,20 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
718
658
|
|
|
719
659
|
cols = [f"k{i}" for i in range(k)]
|
|
720
660
|
|
|
721
|
-
recs_items =
|
|
722
|
-
|
|
723
|
-
|
|
661
|
+
recs_items = (
|
|
662
|
+
recs.groupBy("user_idx")
|
|
663
|
+
.agg(sf.collect_list("item_idx").alias("item_idx"))
|
|
664
|
+
.select([sf.col("item_idx")[i].alias(cols[i]) for i in range(k)])
|
|
724
665
|
)
|
|
725
666
|
|
|
726
667
|
action_dist = np.zeros(shape=(n_users, n_items, k))
|
|
727
668
|
|
|
728
669
|
for i in range(k):
|
|
729
|
-
action_dist[
|
|
730
|
-
|
|
731
|
-
|
|
670
|
+
action_dist[
|
|
671
|
+
np.arange(n_users),
|
|
672
|
+
recs_items.select(cols[i]).toPandas()[cols[i]].to_numpy(),
|
|
673
|
+
np.ones(n_users, dtype=int) * i,
|
|
674
|
+
] += 1
|
|
732
675
|
|
|
733
676
|
return action_dist
|
|
734
677
|
|
|
@@ -765,10 +708,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
765
708
|
setattr(
|
|
766
709
|
self,
|
|
767
710
|
dim_size,
|
|
768
|
-
fit_entities
|
|
769
|
-
.agg({column: "max"})
|
|
770
|
-
.collect()[0][0]
|
|
771
|
-
+ 1,
|
|
711
|
+
fit_entities.agg({column: "max"}).collect()[0][0] + 1,
|
|
772
712
|
)
|
|
773
713
|
return getattr(self, dim_size)
|
|
774
714
|
|
|
@@ -829,13 +769,11 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
829
769
|
"""
|
|
830
770
|
if dataset is not None:
|
|
831
771
|
interactions, query_features, item_features, pairs = [
|
|
832
|
-
convert2spark(df)
|
|
833
|
-
for df in [dataset.interactions, dataset.query_features, dataset.item_features, pairs]
|
|
772
|
+
convert2spark(df) for df in [dataset.interactions, dataset.query_features, dataset.item_features, pairs]
|
|
834
773
|
]
|
|
835
|
-
if set(pairs.columns) !=
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
)
|
|
774
|
+
if set(pairs.columns) != {self.item_column, self.query_column}:
|
|
775
|
+
msg = "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
|
|
776
|
+
raise ValueError(msg)
|
|
839
777
|
pairs, interactions = self._filter_cold_for_predict(pairs, interactions, "query")
|
|
840
778
|
pairs, interactions = self._filter_cold_for_predict(pairs, interactions, "item")
|
|
841
779
|
|
|
@@ -908,13 +846,13 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
908
846
|
self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
|
|
909
847
|
) -> Optional[Tuple[SparkDataFrame, int]]:
|
|
910
848
|
if self.query_column not in ids.columns and self.item_column not in ids.columns:
|
|
911
|
-
|
|
849
|
+
msg = f"{self.query_column} or {self.item_column} missing"
|
|
850
|
+
raise ValueError(msg)
|
|
912
851
|
vectors, rank = self._get_features(ids, features)
|
|
913
852
|
return vectors, rank
|
|
914
853
|
|
|
915
|
-
# pylint: disable=unused-argument
|
|
916
854
|
def _get_features(
|
|
917
|
-
self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
|
|
855
|
+
self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
|
|
918
856
|
) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
|
|
919
857
|
"""
|
|
920
858
|
Get embeddings from model
|
|
@@ -961,39 +899,26 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
961
899
|
k=k,
|
|
962
900
|
)
|
|
963
901
|
|
|
964
|
-
nearest_items = nearest_items.withColumnRenamed(
|
|
965
|
-
|
|
966
|
-
)
|
|
967
|
-
nearest_items = nearest_items.withColumnRenamed(
|
|
968
|
-
"item_idx_one", self.item_column
|
|
969
|
-
)
|
|
902
|
+
nearest_items = nearest_items.withColumnRenamed("item_idx_two", "neighbour_item_idx")
|
|
903
|
+
nearest_items = nearest_items.withColumnRenamed("item_idx_one", self.item_column)
|
|
970
904
|
return nearest_items
|
|
971
905
|
|
|
972
906
|
def _get_nearest_items(
|
|
973
907
|
self,
|
|
974
|
-
items: SparkDataFrame,
|
|
975
|
-
metric: Optional[str] = None,
|
|
976
|
-
candidates: Optional[SparkDataFrame] = None,
|
|
908
|
+
items: SparkDataFrame, # noqa: ARG002
|
|
909
|
+
metric: Optional[str] = None, # noqa: ARG002
|
|
910
|
+
candidates: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
977
911
|
) -> Optional[SparkDataFrame]:
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
)
|
|
912
|
+
msg = f"item-to-item prediction is not implemented for {self}"
|
|
913
|
+
raise NotImplementedError(msg)
|
|
981
914
|
|
|
982
915
|
def _params_tried(self):
|
|
983
916
|
"""check if current parameters were already evaluated"""
|
|
984
917
|
if self.study is None:
|
|
985
918
|
return False
|
|
986
919
|
|
|
987
|
-
params = {
|
|
988
|
-
|
|
989
|
-
for name, value in self._init_args.items()
|
|
990
|
-
if name in self._search_space
|
|
991
|
-
}
|
|
992
|
-
for trial in self.study.trials:
|
|
993
|
-
if params == trial.params:
|
|
994
|
-
return True
|
|
995
|
-
|
|
996
|
-
return False
|
|
920
|
+
params = {name: value for name, value in self._init_args.items() if name in self._search_space}
|
|
921
|
+
return any(params == trial.params for trial in self.study.trials)
|
|
997
922
|
|
|
998
923
|
def _save_model(self, path: str, additional_params: Optional[dict] = None):
|
|
999
924
|
saved_params = {
|
|
@@ -1004,10 +929,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
|
|
|
1004
929
|
}
|
|
1005
930
|
if additional_params is not None:
|
|
1006
931
|
saved_params.update(additional_params)
|
|
1007
|
-
save_picklable_to_parquet(
|
|
1008
|
-
saved_params,
|
|
1009
|
-
join(path, "params.dump")
|
|
1010
|
-
)
|
|
932
|
+
save_picklable_to_parquet(saved_params, join(path, "params.dump"))
|
|
1011
933
|
|
|
1012
934
|
def _load_model(self, path: str):
|
|
1013
935
|
loaded_params = load_pickled_from_parquet(join(path, "params.dump"))
|
|
@@ -1053,10 +975,8 @@ class ItemVectorModel(BaseRecommender):
|
|
|
1053
975
|
spark-dataframe with columns ``[item_idx, neighbour_item_idx, similarity]``
|
|
1054
976
|
"""
|
|
1055
977
|
if metric not in self.item_to_item_metrics:
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
f"{self.item_to_item_metrics}"
|
|
1059
|
-
)
|
|
978
|
+
msg = f"Select one of the valid distance metrics: {self.item_to_item_metrics}"
|
|
979
|
+
raise ValueError(msg)
|
|
1060
980
|
|
|
1061
981
|
return self._get_nearest_items_wrap(
|
|
1062
982
|
items=items,
|
|
@@ -1098,9 +1018,9 @@ class ItemVectorModel(BaseRecommender):
|
|
|
1098
1018
|
)
|
|
1099
1019
|
)
|
|
1100
1020
|
|
|
1101
|
-
right_part = items_vectors.withColumnRenamed(
|
|
1102
|
-
|
|
1103
|
-
)
|
|
1021
|
+
right_part = items_vectors.withColumnRenamed(self.item_column, "item_idx_two").withColumnRenamed(
|
|
1022
|
+
"item_vector", "item_vector_two"
|
|
1023
|
+
)
|
|
1104
1024
|
|
|
1105
1025
|
if candidates is not None:
|
|
1106
1026
|
right_part = right_part.join(
|
|
@@ -1108,25 +1028,18 @@ class ItemVectorModel(BaseRecommender):
|
|
|
1108
1028
|
on="item_idx_two",
|
|
1109
1029
|
)
|
|
1110
1030
|
|
|
1111
|
-
joined_factors = left_part.join(
|
|
1112
|
-
right_part, on=sf.col("item_idx_one") != sf.col("item_idx_two")
|
|
1113
|
-
)
|
|
1031
|
+
joined_factors = left_part.join(right_part, on=sf.col("item_idx_one") != sf.col("item_idx_two"))
|
|
1114
1032
|
|
|
1115
1033
|
joined_factors = joined_factors.withColumn(
|
|
1116
1034
|
metric,
|
|
1117
|
-
dist_function(
|
|
1118
|
-
sf.col("item_vector_one"), sf.col("item_vector_two")
|
|
1119
|
-
),
|
|
1035
|
+
dist_function(sf.col("item_vector_one"), sf.col("item_vector_two")),
|
|
1120
1036
|
)
|
|
1121
1037
|
|
|
1122
|
-
similarity_matrix = joined_factors.select(
|
|
1123
|
-
"item_idx_one", "item_idx_two", metric
|
|
1124
|
-
)
|
|
1038
|
+
similarity_matrix = joined_factors.select("item_idx_one", "item_idx_two", metric)
|
|
1125
1039
|
|
|
1126
1040
|
return similarity_matrix
|
|
1127
1041
|
|
|
1128
1042
|
|
|
1129
|
-
# pylint: disable=abstract-method
|
|
1130
1043
|
class HybridRecommender(BaseRecommender, ABC):
|
|
1131
1044
|
"""Base class for models that can use extra features"""
|
|
1132
1045
|
|
|
@@ -1143,7 +1056,6 @@ class HybridRecommender(BaseRecommender, ABC):
|
|
|
1143
1056
|
"""
|
|
1144
1057
|
self._fit_wrap(dataset=dataset)
|
|
1145
1058
|
|
|
1146
|
-
# pylint: disable=too-many-arguments
|
|
1147
1059
|
def predict(
|
|
1148
1060
|
self,
|
|
1149
1061
|
dataset: Dataset,
|
|
@@ -1260,7 +1172,6 @@ class HybridRecommender(BaseRecommender, ABC):
|
|
|
1260
1172
|
return self._get_features_wrap(ids, features)
|
|
1261
1173
|
|
|
1262
1174
|
|
|
1263
|
-
# pylint: disable=abstract-method
|
|
1264
1175
|
class Recommender(BaseRecommender, ABC):
|
|
1265
1176
|
"""Usual recommender class for models without features."""
|
|
1266
1177
|
|
|
@@ -1274,7 +1185,6 @@ class Recommender(BaseRecommender, ABC):
|
|
|
1274
1185
|
"""
|
|
1275
1186
|
self._fit_wrap(dataset=dataset)
|
|
1276
1187
|
|
|
1277
|
-
# pylint: disable=too-many-arguments
|
|
1278
1188
|
def predict(
|
|
1279
1189
|
self,
|
|
1280
1190
|
dataset: Dataset,
|
|
@@ -1340,7 +1250,6 @@ class Recommender(BaseRecommender, ABC):
|
|
|
1340
1250
|
k=k,
|
|
1341
1251
|
)
|
|
1342
1252
|
|
|
1343
|
-
# pylint: disable=too-many-arguments
|
|
1344
1253
|
def fit_predict(
|
|
1345
1254
|
self,
|
|
1346
1255
|
dataset: Dataset,
|
|
@@ -1406,7 +1315,6 @@ class QueryRecommender(BaseRecommender, ABC):
|
|
|
1406
1315
|
"""
|
|
1407
1316
|
self._fit_wrap(dataset=dataset)
|
|
1408
1317
|
|
|
1409
|
-
# pylint: disable=too-many-arguments
|
|
1410
1318
|
def predict(
|
|
1411
1319
|
self,
|
|
1412
1320
|
dataset: Dataset,
|
|
@@ -1436,7 +1344,8 @@ class QueryRecommender(BaseRecommender, ABC):
|
|
|
1436
1344
|
or None if `file_path` is provided
|
|
1437
1345
|
"""
|
|
1438
1346
|
if not dataset or not dataset.query_features:
|
|
1439
|
-
|
|
1347
|
+
msg = "Query features are missing for predict"
|
|
1348
|
+
raise ValueError(msg)
|
|
1440
1349
|
|
|
1441
1350
|
return self._predict_wrap(
|
|
1442
1351
|
dataset=dataset,
|
|
@@ -1469,7 +1378,8 @@ class QueryRecommender(BaseRecommender, ABC):
|
|
|
1469
1378
|
or None if `file_path` is provided
|
|
1470
1379
|
"""
|
|
1471
1380
|
if not dataset or not dataset.query_features:
|
|
1472
|
-
|
|
1381
|
+
msg = "Query features are missing for predict"
|
|
1382
|
+
raise ValueError(msg)
|
|
1473
1383
|
|
|
1474
1384
|
return self._predict_pairs_wrap(
|
|
1475
1385
|
pairs=pairs,
|
|
@@ -1496,15 +1406,14 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1496
1406
|
if 0 < cold_weight <= 1:
|
|
1497
1407
|
self.cold_weight = cold_weight
|
|
1498
1408
|
else:
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
)
|
|
1409
|
+
msg = "`cold_weight` value should be in interval (0, 1]"
|
|
1410
|
+
raise ValueError(msg)
|
|
1502
1411
|
|
|
1503
1412
|
@property
|
|
1504
1413
|
def _dataframes(self):
|
|
1505
1414
|
return {"item_popularity": self.item_popularity}
|
|
1506
1415
|
|
|
1507
|
-
def _save_model(self, path: str, additional_params: Optional[dict] = None):
|
|
1416
|
+
def _save_model(self, path: str, additional_params: Optional[dict] = None): # noqa: ARG002
|
|
1508
1417
|
super()._save_model(path, additional_params={"fill": self.fill})
|
|
1509
1418
|
|
|
1510
1419
|
def _clear_cache(self):
|
|
@@ -1517,10 +1426,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1517
1426
|
Calculating a fill value a the minimal rating
|
|
1518
1427
|
calculated during model training multiplied by weight.
|
|
1519
1428
|
"""
|
|
1520
|
-
return (
|
|
1521
|
-
item_popularity.select(sf.min(rating_column)).collect()[0][0]
|
|
1522
|
-
* weight
|
|
1523
|
-
)
|
|
1429
|
+
return item_popularity.select(sf.min(rating_column)).collect()[0][0] * weight
|
|
1524
1430
|
|
|
1525
1431
|
@staticmethod
|
|
1526
1432
|
def _check_rating(dataset: Dataset):
|
|
@@ -1529,7 +1435,8 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1529
1435
|
(sf.col(rating_column) != 1) & (sf.col(rating_column) != 0)
|
|
1530
1436
|
)
|
|
1531
1437
|
if vals.count() > 0:
|
|
1532
|
-
|
|
1438
|
+
msg = "Rating values in interactions must be 0 or 1"
|
|
1439
|
+
raise ValueError(msg)
|
|
1533
1440
|
|
|
1534
1441
|
def _get_selected_item_popularity(self, items: SparkDataFrame) -> SparkDataFrame:
|
|
1535
1442
|
"""
|
|
@@ -1561,7 +1468,6 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1561
1468
|
|
|
1562
1469
|
return max_hist_len
|
|
1563
1470
|
|
|
1564
|
-
# pylint: disable=too-many-arguments
|
|
1565
1471
|
def _predict_without_sampling(
|
|
1566
1472
|
self,
|
|
1567
1473
|
dataset: Dataset,
|
|
@@ -1577,11 +1483,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1577
1483
|
selected_item_popularity = self._get_selected_item_popularity(items)
|
|
1578
1484
|
selected_item_popularity = selected_item_popularity.withColumn(
|
|
1579
1485
|
"rank",
|
|
1580
|
-
sf.row_number().over(
|
|
1581
|
-
Window.orderBy(
|
|
1582
|
-
sf.col(self.rating_column).desc(), sf.col(self.item_column).desc()
|
|
1583
|
-
)
|
|
1584
|
-
),
|
|
1486
|
+
sf.row_number().over(Window.orderBy(sf.col(self.rating_column).desc(), sf.col(self.item_column).desc())),
|
|
1585
1487
|
)
|
|
1586
1488
|
|
|
1587
1489
|
if filter_seen_items and dataset is not None:
|
|
@@ -1594,15 +1496,10 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1594
1496
|
queries = queries.fillna(0, "num_items")
|
|
1595
1497
|
# 'selected_item_popularity' truncation by k + max_seen
|
|
1596
1498
|
max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).collect()[0][0]
|
|
1597
|
-
selected_item_popularity = selected_item_popularity
|
|
1598
|
-
|
|
1599
|
-
return queries.join(
|
|
1600
|
-
selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left"
|
|
1601
|
-
)
|
|
1499
|
+
selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
|
|
1500
|
+
return queries.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
|
|
1602
1501
|
|
|
1603
|
-
return queries.crossJoin(
|
|
1604
|
-
selected_item_popularity.filter(sf.col("rank") <= k)
|
|
1605
|
-
).drop("rank")
|
|
1502
|
+
return queries.crossJoin(selected_item_popularity.filter(sf.col("rank") <= k)).drop("rank")
|
|
1606
1503
|
|
|
1607
1504
|
def get_items_pd(self, items: SparkDataFrame) -> pd.DataFrame:
|
|
1608
1505
|
"""
|
|
@@ -1612,26 +1509,22 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1612
1509
|
selected_item_popularity = self._get_selected_item_popularity(items)
|
|
1613
1510
|
selected_item_popularity = selected_item_popularity.withColumn(
|
|
1614
1511
|
self.rating_column,
|
|
1615
|
-
sf.when(sf.col(self.rating_column) == sf.lit(0.0), 0.1**6).otherwise(
|
|
1616
|
-
sf.col(self.rating_column)
|
|
1617
|
-
),
|
|
1512
|
+
sf.when(sf.col(self.rating_column) == sf.lit(0.0), 0.1**6).otherwise(sf.col(self.rating_column)),
|
|
1618
1513
|
)
|
|
1619
1514
|
|
|
1620
1515
|
warnings.warn(
|
|
1621
1516
|
"Prediction with sampling performs spark to pandas convertion to master node, "
|
|
1622
1517
|
"this may lead to OOM exception for large item catalogue.",
|
|
1623
|
-
SparkCollectToMasterWarning
|
|
1518
|
+
SparkCollectToMasterWarning,
|
|
1624
1519
|
)
|
|
1625
1520
|
|
|
1626
1521
|
items_pd = selected_item_popularity.withColumn(
|
|
1627
1522
|
"probability",
|
|
1628
|
-
sf.col(self.rating_column)
|
|
1629
|
-
/ selected_item_popularity.select(sf.sum(self.rating_column)).first()[0],
|
|
1523
|
+
sf.col(self.rating_column) / selected_item_popularity.select(sf.sum(self.rating_column)).first()[0],
|
|
1630
1524
|
).toPandas()
|
|
1631
1525
|
|
|
1632
1526
|
return items_pd
|
|
1633
1527
|
|
|
1634
|
-
# pylint: disable=too-many-locals
|
|
1635
1528
|
def _predict_with_sampling(
|
|
1636
1529
|
self,
|
|
1637
1530
|
dataset: Dataset,
|
|
@@ -1667,10 +1560,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1667
1560
|
query_idx = pandas_df[query_column][0]
|
|
1668
1561
|
cnt = pandas_df["cnt"][0]
|
|
1669
1562
|
|
|
1670
|
-
if seed is not None
|
|
1671
|
-
local_rng = default_rng(seed + query_idx)
|
|
1672
|
-
else:
|
|
1673
|
-
local_rng = default_rng()
|
|
1563
|
+
local_rng = default_rng(seed + query_idx) if seed is not None else default_rng()
|
|
1674
1564
|
|
|
1675
1565
|
items_positions = local_rng.choice(
|
|
1676
1566
|
np.arange(items_pd.shape[0]),
|
|
@@ -1716,7 +1606,6 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1716
1606
|
|
|
1717
1607
|
return recs.groupby(self.query_column).applyInPandas(grouped_map, rec_schema)
|
|
1718
1608
|
|
|
1719
|
-
# pylint: disable=too-many-arguments
|
|
1720
1609
|
def _predict(
|
|
1721
1610
|
self,
|
|
1722
1611
|
dataset: Dataset,
|
|
@@ -1725,7 +1614,6 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1725
1614
|
items: SparkDataFrame,
|
|
1726
1615
|
filter_seen_items: bool = True,
|
|
1727
1616
|
) -> SparkDataFrame:
|
|
1728
|
-
|
|
1729
1617
|
if self.sample:
|
|
1730
1618
|
return self._predict_with_sampling(
|
|
1731
1619
|
dataset=dataset,
|
|
@@ -1735,14 +1623,12 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1735
1623
|
filter_seen_items=filter_seen_items,
|
|
1736
1624
|
)
|
|
1737
1625
|
else:
|
|
1738
|
-
return self._predict_without_sampling(
|
|
1739
|
-
dataset, k, queries, items, filter_seen_items
|
|
1740
|
-
)
|
|
1626
|
+
return self._predict_without_sampling(dataset, k, queries, items, filter_seen_items)
|
|
1741
1627
|
|
|
1742
1628
|
def _predict_pairs(
|
|
1743
1629
|
self,
|
|
1744
1630
|
pairs: SparkDataFrame,
|
|
1745
|
-
dataset: Optional[Dataset] = None,
|
|
1631
|
+
dataset: Optional[Dataset] = None, # noqa: ARG002
|
|
1746
1632
|
) -> SparkDataFrame:
|
|
1747
1633
|
return (
|
|
1748
1634
|
pairs.join(
|
|
@@ -1755,12 +1641,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1755
1641
|
)
|
|
1756
1642
|
|
|
1757
1643
|
def _predict_proba(
|
|
1758
|
-
self,
|
|
1759
|
-
dataset : Dataset,
|
|
1760
|
-
k: int,
|
|
1761
|
-
queries: SparkDataFrame,
|
|
1762
|
-
items: SparkDataFrame,
|
|
1763
|
-
filter_seen_items: bool = True
|
|
1644
|
+
self, dataset: Dataset, k: int, queries: SparkDataFrame, items: SparkDataFrame, filter_seen_items: bool = True
|
|
1764
1645
|
) -> np.ndarray:
|
|
1765
1646
|
"""
|
|
1766
1647
|
Inner method where model actually predicts.
|
|
@@ -1799,8 +1680,4 @@ class NonPersonalizedRecommender(Recommender, ABC):
|
|
|
1799
1680
|
|
|
1800
1681
|
return np.tile(items_pd, (n_users, k)).reshape(n_users, k, n_items).transpose((0, 2, 1))
|
|
1801
1682
|
|
|
1802
|
-
return super()._predict_proba(dataset,
|
|
1803
|
-
k,
|
|
1804
|
-
queries,
|
|
1805
|
-
items,
|
|
1806
|
-
filter_seen_items)
|
|
1683
|
+
return super()._predict_proba(dataset, k, queries, items, filter_seen_items)
|