replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +0 -61
- replay/experimental/metrics/base_metric.py +0 -661
- replay/experimental/metrics/coverage.py +0 -117
- replay/experimental/metrics/experiment.py +0 -200
- replay/experimental/metrics/hitrate.py +0 -27
- replay/experimental/metrics/map.py +0 -31
- replay/experimental/metrics/mrr.py +0 -19
- replay/experimental/metrics/ncis_precision.py +0 -32
- replay/experimental/metrics/ndcg.py +0 -50
- replay/experimental/metrics/precision.py +0 -23
- replay/experimental/metrics/recall.py +0 -26
- replay/experimental/metrics/rocauc.py +0 -50
- replay/experimental/metrics/surprisal.py +0 -102
- replay/experimental/metrics/unexpectedness.py +0 -74
- replay/experimental/models/__init__.py +0 -10
- replay/experimental/models/admm_slim.py +0 -216
- replay/experimental/models/base_neighbour_rec.py +0 -222
- replay/experimental/models/base_rec.py +0 -1361
- replay/experimental/models/base_torch_rec.py +0 -247
- replay/experimental/models/cql.py +0 -468
- replay/experimental/models/ddpg.py +0 -1007
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +0 -193
- replay/experimental/models/dt4rec/gpt1.py +0 -411
- replay/experimental/models/dt4rec/trainer.py +0 -128
- replay/experimental/models/dt4rec/utils.py +0 -274
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
- replay/experimental/models/implicit_wrap.py +0 -138
- replay/experimental/models/lightfm_wrap.py +0 -327
- replay/experimental/models/mult_vae.py +0 -374
- replay/experimental/models/neuromf.py +0 -462
- replay/experimental/models/scala_als.py +0 -311
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -58
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -929
- replay/experimental/preprocessing/padder.py +0 -231
- replay/experimental/preprocessing/sequence_generator.py +0 -218
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
- replay/experimental/scenarios/two_stages/reranker.py +0 -116
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -213
- replay/experimental/utils/session_handler.py +0 -47
- replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
- replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
|
@@ -1,138 +0,0 @@
|
|
|
1
|
-
from os.path import join
|
|
2
|
-
from typing import Optional
|
|
3
|
-
|
|
4
|
-
from replay.data import get_schema
|
|
5
|
-
from replay.experimental.models.base_rec import Recommender
|
|
6
|
-
from replay.preprocessing import CSRConverter
|
|
7
|
-
from replay.utils import PandasDataFrame, SparkDataFrame
|
|
8
|
-
from replay.utils.spark_utils import load_pickled_from_parquet, save_picklable_to_parquet
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class ImplicitWrap(Recommender):
|
|
12
|
-
"""Wrapper for `implicit
|
|
13
|
-
<https://github.com/benfred/implicit>`_
|
|
14
|
-
|
|
15
|
-
Example:
|
|
16
|
-
|
|
17
|
-
>>> import implicit
|
|
18
|
-
>>> model = implicit.als.AlternatingLeastSquares(factors=5)
|
|
19
|
-
>>> als = ImplicitWrap(model)
|
|
20
|
-
|
|
21
|
-
This way you can use implicit models as any other in replay
|
|
22
|
-
with conversions made under the hood.
|
|
23
|
-
|
|
24
|
-
>>> import pandas as pd
|
|
25
|
-
>>> from replay.utils.spark_utils import convert2spark
|
|
26
|
-
>>> df = pd.DataFrame({"user_idx": [1, 1, 2, 2], "item_idx": [1, 2, 2, 3], "relevance": [1, 1, 1, 1]})
|
|
27
|
-
>>> df = convert2spark(df)
|
|
28
|
-
>>> als.fit_predict(df, 1, users=[1])[["user_idx", "item_idx"]].toPandas()
|
|
29
|
-
user_idx item_idx
|
|
30
|
-
0 1 3
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
def __init__(self, model):
|
|
34
|
-
"""Provide initialized ``implicit`` model."""
|
|
35
|
-
self.model = model
|
|
36
|
-
self.logger.info(
|
|
37
|
-
"The model is a wrapper of a non-distributed model which may affect performance"
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
@property
|
|
41
|
-
def _init_args(self):
|
|
42
|
-
return {"model": None}
|
|
43
|
-
|
|
44
|
-
def _save_model(self, path: str):
|
|
45
|
-
save_picklable_to_parquet(self.model, join(path, "model"))
|
|
46
|
-
|
|
47
|
-
def _load_model(self, path: str):
|
|
48
|
-
self.model = load_pickled_from_parquet(join(path, "model"))
|
|
49
|
-
|
|
50
|
-
def _fit(
|
|
51
|
-
self,
|
|
52
|
-
log: SparkDataFrame,
|
|
53
|
-
user_features: Optional[SparkDataFrame] = None,
|
|
54
|
-
item_features: Optional[SparkDataFrame] = None,
|
|
55
|
-
) -> None:
|
|
56
|
-
matrix = CSRConverter(
|
|
57
|
-
first_dim_column="user_idx",
|
|
58
|
-
second_dim_column="item_idx",
|
|
59
|
-
data_column="relevance"
|
|
60
|
-
).transform(log)
|
|
61
|
-
self.model.fit(matrix)
|
|
62
|
-
|
|
63
|
-
@staticmethod
|
|
64
|
-
def _pd_func(model, items_to_use=None, user_item_data=None, filter_seen_items=False):
|
|
65
|
-
def predict_by_user_item(pandas_df):
|
|
66
|
-
user = int(pandas_df["user_idx"].iloc[0])
|
|
67
|
-
items = items_to_use if items_to_use else pandas_df.item_idx.to_list()
|
|
68
|
-
|
|
69
|
-
items_res, rel = model.recommend(
|
|
70
|
-
userid=user,
|
|
71
|
-
user_items=user_item_data[user] if filter_seen_items else None,
|
|
72
|
-
N=len(items),
|
|
73
|
-
filter_already_liked_items=filter_seen_items,
|
|
74
|
-
items=items,
|
|
75
|
-
)
|
|
76
|
-
return PandasDataFrame(
|
|
77
|
-
{
|
|
78
|
-
"user_idx": [user] * len(items_res),
|
|
79
|
-
"item_idx": items_res,
|
|
80
|
-
"relevance": rel,
|
|
81
|
-
}
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
return predict_by_user_item
|
|
85
|
-
|
|
86
|
-
# pylint: disable=too-many-arguments
|
|
87
|
-
def _predict(
|
|
88
|
-
self,
|
|
89
|
-
log: SparkDataFrame,
|
|
90
|
-
k: int,
|
|
91
|
-
users: SparkDataFrame,
|
|
92
|
-
items: SparkDataFrame,
|
|
93
|
-
user_features: Optional[SparkDataFrame] = None,
|
|
94
|
-
item_features: Optional[SparkDataFrame] = None,
|
|
95
|
-
filter_seen_items: bool = True,
|
|
96
|
-
) -> SparkDataFrame:
|
|
97
|
-
|
|
98
|
-
items_to_use = items.distinct().toPandas().item_idx.tolist()
|
|
99
|
-
user_item_data = CSRConverter(
|
|
100
|
-
first_dim_column="user_idx",
|
|
101
|
-
second_dim_column="item_idx",
|
|
102
|
-
data_column="relevance"
|
|
103
|
-
).transform(log)
|
|
104
|
-
model = self.model
|
|
105
|
-
rec_schema = get_schema(
|
|
106
|
-
query_column="user_idx",
|
|
107
|
-
item_column="item_idx",
|
|
108
|
-
rating_column="relevance",
|
|
109
|
-
has_timestamp=False,
|
|
110
|
-
)
|
|
111
|
-
return (
|
|
112
|
-
users.select("user_idx")
|
|
113
|
-
.groupby("user_idx")
|
|
114
|
-
.applyInPandas(self._pd_func(
|
|
115
|
-
model=model,
|
|
116
|
-
items_to_use=items_to_use,
|
|
117
|
-
user_item_data=user_item_data,
|
|
118
|
-
filter_seen_items=filter_seen_items), rec_schema)
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
def _predict_pairs(
|
|
122
|
-
self,
|
|
123
|
-
pairs: SparkDataFrame,
|
|
124
|
-
log: Optional[SparkDataFrame] = None,
|
|
125
|
-
user_features: Optional[SparkDataFrame] = None,
|
|
126
|
-
item_features: Optional[SparkDataFrame] = None,
|
|
127
|
-
) -> SparkDataFrame:
|
|
128
|
-
|
|
129
|
-
model = self.model
|
|
130
|
-
rec_schema = get_schema(
|
|
131
|
-
query_column="user_idx",
|
|
132
|
-
item_column="item_idx",
|
|
133
|
-
rating_column="relevance",
|
|
134
|
-
has_timestamp=False,
|
|
135
|
-
)
|
|
136
|
-
return pairs.groupby("user_idx").applyInPandas(
|
|
137
|
-
self._pd_func(model=model, filter_seen_items=False),
|
|
138
|
-
rec_schema)
|
|
@@ -1,327 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from os.path import join
|
|
3
|
-
from typing import Optional, Tuple
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
from lightfm import LightFM
|
|
7
|
-
from scipy.sparse import csr_matrix, diags, hstack
|
|
8
|
-
from sklearn.preprocessing import MinMaxScaler
|
|
9
|
-
|
|
10
|
-
from replay.data import get_schema
|
|
11
|
-
from replay.experimental.models.base_rec import HybridRecommender
|
|
12
|
-
from replay.experimental.utils.session_handler import State
|
|
13
|
-
from replay.preprocessing import CSRConverter
|
|
14
|
-
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
|
|
15
|
-
from replay.utils.spark_utils import check_numeric, load_pickled_from_parquet, save_picklable_to_parquet
|
|
16
|
-
|
|
17
|
-
if PYSPARK_AVAILABLE:
|
|
18
|
-
import pyspark.sql.functions as sf
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
# pylint: disable=too-many-locals, too-many-instance-attributes
|
|
22
|
-
class LightFMWrap(HybridRecommender):
|
|
23
|
-
"""Wrapper for LightFM."""
|
|
24
|
-
|
|
25
|
-
epochs: int = 10
|
|
26
|
-
_search_space = {
|
|
27
|
-
"loss": {
|
|
28
|
-
"type": "categorical",
|
|
29
|
-
"args": ["logistic", "bpr", "warp", "warp-kos"],
|
|
30
|
-
},
|
|
31
|
-
"no_components": {"type": "loguniform_int", "args": [8, 512]},
|
|
32
|
-
}
|
|
33
|
-
user_feat_scaler: Optional[MinMaxScaler] = None
|
|
34
|
-
item_feat_scaler: Optional[MinMaxScaler] = None
|
|
35
|
-
|
|
36
|
-
def __init__(
|
|
37
|
-
self,
|
|
38
|
-
no_components: int = 128,
|
|
39
|
-
loss: str = "warp",
|
|
40
|
-
random_state: Optional[int] = None,
|
|
41
|
-
): # pylint: disable=too-many-arguments
|
|
42
|
-
np.random.seed(42)
|
|
43
|
-
self.no_components = no_components
|
|
44
|
-
self.loss = loss
|
|
45
|
-
self.random_state = random_state
|
|
46
|
-
cpu_count = os.cpu_count()
|
|
47
|
-
self.num_threads = cpu_count if cpu_count is not None else 1
|
|
48
|
-
|
|
49
|
-
@property
|
|
50
|
-
def _init_args(self):
|
|
51
|
-
return {
|
|
52
|
-
"no_components": self.no_components,
|
|
53
|
-
"loss": self.loss,
|
|
54
|
-
"random_state": self.random_state,
|
|
55
|
-
}
|
|
56
|
-
|
|
57
|
-
def _save_model(self, path: str):
|
|
58
|
-
save_picklable_to_parquet(self.model, join(path, "model"))
|
|
59
|
-
save_picklable_to_parquet(self.user_feat_scaler, join(path, "user_feat_scaler"))
|
|
60
|
-
save_picklable_to_parquet(self.item_feat_scaler, join(path, "item_feat_scaler"))
|
|
61
|
-
|
|
62
|
-
def _load_model(self, path: str):
|
|
63
|
-
self.model = load_pickled_from_parquet(join(path, "model"))
|
|
64
|
-
self.user_feat_scaler = load_pickled_from_parquet(join(path, "user_feat_scaler"))
|
|
65
|
-
self.item_feat_scaler = load_pickled_from_parquet(join(path, "item_feat_scaler"))
|
|
66
|
-
|
|
67
|
-
def _feature_table_to_csr(
|
|
68
|
-
self,
|
|
69
|
-
log_ids_list: SparkDataFrame,
|
|
70
|
-
feature_table: Optional[SparkDataFrame] = None,
|
|
71
|
-
) -> Optional[csr_matrix]:
|
|
72
|
-
"""
|
|
73
|
-
Transform features to sparse matrix
|
|
74
|
-
Matrix consists of two parts:
|
|
75
|
-
1) Left one is a ohe-hot encoding of user and item ids.
|
|
76
|
-
Matrix size is: number of users or items * number of user or items in fit.
|
|
77
|
-
Cold users and items are represented with empty strings
|
|
78
|
-
2) Right one is a numerical features, passed with feature_table.
|
|
79
|
-
MinMaxScaler is applied per column, and then value is divided by the row sum.
|
|
80
|
-
|
|
81
|
-
:param feature_table: dataframe with ``user_idx`` or ``item_idx``,
|
|
82
|
-
other columns are features.
|
|
83
|
-
:param log_ids_list: dataframe with ``user_idx`` or ``item_idx``,
|
|
84
|
-
containing unique ids from log.
|
|
85
|
-
:returns: feature matrix
|
|
86
|
-
"""
|
|
87
|
-
|
|
88
|
-
if feature_table is None:
|
|
89
|
-
return None
|
|
90
|
-
|
|
91
|
-
check_numeric(feature_table)
|
|
92
|
-
log_ids_list = log_ids_list.distinct()
|
|
93
|
-
entity = "item" if "item_idx" in feature_table.columns else "user"
|
|
94
|
-
idx_col_name = f"{entity}_idx"
|
|
95
|
-
|
|
96
|
-
# filter features by log
|
|
97
|
-
feature_table = feature_table.join(
|
|
98
|
-
log_ids_list, on=idx_col_name, how="inner"
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
fit_dim = getattr(self, f"_{entity}_dim")
|
|
102
|
-
matrix_height = max(
|
|
103
|
-
fit_dim,
|
|
104
|
-
log_ids_list.select(sf.max(idx_col_name)).collect()[0][0] + 1,
|
|
105
|
-
)
|
|
106
|
-
if not feature_table.rdd.isEmpty():
|
|
107
|
-
matrix_height = max(
|
|
108
|
-
matrix_height,
|
|
109
|
-
feature_table.select(sf.max(idx_col_name)).collect()[0][0] + 1,
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
features_np = (
|
|
113
|
-
feature_table.select(
|
|
114
|
-
idx_col_name,
|
|
115
|
-
# first column contains id, next contain features
|
|
116
|
-
*(
|
|
117
|
-
sorted(
|
|
118
|
-
list(
|
|
119
|
-
set(feature_table.columns).difference(
|
|
120
|
-
{idx_col_name}
|
|
121
|
-
)
|
|
122
|
-
)
|
|
123
|
-
)
|
|
124
|
-
),
|
|
125
|
-
)
|
|
126
|
-
.toPandas()
|
|
127
|
-
.to_numpy()
|
|
128
|
-
)
|
|
129
|
-
entities_ids = features_np[:, 0]
|
|
130
|
-
features_np = features_np[:, 1:]
|
|
131
|
-
number_of_features = features_np.shape[1]
|
|
132
|
-
|
|
133
|
-
all_ids_list = log_ids_list.toPandas().to_numpy().ravel()
|
|
134
|
-
entities_seen_in_fit = all_ids_list[all_ids_list < fit_dim]
|
|
135
|
-
|
|
136
|
-
entity_id_features = csr_matrix(
|
|
137
|
-
(
|
|
138
|
-
[1.0] * entities_seen_in_fit.shape[0],
|
|
139
|
-
(entities_seen_in_fit, entities_seen_in_fit),
|
|
140
|
-
),
|
|
141
|
-
shape=(matrix_height, fit_dim),
|
|
142
|
-
)
|
|
143
|
-
|
|
144
|
-
scaler_name = f"{entity}_feat_scaler"
|
|
145
|
-
if getattr(self, scaler_name) is None:
|
|
146
|
-
if not features_np.size:
|
|
147
|
-
raise ValueError(f"features for {entity}s from log are absent")
|
|
148
|
-
setattr(self, scaler_name, MinMaxScaler().fit(features_np))
|
|
149
|
-
|
|
150
|
-
if features_np.size:
|
|
151
|
-
features_np = getattr(self, scaler_name).transform(features_np)
|
|
152
|
-
sparse_features = csr_matrix(
|
|
153
|
-
(
|
|
154
|
-
features_np.ravel(),
|
|
155
|
-
(
|
|
156
|
-
np.repeat(entities_ids, number_of_features),
|
|
157
|
-
np.tile(
|
|
158
|
-
np.arange(number_of_features),
|
|
159
|
-
entities_ids.shape[0],
|
|
160
|
-
),
|
|
161
|
-
),
|
|
162
|
-
),
|
|
163
|
-
shape=(matrix_height, number_of_features),
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
else:
|
|
167
|
-
sparse_features = csr_matrix((matrix_height, number_of_features))
|
|
168
|
-
|
|
169
|
-
concat_features = hstack([entity_id_features, sparse_features])
|
|
170
|
-
concat_features_sum = concat_features.sum(axis=1).A.ravel()
|
|
171
|
-
mask = concat_features_sum != 0.0
|
|
172
|
-
concat_features_sum[mask] = 1.0 / concat_features_sum[mask]
|
|
173
|
-
return diags(concat_features_sum, format="csr") @ concat_features
|
|
174
|
-
|
|
175
|
-
def _fit(
|
|
176
|
-
self,
|
|
177
|
-
log: SparkDataFrame,
|
|
178
|
-
user_features: Optional[SparkDataFrame] = None,
|
|
179
|
-
item_features: Optional[SparkDataFrame] = None,
|
|
180
|
-
) -> None:
|
|
181
|
-
self.user_feat_scaler = None
|
|
182
|
-
self.item_feat_scaler = None
|
|
183
|
-
|
|
184
|
-
interactions_matrix = CSRConverter(
|
|
185
|
-
first_dim_column="user_idx",
|
|
186
|
-
second_dim_column="item_idx",
|
|
187
|
-
data_column="relevance",
|
|
188
|
-
row_count=self._user_dim,
|
|
189
|
-
column_count=self._item_dim
|
|
190
|
-
).transform(log)
|
|
191
|
-
csr_item_features = self._feature_table_to_csr(
|
|
192
|
-
log.select("item_idx").distinct(), item_features
|
|
193
|
-
)
|
|
194
|
-
csr_user_features = self._feature_table_to_csr(
|
|
195
|
-
log.select("user_idx").distinct(), user_features
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
if user_features is not None:
|
|
199
|
-
self.can_predict_cold_users = True
|
|
200
|
-
if item_features is not None:
|
|
201
|
-
self.can_predict_cold_items = True
|
|
202
|
-
|
|
203
|
-
self.model = LightFM(
|
|
204
|
-
loss=self.loss,
|
|
205
|
-
no_components=self.no_components,
|
|
206
|
-
random_state=self.random_state,
|
|
207
|
-
).fit(
|
|
208
|
-
interactions=interactions_matrix,
|
|
209
|
-
epochs=self.epochs,
|
|
210
|
-
num_threads=self.num_threads,
|
|
211
|
-
item_features=csr_item_features,
|
|
212
|
-
user_features=csr_user_features,
|
|
213
|
-
)
|
|
214
|
-
|
|
215
|
-
def _predict_selected_pairs(
|
|
216
|
-
self,
|
|
217
|
-
pairs: SparkDataFrame,
|
|
218
|
-
user_features: Optional[SparkDataFrame] = None,
|
|
219
|
-
item_features: Optional[SparkDataFrame] = None,
|
|
220
|
-
):
|
|
221
|
-
def predict_by_user(pandas_df: PandasDataFrame) -> PandasDataFrame:
|
|
222
|
-
pandas_df["relevance"] = model.predict(
|
|
223
|
-
user_ids=pandas_df["user_idx"].to_numpy(),
|
|
224
|
-
item_ids=pandas_df["item_idx"].to_numpy(),
|
|
225
|
-
item_features=csr_item_features,
|
|
226
|
-
user_features=csr_user_features,
|
|
227
|
-
)
|
|
228
|
-
return pandas_df
|
|
229
|
-
|
|
230
|
-
model = self.model
|
|
231
|
-
|
|
232
|
-
if self.can_predict_cold_users and user_features is None:
|
|
233
|
-
raise ValueError("User features are missing for predict")
|
|
234
|
-
if self.can_predict_cold_items and item_features is None:
|
|
235
|
-
raise ValueError("Item features are missing for predict")
|
|
236
|
-
|
|
237
|
-
csr_item_features = self._feature_table_to_csr(
|
|
238
|
-
pairs.select("item_idx").distinct(), item_features
|
|
239
|
-
)
|
|
240
|
-
csr_user_features = self._feature_table_to_csr(
|
|
241
|
-
pairs.select("user_idx").distinct(), user_features
|
|
242
|
-
)
|
|
243
|
-
rec_schema = get_schema(
|
|
244
|
-
query_column="user_idx",
|
|
245
|
-
item_column="item_idx",
|
|
246
|
-
rating_column="relevance",
|
|
247
|
-
has_timestamp=False,
|
|
248
|
-
)
|
|
249
|
-
return pairs.groupby("user_idx").applyInPandas(
|
|
250
|
-
predict_by_user, rec_schema
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
# pylint: disable=too-many-arguments
|
|
254
|
-
def _predict(
|
|
255
|
-
self,
|
|
256
|
-
log: SparkDataFrame,
|
|
257
|
-
k: int,
|
|
258
|
-
users: SparkDataFrame,
|
|
259
|
-
items: SparkDataFrame,
|
|
260
|
-
user_features: Optional[SparkDataFrame] = None,
|
|
261
|
-
item_features: Optional[SparkDataFrame] = None,
|
|
262
|
-
filter_seen_items: bool = True,
|
|
263
|
-
) -> SparkDataFrame:
|
|
264
|
-
return self._predict_selected_pairs(
|
|
265
|
-
users.crossJoin(items), user_features, item_features
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
def _predict_pairs(
|
|
269
|
-
self,
|
|
270
|
-
pairs: SparkDataFrame,
|
|
271
|
-
log: Optional[SparkDataFrame] = None,
|
|
272
|
-
user_features: Optional[SparkDataFrame] = None,
|
|
273
|
-
item_features: Optional[SparkDataFrame] = None,
|
|
274
|
-
) -> SparkDataFrame:
|
|
275
|
-
return self._predict_selected_pairs(
|
|
276
|
-
pairs, user_features, item_features
|
|
277
|
-
)
|
|
278
|
-
|
|
279
|
-
def _get_features(
|
|
280
|
-
self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
|
|
281
|
-
) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
|
|
282
|
-
"""
|
|
283
|
-
Get features from LightFM.
|
|
284
|
-
LightFM has methods get_item_representations/get_user_representations,
|
|
285
|
-
which accept object matrix and return features.
|
|
286
|
-
|
|
287
|
-
:param ids: id item_idx/user_idx to get features for
|
|
288
|
-
:param features: features for item_idx/user_idx
|
|
289
|
-
:return: spark-dataframe with biases and vectors for users/items and vector size
|
|
290
|
-
"""
|
|
291
|
-
entity = "item" if "item_idx" in ids.columns else "user"
|
|
292
|
-
ids_list = ids.toPandas()[f"{entity}_idx"]
|
|
293
|
-
|
|
294
|
-
# models without features use sparse matrix
|
|
295
|
-
if features is None:
|
|
296
|
-
matrix_width = getattr(self, f"fit_{entity}s").count()
|
|
297
|
-
warm_ids = ids_list[ids_list < matrix_width]
|
|
298
|
-
sparse_features = csr_matrix(
|
|
299
|
-
(
|
|
300
|
-
[1] * warm_ids.shape[0],
|
|
301
|
-
(warm_ids, warm_ids),
|
|
302
|
-
),
|
|
303
|
-
shape=(ids_list.max() + 1, matrix_width),
|
|
304
|
-
)
|
|
305
|
-
else:
|
|
306
|
-
sparse_features = self._feature_table_to_csr(ids, features)
|
|
307
|
-
|
|
308
|
-
biases, vectors = getattr(self.model, f"get_{entity}_representations")(
|
|
309
|
-
sparse_features
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
embed_list = list(
|
|
313
|
-
zip(
|
|
314
|
-
ids_list,
|
|
315
|
-
biases[ids_list].tolist(),
|
|
316
|
-
vectors[ids_list].tolist(),
|
|
317
|
-
)
|
|
318
|
-
)
|
|
319
|
-
lightfm_factors = State().session.createDataFrame(
|
|
320
|
-
embed_list,
|
|
321
|
-
schema=[
|
|
322
|
-
f"{entity}_idx",
|
|
323
|
-
f"{entity}_bias",
|
|
324
|
-
f"{entity}_factors",
|
|
325
|
-
],
|
|
326
|
-
)
|
|
327
|
-
return lightfm_factors, self.model.no_components
|