replay-rec 0.20.3rc0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/METADATA +18 -12
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay/experimental/metrics/__init__.py +0 -62
- replay/experimental/metrics/base_metric.py +0 -603
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -50
- replay/experimental/models/admm_slim.py +0 -257
- replay/experimental/models/base_neighbour_rec.py +0 -200
- replay/experimental/models/base_rec.py +0 -1386
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -454
- replay/experimental/models/ddpg.py +0 -932
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -264
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/hierarchical_recommender.py +0 -331
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -303
- replay/experimental/models/mult_vae.py +0 -332
- replay/experimental/models/neural_ts.py +0 -986
- replay/experimental/models/neuromf.py +0 -406
- replay/experimental/models/scala_als.py +0 -293
- replay/experimental/models/u_lin_ucb.py +0 -115
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -102
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -839
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -186
- replay/experimental/utils/session_handler.py +0 -44
- replay_rec-0.20.3rc0.dist-info/RECORD +0 -193
- /replay/{experimental → data/nn/parquet/constants}/__init__.py +0 -0
- /replay/{experimental/models/dt4rec → data/nn/parquet/impl}/__init__.py +0 -0
- /replay/{experimental/models/extensions/spark_custom_models → data/nn/parquet/info}/__init__.py +0 -0
- /replay/{experimental/scenarios/two_stages → data/nn/parquet/utils}/__init__.py +0 -0
- /replay/{experimental → data}/utils/__init__.py +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,839 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Contains classes for data preparation and categorical features transformation.
|
|
3
|
-
``DataPreparator`` is used to transform DataFrames to a library format.
|
|
4
|
-
``Indexed`` is used to convert user and item ids to numeric format.
|
|
5
|
-
``CatFeaturesTransformer`` transforms categorical features with one-hot encoding.
|
|
6
|
-
``ToNumericFeatureTransformer`` leaves only numerical features
|
|
7
|
-
by one-hot encoding of some features and deleting the others.
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
|
-
import json
|
|
11
|
-
import logging
|
|
12
|
-
import string
|
|
13
|
-
from os.path import join
|
|
14
|
-
from typing import Optional
|
|
15
|
-
|
|
16
|
-
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, MissingImport, SparkDataFrame
|
|
17
|
-
from replay.utils.session_handler import State
|
|
18
|
-
|
|
19
|
-
if PYSPARK_AVAILABLE:
|
|
20
|
-
from pyspark.ml import Estimator, Transformer
|
|
21
|
-
from pyspark.ml.feature import IndexToString, StringIndexer, StringIndexerModel
|
|
22
|
-
from pyspark.ml.util import DefaultParamsWriter, MLReadable, MLReader, MLWritable, MLWriter
|
|
23
|
-
from pyspark.sql import functions as sf
|
|
24
|
-
from pyspark.sql.types import DoubleType, IntegerType, NumericType, StructField, StructType
|
|
25
|
-
|
|
26
|
-
from replay.utils.spark_utils import convert2spark, process_timestamp_column
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
LOG_COLUMNS = ["user_id", "item_id", "timestamp", "relevance"]
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
if PYSPARK_AVAILABLE:
|
|
33
|
-
|
|
34
|
-
class Indexer:
|
|
35
|
-
"""
|
|
36
|
-
This class is used to convert arbitrary id to numerical idx and back.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
user_indexer: StringIndexerModel
|
|
40
|
-
item_indexer: StringIndexerModel
|
|
41
|
-
inv_user_indexer: IndexToString
|
|
42
|
-
inv_item_indexer: IndexToString
|
|
43
|
-
user_type: None
|
|
44
|
-
item_type: None
|
|
45
|
-
suffix = "inner"
|
|
46
|
-
|
|
47
|
-
def __init__(self, user_col="user_id", item_col="item_id"):
|
|
48
|
-
"""
|
|
49
|
-
Provide column names for indexer to use
|
|
50
|
-
"""
|
|
51
|
-
self.user_col = user_col
|
|
52
|
-
self.item_col = item_col
|
|
53
|
-
|
|
54
|
-
@property
|
|
55
|
-
def _init_args(self):
|
|
56
|
-
return {
|
|
57
|
-
"user_col": self.user_col,
|
|
58
|
-
"item_col": self.item_col,
|
|
59
|
-
}
|
|
60
|
-
|
|
61
|
-
def fit(
|
|
62
|
-
self,
|
|
63
|
-
users: SparkDataFrame,
|
|
64
|
-
items: SparkDataFrame,
|
|
65
|
-
) -> None:
|
|
66
|
-
"""
|
|
67
|
-
Creates indexers to map raw id to numerical idx so that spark can handle them.
|
|
68
|
-
:param users: SparkDataFrame containing user column
|
|
69
|
-
:param items: SparkDataFrame containing item column
|
|
70
|
-
:return:
|
|
71
|
-
"""
|
|
72
|
-
users = users.select(self.user_col).withColumnRenamed(self.user_col, f"{self.user_col}_{self.suffix}")
|
|
73
|
-
items = items.select(self.item_col).withColumnRenamed(self.item_col, f"{self.item_col}_{self.suffix}")
|
|
74
|
-
|
|
75
|
-
self.user_type = users.schema[f"{self.user_col}_{self.suffix}"].dataType
|
|
76
|
-
self.item_type = items.schema[f"{self.item_col}_{self.suffix}"].dataType
|
|
77
|
-
|
|
78
|
-
self.user_indexer = StringIndexer(inputCol=f"{self.user_col}_{self.suffix}", outputCol="user_idx").fit(
|
|
79
|
-
users
|
|
80
|
-
)
|
|
81
|
-
self.item_indexer = StringIndexer(inputCol=f"{self.item_col}_{self.suffix}", outputCol="item_idx").fit(
|
|
82
|
-
items
|
|
83
|
-
)
|
|
84
|
-
self.inv_user_indexer = IndexToString(
|
|
85
|
-
inputCol=f"{self.user_col}_{self.suffix}",
|
|
86
|
-
outputCol=self.user_col,
|
|
87
|
-
labels=self.user_indexer.labels,
|
|
88
|
-
)
|
|
89
|
-
self.inv_item_indexer = IndexToString(
|
|
90
|
-
inputCol=f"{self.item_col}_{self.suffix}",
|
|
91
|
-
outputCol=self.item_col,
|
|
92
|
-
labels=self.item_indexer.labels,
|
|
93
|
-
)
|
|
94
|
-
|
|
95
|
-
def transform(self, df: SparkDataFrame) -> Optional[SparkDataFrame]:
|
|
96
|
-
"""
|
|
97
|
-
Convert raw ``user_col`` and ``item_col`` to numerical ``user_idx`` and ``item_idx``
|
|
98
|
-
|
|
99
|
-
:param df: dataframe with raw indexes
|
|
100
|
-
:return: dataframe with converted indexes
|
|
101
|
-
"""
|
|
102
|
-
if self.item_col in df.columns:
|
|
103
|
-
remaining_cols = df.drop(self.item_col).columns
|
|
104
|
-
df = df.withColumnRenamed(self.item_col, f"{self.item_col}_{self.suffix}")
|
|
105
|
-
self._reindex(df, "item")
|
|
106
|
-
df = self.item_indexer.transform(df).select(
|
|
107
|
-
sf.col("item_idx").cast("int").alias("item_idx"),
|
|
108
|
-
*remaining_cols,
|
|
109
|
-
)
|
|
110
|
-
if self.user_col in df.columns:
|
|
111
|
-
remaining_cols = df.drop(self.user_col).columns
|
|
112
|
-
df = df.withColumnRenamed(self.user_col, f"{self.user_col}_{self.suffix}")
|
|
113
|
-
self._reindex(df, "user")
|
|
114
|
-
df = self.user_indexer.transform(df).select(
|
|
115
|
-
sf.col("user_idx").cast("int").alias("user_idx"),
|
|
116
|
-
*remaining_cols,
|
|
117
|
-
)
|
|
118
|
-
return df
|
|
119
|
-
|
|
120
|
-
def inverse_transform(self, df: SparkDataFrame) -> SparkDataFrame:
|
|
121
|
-
"""
|
|
122
|
-
Convert SparkDataFrame to the initial indexes.
|
|
123
|
-
|
|
124
|
-
:param df: SparkDataFrame with numerical ``user_idx/item_idx`` columns
|
|
125
|
-
:return: SparkDataFrame with original user/item columns
|
|
126
|
-
"""
|
|
127
|
-
res = df
|
|
128
|
-
if "item_idx" in df.columns:
|
|
129
|
-
remaining_cols = res.drop("item_idx").columns
|
|
130
|
-
res = self.inv_item_indexer.transform(
|
|
131
|
-
res.withColumnRenamed("item_idx", f"{self.item_col}_{self.suffix}")
|
|
132
|
-
).select(
|
|
133
|
-
sf.col(self.item_col).cast(self.item_type).alias(self.item_col),
|
|
134
|
-
*remaining_cols,
|
|
135
|
-
)
|
|
136
|
-
if "user_idx" in df.columns:
|
|
137
|
-
remaining_cols = res.drop("user_idx").columns
|
|
138
|
-
res = self.inv_user_indexer.transform(
|
|
139
|
-
res.withColumnRenamed("user_idx", f"{self.user_col}_{self.suffix}")
|
|
140
|
-
).select(
|
|
141
|
-
sf.col(self.user_col).cast(self.user_type).alias(self.user_col),
|
|
142
|
-
*remaining_cols,
|
|
143
|
-
)
|
|
144
|
-
return res
|
|
145
|
-
|
|
146
|
-
def _reindex(self, df: SparkDataFrame, entity: str):
|
|
147
|
-
"""
|
|
148
|
-
Update indexer with new entries.
|
|
149
|
-
|
|
150
|
-
:param df: SparkDataFrame with users/items
|
|
151
|
-
:param entity: user or item
|
|
152
|
-
"""
|
|
153
|
-
indexer = getattr(self, f"{entity}_indexer")
|
|
154
|
-
inv_indexer = getattr(self, f"inv_{entity}_indexer")
|
|
155
|
-
new_objects = set(
|
|
156
|
-
map(
|
|
157
|
-
str,
|
|
158
|
-
df.select(indexer.getInputCol()).distinct().toPandas()[indexer.getInputCol()],
|
|
159
|
-
)
|
|
160
|
-
).difference(indexer.labels)
|
|
161
|
-
if new_objects:
|
|
162
|
-
new_labels = indexer.labels + list(new_objects)
|
|
163
|
-
setattr(
|
|
164
|
-
self,
|
|
165
|
-
f"{entity}_indexer",
|
|
166
|
-
indexer.from_labels(
|
|
167
|
-
new_labels,
|
|
168
|
-
inputCol=indexer.getInputCol(),
|
|
169
|
-
outputCol=indexer.getOutputCol(),
|
|
170
|
-
handleInvalid="error",
|
|
171
|
-
),
|
|
172
|
-
)
|
|
173
|
-
inv_indexer.setLabels(new_labels)
|
|
174
|
-
|
|
175
|
-
# We need to inherit it from DefaultParamsWriter to make it being saved correctly within Pipeline
|
|
176
|
-
class JoinIndexerMLWriter(DefaultParamsWriter):
|
|
177
|
-
"""Implements saving the JoinIndexerTransformer instance to disk.
|
|
178
|
-
Used when saving a trained pipeline.
|
|
179
|
-
Implements MLWriter.saveImpl(path) method.
|
|
180
|
-
"""
|
|
181
|
-
|
|
182
|
-
def __init__(self, instance):
|
|
183
|
-
super().__init__(instance)
|
|
184
|
-
self.instance = instance
|
|
185
|
-
|
|
186
|
-
def saveImpl(self, path: str) -> None: # noqa: N802
|
|
187
|
-
"""Save implementation"""
|
|
188
|
-
super().saveImpl(path)
|
|
189
|
-
|
|
190
|
-
spark = State().session
|
|
191
|
-
|
|
192
|
-
init_args = self.instance._init_args
|
|
193
|
-
sc = spark.sparkContext
|
|
194
|
-
df = spark.read.json(sc.parallelize([json.dumps(init_args)]))
|
|
195
|
-
df.coalesce(1).write.mode("overwrite").json(join(path, "init_args.json"))
|
|
196
|
-
|
|
197
|
-
self.instance.user_col_2_index_map.write.mode("overwrite").save(join(path, "user_col_2_index_map.parquet"))
|
|
198
|
-
self.instance.item_col_2_index_map.write.mode("overwrite").save(join(path, "item_col_2_index_map.parquet"))
|
|
199
|
-
|
|
200
|
-
class JoinIndexerMLReader(MLReader):
|
|
201
|
-
"""Implements reading the JoinIndexerTransformer instance from disk.
|
|
202
|
-
Used when loading a trained pipeline.
|
|
203
|
-
"""
|
|
204
|
-
|
|
205
|
-
def load(self, path):
|
|
206
|
-
"""Load the ML instance from the input path."""
|
|
207
|
-
spark = State().session
|
|
208
|
-
args = spark.read.json(join(path, "init_args.json")).first().asDict(recursive=True)
|
|
209
|
-
user_col_2_index_map = spark.read.parquet(join(path, "user_col_2_index_map.parquet"))
|
|
210
|
-
item_col_2_index_map = spark.read.parquet(join(path, "item_col_2_index_map.parquet"))
|
|
211
|
-
|
|
212
|
-
indexer = JoinBasedIndexerTransformer(
|
|
213
|
-
user_col=args["user_col"],
|
|
214
|
-
user_type=args["user_type"],
|
|
215
|
-
user_col_2_index_map=user_col_2_index_map,
|
|
216
|
-
item_col=args["item_col"],
|
|
217
|
-
item_type=args["item_type"],
|
|
218
|
-
item_col_2_index_map=item_col_2_index_map,
|
|
219
|
-
)
|
|
220
|
-
|
|
221
|
-
return indexer
|
|
222
|
-
|
|
223
|
-
class JoinBasedIndexerTransformer(Transformer, MLWritable, MLReadable):
|
|
224
|
-
"""
|
|
225
|
-
JoinBasedIndexer, that index user column and item column in input dataframe
|
|
226
|
-
"""
|
|
227
|
-
|
|
228
|
-
def __init__(
|
|
229
|
-
self,
|
|
230
|
-
user_col: str,
|
|
231
|
-
item_col: str,
|
|
232
|
-
user_type: str,
|
|
233
|
-
item_type: str,
|
|
234
|
-
user_col_2_index_map: SparkDataFrame,
|
|
235
|
-
item_col_2_index_map: SparkDataFrame,
|
|
236
|
-
update_map_on_transform: bool = True,
|
|
237
|
-
force_broadcast_on_mapping_joins: bool = True,
|
|
238
|
-
):
|
|
239
|
-
super().__init__()
|
|
240
|
-
self.user_col = user_col
|
|
241
|
-
self.item_col = item_col
|
|
242
|
-
self.user_type = user_type
|
|
243
|
-
self.item_type = item_type
|
|
244
|
-
self.user_col_2_index_map = user_col_2_index_map
|
|
245
|
-
self.item_col_2_index_map = item_col_2_index_map
|
|
246
|
-
self.update_map_on_transform = update_map_on_transform
|
|
247
|
-
self.force_broadcast_on_mapping_joins = force_broadcast_on_mapping_joins
|
|
248
|
-
|
|
249
|
-
@property
|
|
250
|
-
def _init_args(self):
|
|
251
|
-
return {
|
|
252
|
-
"user_col": self.user_col,
|
|
253
|
-
"item_col": self.item_col,
|
|
254
|
-
"user_type": self.user_type,
|
|
255
|
-
"item_type": self.item_type,
|
|
256
|
-
"update_map_on_transform": self.update_map_on_transform,
|
|
257
|
-
"force_broadcast_on_mapping_joins": self.force_broadcast_on_mapping_joins,
|
|
258
|
-
}
|
|
259
|
-
|
|
260
|
-
def set_update_map_on_transform(self, value: bool):
|
|
261
|
-
"""Sets 'update_map_on_transform' flag"""
|
|
262
|
-
self.update_map_on_transform = value
|
|
263
|
-
|
|
264
|
-
def set_force_broadcast_on_mapping_joins(self, value: bool):
|
|
265
|
-
"""Sets 'force_broadcast_on_mapping_joins' flag"""
|
|
266
|
-
self.force_broadcast_on_mapping_joins = value
|
|
267
|
-
|
|
268
|
-
def _get_item_mapping(self) -> SparkDataFrame:
|
|
269
|
-
if self.force_broadcast_on_mapping_joins:
|
|
270
|
-
mapping = sf.broadcast(self.item_col_2_index_map)
|
|
271
|
-
else:
|
|
272
|
-
mapping = self.item_col_2_index_map
|
|
273
|
-
return mapping
|
|
274
|
-
|
|
275
|
-
def _get_user_mapping(self) -> SparkDataFrame:
|
|
276
|
-
if self.force_broadcast_on_mapping_joins:
|
|
277
|
-
mapping = sf.broadcast(self.user_col_2_index_map)
|
|
278
|
-
else:
|
|
279
|
-
mapping = self.user_col_2_index_map
|
|
280
|
-
return mapping
|
|
281
|
-
|
|
282
|
-
def write(self) -> MLWriter:
|
|
283
|
-
"""Returns MLWriter instance that can save the Transformer instance."""
|
|
284
|
-
return JoinIndexerMLWriter(self)
|
|
285
|
-
|
|
286
|
-
@classmethod
|
|
287
|
-
def read(cls):
|
|
288
|
-
"""Returns an MLReader instance for this class."""
|
|
289
|
-
return JoinIndexerMLReader()
|
|
290
|
-
|
|
291
|
-
def _update_maps(self, df: SparkDataFrame):
|
|
292
|
-
new_items = (
|
|
293
|
-
df.join(self._get_item_mapping(), on=self.item_col, how="left_anti").select(self.item_col).distinct()
|
|
294
|
-
)
|
|
295
|
-
prev_item_count = self.item_col_2_index_map.count()
|
|
296
|
-
new_items_map = JoinBasedIndexerEstimator.get_map(new_items, self.item_col, "item_idx").select(
|
|
297
|
-
self.item_col, (sf.col("item_idx") + prev_item_count).alias("item_idx")
|
|
298
|
-
)
|
|
299
|
-
self.item_col_2_index_map = self.item_col_2_index_map.union(new_items_map)
|
|
300
|
-
|
|
301
|
-
new_users = (
|
|
302
|
-
df.join(self._get_user_mapping(), on=self.user_col, how="left_anti").select(self.user_col).distinct()
|
|
303
|
-
)
|
|
304
|
-
prev_user_count = self.user_col_2_index_map.count()
|
|
305
|
-
new_users_map = JoinBasedIndexerEstimator.get_map(new_users, self.user_col, "user_idx").select(
|
|
306
|
-
self.user_col, (sf.col("user_idx") + prev_user_count).alias("user_idx")
|
|
307
|
-
)
|
|
308
|
-
self.user_col_2_index_map = self.user_col_2_index_map.union(new_users_map)
|
|
309
|
-
|
|
310
|
-
def _transform(self, dataset: SparkDataFrame) -> SparkDataFrame:
|
|
311
|
-
if self.update_map_on_transform:
|
|
312
|
-
self._update_maps(dataset)
|
|
313
|
-
|
|
314
|
-
if self.item_col in dataset.columns:
|
|
315
|
-
remaining_cols = dataset.drop(self.item_col).columns
|
|
316
|
-
dataset = dataset.join(self._get_item_mapping(), on=self.item_col, how="left").select(
|
|
317
|
-
sf.col("item_idx").cast("int").alias("item_idx"),
|
|
318
|
-
*remaining_cols,
|
|
319
|
-
)
|
|
320
|
-
if self.user_col in dataset.columns:
|
|
321
|
-
remaining_cols = dataset.drop(self.user_col).columns
|
|
322
|
-
dataset = dataset.join(self._get_user_mapping(), on=self.user_col, how="left").select(
|
|
323
|
-
sf.col("user_idx").cast("int").alias("user_idx"),
|
|
324
|
-
*remaining_cols,
|
|
325
|
-
)
|
|
326
|
-
return dataset
|
|
327
|
-
|
|
328
|
-
def inverse_transform(self, df: SparkDataFrame) -> SparkDataFrame:
|
|
329
|
-
"""
|
|
330
|
-
Convert SparkDataFrame to the initial indexes.
|
|
331
|
-
|
|
332
|
-
:param df: SparkDataFrame with numerical ``user_idx/item_idx`` columns
|
|
333
|
-
:return: SparkDataFrame with original user/item columns
|
|
334
|
-
"""
|
|
335
|
-
if "item_idx" in df.columns:
|
|
336
|
-
remaining_cols = df.drop("item_idx").columns
|
|
337
|
-
df = df.join(self._get_item_mapping(), on="item_idx", how="left").select(
|
|
338
|
-
self.item_col,
|
|
339
|
-
*remaining_cols,
|
|
340
|
-
)
|
|
341
|
-
if "user_idx" in df.columns:
|
|
342
|
-
remaining_cols = df.drop("user_idx").columns
|
|
343
|
-
df = df.join(self._get_user_mapping(), on="user_idx", how="left").select(
|
|
344
|
-
self.user_col,
|
|
345
|
-
*remaining_cols,
|
|
346
|
-
)
|
|
347
|
-
return df
|
|
348
|
-
|
|
349
|
-
class JoinBasedIndexerEstimator(Estimator):
|
|
350
|
-
"""
|
|
351
|
-
Estimator that produces JoinBasedIndexerTransformer
|
|
352
|
-
"""
|
|
353
|
-
|
|
354
|
-
def __init__(self, user_col="user_id", item_col="item_id"):
|
|
355
|
-
"""
|
|
356
|
-
Provide column names for indexer to use
|
|
357
|
-
"""
|
|
358
|
-
self.user_col = user_col
|
|
359
|
-
self.item_col = item_col
|
|
360
|
-
self.user_col_2_index_map = None
|
|
361
|
-
self.item_col_2_index_map = None
|
|
362
|
-
self.user_type = None
|
|
363
|
-
self.item_type = None
|
|
364
|
-
|
|
365
|
-
@staticmethod
|
|
366
|
-
def get_map(df: SparkDataFrame, col_name: str, idx_col_name: str) -> SparkDataFrame:
|
|
367
|
-
"""Creates indexes [0, .., k] for values from `col_name` column.
|
|
368
|
-
|
|
369
|
-
:param df: input dataframe
|
|
370
|
-
:param col_name: column name from `df` that need to index
|
|
371
|
-
:param idx_col_name: column name with indexes
|
|
372
|
-
:return: SparkDataFrame with map "col_name" -> "idx_col_name"
|
|
373
|
-
"""
|
|
374
|
-
uid_rdd = df.select(col_name).distinct().rdd.map(lambda x: x[col_name]).zipWithIndex()
|
|
375
|
-
|
|
376
|
-
return uid_rdd.toDF(
|
|
377
|
-
StructType(
|
|
378
|
-
[
|
|
379
|
-
df.schema[col_name],
|
|
380
|
-
StructField(idx_col_name, IntegerType(), False),
|
|
381
|
-
]
|
|
382
|
-
)
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
def _fit(self, dataset: SparkDataFrame) -> Transformer:
|
|
386
|
-
"""
|
|
387
|
-
Creates indexers to map raw id to numerical idx so that spark can handle them.
|
|
388
|
-
:param df: SparkDataFrame containing user column and item column
|
|
389
|
-
:return:
|
|
390
|
-
"""
|
|
391
|
-
|
|
392
|
-
self.user_col_2_index_map = self.get_map(dataset, self.user_col, "user_idx")
|
|
393
|
-
self.item_col_2_index_map = self.get_map(dataset, self.item_col, "item_idx")
|
|
394
|
-
|
|
395
|
-
self.user_type = dataset.schema[self.user_col].dataType
|
|
396
|
-
self.item_type = dataset.schema[self.item_col].dataType
|
|
397
|
-
|
|
398
|
-
return JoinBasedIndexerTransformer(
|
|
399
|
-
user_col=self.user_col,
|
|
400
|
-
user_type=str(self.user_type),
|
|
401
|
-
item_col=self.item_col,
|
|
402
|
-
item_type=str(self.item_type),
|
|
403
|
-
user_col_2_index_map=self.user_col_2_index_map,
|
|
404
|
-
item_col_2_index_map=self.item_col_2_index_map,
|
|
405
|
-
)
|
|
406
|
-
|
|
407
|
-
class DataPreparator:
|
|
408
|
-
"""Transforms data to a library format:
|
|
409
|
-
- read as a spark dataframe/ convert pandas dataframe to spark
|
|
410
|
-
- check for nulls
|
|
411
|
-
- create relevance/timestamp columns if absent
|
|
412
|
-
- convert dates to TimestampType
|
|
413
|
-
|
|
414
|
-
Examples:
|
|
415
|
-
|
|
416
|
-
Loading log DataFrame
|
|
417
|
-
|
|
418
|
-
>>> import pandas as pd
|
|
419
|
-
>>> from replay.experimental.preprocessing.data_preparator import DataPreparator
|
|
420
|
-
>>>
|
|
421
|
-
>>> log = pd.DataFrame({"user": [2, 2, 2, 1],
|
|
422
|
-
... "item_id": [1, 2, 3, 3],
|
|
423
|
-
... "rel": [5, 5, 5, 5]}
|
|
424
|
-
... )
|
|
425
|
-
>>> dp = DataPreparator()
|
|
426
|
-
>>> correct_log = dp.transform(data=log,
|
|
427
|
-
... columns_mapping={"user_id": "user",
|
|
428
|
-
... "item_id": "item_id",
|
|
429
|
-
... "relevance": "rel"}
|
|
430
|
-
... )
|
|
431
|
-
>>> correct_log.show(2)
|
|
432
|
-
+-------+-------+---------+-------------------+
|
|
433
|
-
|user_id|item_id|relevance| timestamp|
|
|
434
|
-
+-------+-------+---------+-------------------+
|
|
435
|
-
| 2| 1| 5.0|2099-01-01 00:00:00|
|
|
436
|
-
| 2| 2| 5.0|2099-01-01 00:00:00|
|
|
437
|
-
+-------+-------+---------+-------------------+
|
|
438
|
-
only showing top 2 rows
|
|
439
|
-
<BLANKLINE>
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
Loading user features
|
|
443
|
-
|
|
444
|
-
>>> import pandas as pd
|
|
445
|
-
>>> from replay.experimental.preprocessing.data_preparator import DataPreparator
|
|
446
|
-
>>>
|
|
447
|
-
>>> log = pd.DataFrame({"user": ["user1", "user1", "user2"],
|
|
448
|
-
... "f0": ["feature1","feature2","feature1"],
|
|
449
|
-
... "f1": ["left","left","center"],
|
|
450
|
-
... "ts": ["2019-01-01","2019-01-01","2019-01-01"]}
|
|
451
|
-
... )
|
|
452
|
-
>>> dp = DataPreparator()
|
|
453
|
-
>>> correct_log = dp.transform(data=log,
|
|
454
|
-
... columns_mapping={"user_id": "user"},
|
|
455
|
-
... )
|
|
456
|
-
>>> correct_log.show(3)
|
|
457
|
-
+-------+--------+------+----------+
|
|
458
|
-
|user_id| f0| f1| ts|
|
|
459
|
-
+-------+--------+------+----------+
|
|
460
|
-
| user1|feature1| left|2019-01-01|
|
|
461
|
-
| user1|feature2| left|2019-01-01|
|
|
462
|
-
| user2|feature1|center|2019-01-01|
|
|
463
|
-
+-------+--------+------+----------+
|
|
464
|
-
<BLANKLINE>
|
|
465
|
-
|
|
466
|
-
"""
|
|
467
|
-
|
|
468
|
-
_logger: Optional[logging.Logger] = None
|
|
469
|
-
|
|
470
|
-
@property
|
|
471
|
-
def logger(self) -> logging.Logger:
|
|
472
|
-
"""
|
|
473
|
-
:returns: get library logger
|
|
474
|
-
"""
|
|
475
|
-
if self._logger is None:
|
|
476
|
-
self._logger = logging.getLogger("replay")
|
|
477
|
-
return self._logger
|
|
478
|
-
|
|
479
|
-
@staticmethod
|
|
480
|
-
def read_as_spark_df(
|
|
481
|
-
data: Optional[DataFrameLike] = None,
|
|
482
|
-
path: Optional[str] = None,
|
|
483
|
-
format_type: Optional[str] = None,
|
|
484
|
-
**kwargs,
|
|
485
|
-
) -> SparkDataFrame:
|
|
486
|
-
"""
|
|
487
|
-
Read spark dataframe from file of transform pandas dataframe.
|
|
488
|
-
|
|
489
|
-
:param data: DataFrame to process (``pass`` or ``data`` should be defined)
|
|
490
|
-
:param path: path to data (``pass`` or ``data`` should be defined)
|
|
491
|
-
:param format_type: file type, one of ``[csv , parquet , json , table]``
|
|
492
|
-
:param kwargs: extra arguments passed to
|
|
493
|
-
``spark.read.<format>(path, **reader_kwargs)``
|
|
494
|
-
:return: spark DataFrame
|
|
495
|
-
"""
|
|
496
|
-
if data is not None:
|
|
497
|
-
dataframe = convert2spark(data)
|
|
498
|
-
elif path and format_type:
|
|
499
|
-
spark = State().session
|
|
500
|
-
if format_type == "csv":
|
|
501
|
-
dataframe = spark.read.csv(path, inferSchema=True, **kwargs)
|
|
502
|
-
elif format_type == "parquet":
|
|
503
|
-
dataframe = spark.read.parquet(path)
|
|
504
|
-
elif format_type == "json":
|
|
505
|
-
dataframe = spark.read.json(path, **kwargs)
|
|
506
|
-
elif format_type == "table":
|
|
507
|
-
dataframe = spark.read.table(path)
|
|
508
|
-
else:
|
|
509
|
-
msg = f"Invalid value of format_type='{format_type}'"
|
|
510
|
-
raise ValueError(msg)
|
|
511
|
-
else:
|
|
512
|
-
msg = "Either data or path parameters must not be None"
|
|
513
|
-
raise ValueError(msg)
|
|
514
|
-
return dataframe
|
|
515
|
-
|
|
516
|
-
def check_df(self, dataframe: SparkDataFrame, columns_mapping: dict[str, str]) -> None:
|
|
517
|
-
"""
|
|
518
|
-
Check:
|
|
519
|
-
- if dataframe is not empty,
|
|
520
|
-
- if columns from ``columns_mapping`` are present in dataframe
|
|
521
|
-
- warn about nulls in columns from ``columns_mapping``
|
|
522
|
-
- warn about absent of ``timestamp/relevance`` columns for interactions log
|
|
523
|
-
- warn about wrong relevance DataType
|
|
524
|
-
|
|
525
|
-
:param dataframe: spark DataFrame to process
|
|
526
|
-
:param columns_mapping: dictionary mapping "key: column name in input DataFrame".
|
|
527
|
-
Possible keys: ``[user_id, user_id, timestamp, relevance]``
|
|
528
|
-
``columns_mapping`` values specifies the nature of the DataFrame:
|
|
529
|
-
- if both ``[user_id, item_id]`` are present,
|
|
530
|
-
then the dataframe is a log of interactions.
|
|
531
|
-
Specify ``timestamp, relevance`` columns in mapping if available.
|
|
532
|
-
- if ether ``user_id`` or ``item_id`` is present,
|
|
533
|
-
then the dataframe is a dataframe of user/item features
|
|
534
|
-
"""
|
|
535
|
-
if not dataframe.head(1):
|
|
536
|
-
msg = "DataFrame is empty"
|
|
537
|
-
raise ValueError(msg)
|
|
538
|
-
|
|
539
|
-
for value in columns_mapping.values():
|
|
540
|
-
if value not in dataframe.columns:
|
|
541
|
-
msg = f"Column `{value}` stated in mapping is absent in dataframe"
|
|
542
|
-
raise ValueError(msg)
|
|
543
|
-
|
|
544
|
-
for column in columns_mapping.values():
|
|
545
|
-
if dataframe.where(sf.col(column).isNull()).count() > 0:
|
|
546
|
-
self.logger.info(
|
|
547
|
-
"Column `%s` has NULL values. Handle NULL values before "
|
|
548
|
-
"the next data preprocessing/model training steps",
|
|
549
|
-
column,
|
|
550
|
-
)
|
|
551
|
-
|
|
552
|
-
if "user_id" in columns_mapping and "item_id" in columns_mapping:
|
|
553
|
-
absent_cols = set(LOG_COLUMNS).difference(columns_mapping.keys())
|
|
554
|
-
if len(absent_cols) > 0:
|
|
555
|
-
self.logger.info(
|
|
556
|
-
"Columns %s are absent, but may be required for models training. "
|
|
557
|
-
"Add them with DataPreparator().generate_absent_log_cols",
|
|
558
|
-
list(absent_cols),
|
|
559
|
-
)
|
|
560
|
-
if "relevance" in columns_mapping and not isinstance(
|
|
561
|
-
dataframe.schema[columns_mapping["relevance"]].dataType,
|
|
562
|
-
NumericType,
|
|
563
|
-
):
|
|
564
|
-
self.logger.info(
|
|
565
|
-
"Relevance column `%s` should be numeric, but it is %s",
|
|
566
|
-
columns_mapping["relevance"],
|
|
567
|
-
dataframe.schema[columns_mapping["relevance"]].dataType,
|
|
568
|
-
)
|
|
569
|
-
|
|
570
|
-
@staticmethod
|
|
571
|
-
def add_absent_log_cols(
|
|
572
|
-
dataframe: SparkDataFrame,
|
|
573
|
-
columns_mapping: dict[str, str],
|
|
574
|
-
default_relevance: float = 1.0,
|
|
575
|
-
default_ts: str = "2099-01-01",
|
|
576
|
-
):
|
|
577
|
-
"""
|
|
578
|
-
Add ``relevance`` and ``timestamp`` columns with default values if
|
|
579
|
-
``relevance`` or ``timestamp`` is absent among mapping keys.
|
|
580
|
-
|
|
581
|
-
:param dataframe: interactions log to process
|
|
582
|
-
:param columns_mapping: dictionary mapping "key: column name in input DataFrame".
|
|
583
|
-
Possible keys: ``[user_id, user_id, timestamp, relevance]``
|
|
584
|
-
:param default_relevance: default value for generated `relevance` column
|
|
585
|
-
:param default_ts: str, default value for generated `timestamp` column
|
|
586
|
-
:return: spark DataFrame with generated ``timestamp`` and ``relevance`` columns
|
|
587
|
-
if absent in original dataframe
|
|
588
|
-
"""
|
|
589
|
-
absent_cols = set(LOG_COLUMNS).difference(columns_mapping.keys())
|
|
590
|
-
if "relevance" in absent_cols:
|
|
591
|
-
dataframe = dataframe.withColumn("relevance", sf.lit(default_relevance).cast(DoubleType()))
|
|
592
|
-
if "timestamp" in absent_cols:
|
|
593
|
-
dataframe = dataframe.withColumn("timestamp", sf.to_timestamp(sf.lit(default_ts)))
|
|
594
|
-
return dataframe
|
|
595
|
-
|
|
596
|
-
@staticmethod
|
|
597
|
-
def _rename(df: SparkDataFrame, mapping: dict) -> Optional[SparkDataFrame]:
|
|
598
|
-
"""
|
|
599
|
-
rename dataframe columns based on mapping
|
|
600
|
-
"""
|
|
601
|
-
if df is None or mapping is None:
|
|
602
|
-
return df
|
|
603
|
-
for out_col, in_col in mapping.items():
|
|
604
|
-
if in_col in df.columns:
|
|
605
|
-
df = df.withColumnRenamed(in_col, out_col)
|
|
606
|
-
return df
|
|
607
|
-
|
|
608
|
-
def transform(
|
|
609
|
-
self,
|
|
610
|
-
columns_mapping: dict[str, str],
|
|
611
|
-
data: Optional[DataFrameLike] = None,
|
|
612
|
-
path: Optional[str] = None,
|
|
613
|
-
format_type: Optional[str] = None,
|
|
614
|
-
date_format: Optional[str] = None,
|
|
615
|
-
reader_kwargs: Optional[dict] = None,
|
|
616
|
-
) -> SparkDataFrame:
|
|
617
|
-
"""
|
|
618
|
-
Transforms log, user or item features into a Spark DataFrame
|
|
619
|
-
``[user_id, user_id, timestamp, relevance]``,
|
|
620
|
-
``[user_id, *features]``, or ``[item_id, *features]``.
|
|
621
|
-
Input is either file of ``format_type``
|
|
622
|
-
at ``path``, or ``pandas.DataFrame`` or ``spark.DataFrame``.
|
|
623
|
-
Transform performs:
|
|
624
|
-
- dataframe reading/convert to spark DataFrame format
|
|
625
|
-
- check dataframe (nulls, columns_mapping)
|
|
626
|
-
- rename columns from mapping to standard names (user_id, user_id, timestamp, relevance)
|
|
627
|
-
- for interactions log: create absent columns,
|
|
628
|
-
convert ``timestamp`` column to TimestampType and ``relevance`` to DoubleType
|
|
629
|
-
|
|
630
|
-
:param columns_mapping: dictionary mapping "key: column name in input DataFrame".
|
|
631
|
-
Possible keys: ``[user_id, user_id, timestamp, relevance]``
|
|
632
|
-
``columns_mapping`` values specifies the nature of the DataFrame:
|
|
633
|
-
- if both ``[user_id, item_id]`` are present,
|
|
634
|
-
then the dataframe is a log of interactions.
|
|
635
|
-
Specify ``timestamp, relevance`` columns in mapping if present.
|
|
636
|
-
- if ether ``user_id`` or ``item_id`` is present,
|
|
637
|
-
then the dataframe is a dataframe of user/item features
|
|
638
|
-
|
|
639
|
-
:param data: DataFrame to process
|
|
640
|
-
:param path: path to data
|
|
641
|
-
:param format_type: file type, one of ``[csv , parquet , json , table]``
|
|
642
|
-
:param date_format: format for the ``timestamp`` column
|
|
643
|
-
:param reader_kwargs: extra arguments passed to
|
|
644
|
-
``spark.read.<format>(path, **reader_kwargs)``
|
|
645
|
-
:return: processed DataFrame
|
|
646
|
-
"""
|
|
647
|
-
is_log = False
|
|
648
|
-
if "user_id" in columns_mapping and "item_id" in columns_mapping:
|
|
649
|
-
self.logger.info(
|
|
650
|
-
"Columns with ids of users or items are present in mapping. "
|
|
651
|
-
"The dataframe will be treated as an interactions log."
|
|
652
|
-
)
|
|
653
|
-
is_log = True
|
|
654
|
-
elif "user_id" not in columns_mapping and "item_id" not in columns_mapping:
|
|
655
|
-
msg = "Mapping either for user ids or for item ids is not stated in `columns_mapping`"
|
|
656
|
-
raise ValueError(msg)
|
|
657
|
-
else:
|
|
658
|
-
self.logger.info(
|
|
659
|
-
"Column with ids of users or items is absent in mapping. "
|
|
660
|
-
"The dataframe will be treated as a users'/items' features dataframe."
|
|
661
|
-
)
|
|
662
|
-
reader_kwargs = {} if reader_kwargs is None else reader_kwargs
|
|
663
|
-
dataframe = self.read_as_spark_df(data=data, path=path, format_type=format_type, **reader_kwargs)
|
|
664
|
-
self.check_df(dataframe, columns_mapping=columns_mapping)
|
|
665
|
-
dataframe = self._rename(df=dataframe, mapping=columns_mapping)
|
|
666
|
-
if is_log:
|
|
667
|
-
dataframe = self.add_absent_log_cols(dataframe=dataframe, columns_mapping=columns_mapping)
|
|
668
|
-
dataframe = dataframe.withColumn("relevance", sf.col("relevance").cast(DoubleType()))
|
|
669
|
-
dataframe = process_timestamp_column(
|
|
670
|
-
dataframe=dataframe,
|
|
671
|
-
column_name="timestamp",
|
|
672
|
-
date_format=date_format,
|
|
673
|
-
)
|
|
674
|
-
|
|
675
|
-
return dataframe
|
|
676
|
-
|
|
677
|
-
class CatFeaturesTransformer:
|
|
678
|
-
"""Transform categorical features in ``cat_cols_list``
|
|
679
|
-
with one-hot encoding and remove original columns."""
|
|
680
|
-
|
|
681
|
-
def __init__(
|
|
682
|
-
self,
|
|
683
|
-
cat_cols_list: list,
|
|
684
|
-
alias: str = "ohe",
|
|
685
|
-
):
|
|
686
|
-
"""
|
|
687
|
-
:param cat_cols_list: list of categorical columns
|
|
688
|
-
:param alias: prefix for one-hot encoding columns
|
|
689
|
-
"""
|
|
690
|
-
self.cat_cols_list = cat_cols_list
|
|
691
|
-
self.expressions_list = []
|
|
692
|
-
self.alias = alias
|
|
693
|
-
|
|
694
|
-
def fit(self, spark_df: Optional[SparkDataFrame]) -> None:
|
|
695
|
-
"""
|
|
696
|
-
Save categories for each column
|
|
697
|
-
:param spark_df: Spark DataFrame with features
|
|
698
|
-
"""
|
|
699
|
-
if spark_df is None:
|
|
700
|
-
return
|
|
701
|
-
|
|
702
|
-
cat_feat_values_dict = {
|
|
703
|
-
name: (spark_df.select(sf.collect_set(sf.col(name))).first()[0]) for name in self.cat_cols_list
|
|
704
|
-
}
|
|
705
|
-
self.expressions_list = [
|
|
706
|
-
sf.when(sf.col(col_name) == cur_name, 1)
|
|
707
|
-
.otherwise(0)
|
|
708
|
-
.alias(
|
|
709
|
-
f"""{self.alias}_{col_name}_{str(cur_name).translate(
|
|
710
|
-
str.maketrans(
|
|
711
|
-
"", "", string.punctuation + string.whitespace
|
|
712
|
-
)
|
|
713
|
-
)[:30]}"""
|
|
714
|
-
)
|
|
715
|
-
for col_name, col_values in cat_feat_values_dict.items()
|
|
716
|
-
for cur_name in col_values
|
|
717
|
-
]
|
|
718
|
-
|
|
719
|
-
def transform(self, spark_df: Optional[SparkDataFrame]):
|
|
720
|
-
"""
|
|
721
|
-
Transform categorical columns.
|
|
722
|
-
If there are any new categories that were not present at fit stage, they will be ignored.
|
|
723
|
-
:param spark_df: feature DataFrame
|
|
724
|
-
:return: transformed DataFrame
|
|
725
|
-
"""
|
|
726
|
-
if spark_df is None:
|
|
727
|
-
return None
|
|
728
|
-
return spark_df.select(*spark_df.columns, *self.expressions_list).drop(*self.cat_cols_list)
|
|
729
|
-
|
|
730
|
-
class ToNumericFeatureTransformer:
|
|
731
|
-
"""Transform user/item features to numeric types:
|
|
732
|
-
- numeric features stays as is
|
|
733
|
-
- categorical features:
|
|
734
|
-
if threshold is defined:
|
|
735
|
-
- all non-numeric columns with less unique values than threshold are one-hot encoded
|
|
736
|
-
- remaining columns are dropped
|
|
737
|
-
else all non-numeric columns are one-hot encoded
|
|
738
|
-
"""
|
|
739
|
-
|
|
740
|
-
cat_feat_transformer: Optional[CatFeaturesTransformer]
|
|
741
|
-
cols_to_ohe: Optional[list]
|
|
742
|
-
cols_to_del: Optional[list]
|
|
743
|
-
all_columns: Optional[list]
|
|
744
|
-
|
|
745
|
-
def __init__(self, threshold: Optional[int] = 100):
|
|
746
|
-
self.threshold = threshold
|
|
747
|
-
self.fitted = False
|
|
748
|
-
|
|
749
|
-
def fit(self, features: Optional[SparkDataFrame]) -> None:
|
|
750
|
-
"""
|
|
751
|
-
Determine categorical columns for one-hot encoding.
|
|
752
|
-
Non categorical columns with more values than threshold will be deleted.
|
|
753
|
-
Saves categories for each column.
|
|
754
|
-
:param features: input DataFrame
|
|
755
|
-
"""
|
|
756
|
-
self.cat_feat_transformer = None
|
|
757
|
-
self.cols_to_del = []
|
|
758
|
-
self.fitted = True
|
|
759
|
-
|
|
760
|
-
if features is None:
|
|
761
|
-
self.all_columns = None
|
|
762
|
-
return
|
|
763
|
-
|
|
764
|
-
self.all_columns = sorted(features.columns)
|
|
765
|
-
|
|
766
|
-
spark_df_non_numeric_cols = [
|
|
767
|
-
col
|
|
768
|
-
for col in features.columns
|
|
769
|
-
if (not isinstance(features.schema[col].dataType, NumericType))
|
|
770
|
-
and (col not in {"user_idx", "item_idx"})
|
|
771
|
-
]
|
|
772
|
-
|
|
773
|
-
# numeric only
|
|
774
|
-
if len(spark_df_non_numeric_cols) == 0:
|
|
775
|
-
self.cols_to_ohe = []
|
|
776
|
-
return
|
|
777
|
-
|
|
778
|
-
if self.threshold is None:
|
|
779
|
-
self.cols_to_ohe = spark_df_non_numeric_cols
|
|
780
|
-
else:
|
|
781
|
-
counts_pd = (
|
|
782
|
-
features.agg(*[sf.approx_count_distinct(sf.col(c)).alias(c) for c in spark_df_non_numeric_cols])
|
|
783
|
-
.toPandas()
|
|
784
|
-
.T
|
|
785
|
-
)
|
|
786
|
-
self.cols_to_ohe = (counts_pd[counts_pd[0] <= self.threshold]).index.values
|
|
787
|
-
|
|
788
|
-
self.cols_to_del = [col for col in spark_df_non_numeric_cols if col not in set(self.cols_to_ohe)]
|
|
789
|
-
|
|
790
|
-
if self.cols_to_del:
|
|
791
|
-
State().logger.warning(
|
|
792
|
-
"%s columns contain more that threshold unique values and will be deleted",
|
|
793
|
-
self.cols_to_del,
|
|
794
|
-
)
|
|
795
|
-
|
|
796
|
-
if len(self.cols_to_ohe) > 0:
|
|
797
|
-
self.cat_feat_transformer = CatFeaturesTransformer(cat_cols_list=self.cols_to_ohe)
|
|
798
|
-
self.cat_feat_transformer.fit(features.drop(*self.cols_to_del))
|
|
799
|
-
|
|
800
|
-
def transform(self, spark_df: Optional[SparkDataFrame]) -> Optional[SparkDataFrame]:
|
|
801
|
-
"""
|
|
802
|
-
Transform categorical features.
|
|
803
|
-
Use one hot encoding for columns with the amount of unique values smaller
|
|
804
|
-
than threshold and delete other columns.
|
|
805
|
-
:param spark_df: input DataFrame
|
|
806
|
-
:return: processed DataFrame
|
|
807
|
-
"""
|
|
808
|
-
if not self.fitted:
|
|
809
|
-
msg = "Call fit before running transform"
|
|
810
|
-
raise AttributeError(msg)
|
|
811
|
-
|
|
812
|
-
if spark_df is None or self.all_columns is None:
|
|
813
|
-
return None
|
|
814
|
-
|
|
815
|
-
if self.cat_feat_transformer is None:
|
|
816
|
-
return spark_df.drop(*self.cols_to_del)
|
|
817
|
-
|
|
818
|
-
if sorted(spark_df.columns) != self.all_columns:
|
|
819
|
-
msg = (
|
|
820
|
-
f"Columns from fit do not match "
|
|
821
|
-
f"columns in transform. "
|
|
822
|
-
f"Fit columns: {self.all_columns},"
|
|
823
|
-
f"Transform columns: {spark_df.columns}"
|
|
824
|
-
)
|
|
825
|
-
raise ValueError(msg)
|
|
826
|
-
|
|
827
|
-
return self.cat_feat_transformer.transform(spark_df.drop(*self.cols_to_del))
|
|
828
|
-
|
|
829
|
-
def fit_transform(self, spark_df: SparkDataFrame) -> SparkDataFrame:
|
|
830
|
-
"""
|
|
831
|
-
:param spark_df: input DataFrame
|
|
832
|
-
:return: output DataFrame
|
|
833
|
-
"""
|
|
834
|
-
self.fit(spark_df)
|
|
835
|
-
return self.transform(spark_df)
|
|
836
|
-
|
|
837
|
-
else:
|
|
838
|
-
Indexer = MissingImport
|
|
839
|
-
DataPreparator = MissingImport
|