replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
|
@@ -6,9 +6,11 @@ Contains classes for encoding categorical data
|
|
|
6
6
|
``LabelEncoder`` to apply multiple LabelEncodingRule to dataframe.
|
|
7
7
|
"""
|
|
8
8
|
import abc
|
|
9
|
-
import
|
|
9
|
+
import warnings
|
|
10
10
|
from typing import Dict, List, Literal, Mapping, Optional, Sequence, Union
|
|
11
11
|
|
|
12
|
+
import polars as pl
|
|
13
|
+
|
|
12
14
|
from replay.utils import (
|
|
13
15
|
PYSPARK_AVAILABLE,
|
|
14
16
|
DataFrameLike,
|
|
@@ -20,13 +22,16 @@ from replay.utils import (
|
|
|
20
22
|
|
|
21
23
|
if PYSPARK_AVAILABLE:
|
|
22
24
|
from pyspark.sql import functions as sf
|
|
25
|
+
from pyspark.sql.types import LongType, StructType
|
|
23
26
|
from pyspark.storagelevel import StorageLevel
|
|
24
|
-
from pyspark.sql.types import StructType, LongType
|
|
25
27
|
|
|
26
|
-
HandleUnknownStrategies = Literal["error", "use_default_value"]
|
|
28
|
+
HandleUnknownStrategies = Literal["error", "use_default_value", "drop"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class LabelEncoderTransformWarning(Warning):
|
|
32
|
+
"""Label encoder transform warning."""
|
|
27
33
|
|
|
28
34
|
|
|
29
|
-
# pylint: disable=missing-function-docstring
|
|
30
35
|
class BaseLabelEncodingRule(abc.ABC): # pragma: no cover
|
|
31
36
|
"""
|
|
32
37
|
Interface of the label encoding rule
|
|
@@ -70,7 +75,6 @@ class BaseLabelEncodingRule(abc.ABC): # pragma: no cover
|
|
|
70
75
|
raise NotImplementedError()
|
|
71
76
|
|
|
72
77
|
|
|
73
|
-
# pylint: disable=too-many-instance-attributes
|
|
74
78
|
class LabelEncodingRule(BaseLabelEncodingRule):
|
|
75
79
|
"""
|
|
76
80
|
Implementation of the encoding rule for categorical variables of PySpark and Pandas Data Frames.
|
|
@@ -79,7 +83,7 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
79
83
|
"""
|
|
80
84
|
|
|
81
85
|
_ENCODED_COLUMN_SUFFIX: str = "_encoded"
|
|
82
|
-
_HANDLE_UNKNOWN_STRATEGIES = ("error", "use_default_value")
|
|
86
|
+
_HANDLE_UNKNOWN_STRATEGIES = ("error", "use_default_value", "drop")
|
|
83
87
|
_TRANSFORM_PERFORMANCE_THRESHOLD_FOR_PANDAS = 100_000
|
|
84
88
|
|
|
85
89
|
def __init__(
|
|
@@ -99,6 +103,7 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
99
103
|
When set to ``error`` an error will be raised in case an unknown label is present during transform.
|
|
100
104
|
When set to ``use_default_value``, the encoded value of unknown label will be set
|
|
101
105
|
to the value given for the parameter default_value.
|
|
106
|
+
When set to ``drop``, the unknown labels will be dropped.
|
|
102
107
|
Default: ``error``.
|
|
103
108
|
:param default_value: Default value that will fill the unknown labels after transform.
|
|
104
109
|
When the parameter handle_unknown is set to ``use_default_value``,
|
|
@@ -110,7 +115,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
110
115
|
Default: ``None``.
|
|
111
116
|
"""
|
|
112
117
|
if handle_unknown not in self._HANDLE_UNKNOWN_STRATEGIES:
|
|
113
|
-
|
|
118
|
+
msg = f"handle_unknown should be either 'error' or 'use_default_value', got {handle_unknown}."
|
|
119
|
+
raise ValueError(msg)
|
|
114
120
|
self._handle_unknown = handle_unknown
|
|
115
121
|
if (
|
|
116
122
|
self._handle_unknown == "use_default_value"
|
|
@@ -118,7 +124,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
118
124
|
and not isinstance(default_value, int)
|
|
119
125
|
and default_value != "last"
|
|
120
126
|
):
|
|
121
|
-
|
|
127
|
+
msg = "Default value should be None, int or 'last'"
|
|
128
|
+
raise ValueError(msg)
|
|
122
129
|
|
|
123
130
|
self._default_value = default_value
|
|
124
131
|
self._col = column
|
|
@@ -135,12 +142,14 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
135
142
|
|
|
136
143
|
def get_mapping(self) -> Mapping:
|
|
137
144
|
if self._mapping is None:
|
|
138
|
-
|
|
145
|
+
msg = "Label encoder is not fitted"
|
|
146
|
+
raise RuntimeError(msg)
|
|
139
147
|
return self._mapping
|
|
140
148
|
|
|
141
149
|
def get_inverse_mapping(self) -> Mapping:
|
|
142
150
|
if self._mapping is None:
|
|
143
|
-
|
|
151
|
+
msg = "Label encoder is not fitted"
|
|
152
|
+
raise RuntimeError(msg)
|
|
144
153
|
return self._inverse_mapping
|
|
145
154
|
|
|
146
155
|
def _make_inverse_mapping(self) -> Mapping:
|
|
@@ -159,17 +168,14 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
159
168
|
unique_col_values.rdd.zipWithIndex()
|
|
160
169
|
.toDF(
|
|
161
170
|
StructType()
|
|
162
|
-
.add("_1",
|
|
163
|
-
StructType()
|
|
164
|
-
.add(self._col, df.schema[self._col].dataType, True),
|
|
165
|
-
True)
|
|
171
|
+
.add("_1", StructType().add(self._col, df.schema[self._col].dataType, True), True)
|
|
166
172
|
.add("_2", LongType(), True)
|
|
167
173
|
)
|
|
168
174
|
.select(sf.col(f"_1.{self._col}").alias(self._col), sf.col("_2").alias(self._target_col))
|
|
169
175
|
.persist(StorageLevel.MEMORY_ONLY)
|
|
170
176
|
)
|
|
171
177
|
|
|
172
|
-
self._mapping = mapping_on_spark.rdd.collectAsMap()
|
|
178
|
+
self._mapping = mapping_on_spark.rdd.collectAsMap()
|
|
173
179
|
mapping_on_spark.unpersist()
|
|
174
180
|
unique_col_values.unpersist()
|
|
175
181
|
|
|
@@ -198,17 +204,18 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
198
204
|
elif isinstance(df, PolarsDataFrame):
|
|
199
205
|
self._fit_polars(df)
|
|
200
206
|
else:
|
|
201
|
-
|
|
207
|
+
msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
|
|
208
|
+
raise NotImplementedError(msg)
|
|
202
209
|
self._inverse_mapping = self._make_inverse_mapping()
|
|
203
210
|
self._inverse_mapping_list = self._make_inverse_mapping_list()
|
|
204
|
-
if self._handle_unknown == "use_default_value":
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
211
|
+
if self._handle_unknown == "use_default_value" and self._default_value in self._inverse_mapping:
|
|
212
|
+
msg = (
|
|
213
|
+
"The used value for default_value "
|
|
214
|
+
f"{self._default_value} is one of the "
|
|
215
|
+
"values already used for encoding the "
|
|
216
|
+
"seen labels."
|
|
217
|
+
)
|
|
218
|
+
raise ValueError(msg)
|
|
212
219
|
self._is_fitted = True
|
|
213
220
|
return self
|
|
214
221
|
|
|
@@ -226,18 +233,15 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
226
233
|
new_unique_values.rdd.zipWithIndex()
|
|
227
234
|
.toDF(
|
|
228
235
|
StructType()
|
|
229
|
-
.add("_1",
|
|
230
|
-
StructType()
|
|
231
|
-
.add(self._col, df.schema[self._col].dataType),
|
|
232
|
-
True)
|
|
236
|
+
.add("_1", StructType().add(self._col, df.schema[self._col].dataType), True)
|
|
233
237
|
.add("_2", LongType(), True)
|
|
234
238
|
)
|
|
235
239
|
.select(sf.col(f"_1.{self._col}").alias(self._col), sf.col("_2").alias(self._target_col))
|
|
236
240
|
.withColumn(self._target_col, sf.col(self._target_col) + max_value)
|
|
237
|
-
.rdd.collectAsMap()
|
|
241
|
+
.rdd.collectAsMap()
|
|
238
242
|
)
|
|
239
|
-
self._mapping.update(new_data)
|
|
240
|
-
self._inverse_mapping.update({v: k for k, v in new_data.items()})
|
|
243
|
+
self._mapping.update(new_data)
|
|
244
|
+
self._inverse_mapping.update({v: k for k, v in new_data.items()})
|
|
241
245
|
self._inverse_mapping_list.extend(new_data.keys())
|
|
242
246
|
new_unique_values.unpersist()
|
|
243
247
|
|
|
@@ -245,9 +249,10 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
245
249
|
assert self._mapping is not None
|
|
246
250
|
|
|
247
251
|
new_unique_values = set(df[self._col].tolist()) - set(self._mapping)
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
self.
|
|
252
|
+
last_mapping_value = max(self._mapping.values())
|
|
253
|
+
new_data: dict = {value: last_mapping_value + i for i, value in enumerate(new_unique_values, start=1)}
|
|
254
|
+
self._mapping.update(new_data)
|
|
255
|
+
self._inverse_mapping.update({v: k for k, v in new_data.items()})
|
|
251
256
|
self._inverse_mapping_list.extend(new_data.keys())
|
|
252
257
|
|
|
253
258
|
def _partial_fit_polars(self, df: PolarsDataFrame) -> None:
|
|
@@ -255,8 +260,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
255
260
|
|
|
256
261
|
new_unique_values = set(df.select(self._col).unique().to_series().to_list()) - set(self._mapping)
|
|
257
262
|
new_data: dict = {value: max(self._mapping.values()) + i for i, value in enumerate(new_unique_values, start=1)}
|
|
258
|
-
self._mapping.update(new_data)
|
|
259
|
-
self._inverse_mapping.update({v: k for k, v in new_data.items()})
|
|
263
|
+
self._mapping.update(new_data)
|
|
264
|
+
self._inverse_mapping.update({v: k for k, v in new_data.items()})
|
|
260
265
|
self._inverse_mapping_list.extend(new_data.keys())
|
|
261
266
|
|
|
262
267
|
def partial_fit(self, df: DataFrameLike) -> "LabelEncodingRule":
|
|
@@ -276,7 +281,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
276
281
|
elif isinstance(df, PolarsDataFrame):
|
|
277
282
|
self._partial_fit_polars(df)
|
|
278
283
|
else:
|
|
279
|
-
|
|
284
|
+
msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
|
|
285
|
+
raise NotImplementedError(msg)
|
|
280
286
|
|
|
281
287
|
self._is_fitted = True
|
|
282
288
|
return self
|
|
@@ -299,14 +305,24 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
299
305
|
joined_df.loc[unknown_mask, self._target_col] = -1
|
|
300
306
|
is_unknown_label |= unknown_mask.sum() > 0
|
|
301
307
|
|
|
302
|
-
if is_unknown_label
|
|
308
|
+
if is_unknown_label:
|
|
303
309
|
unknown_mask = joined_df[self._target_col] == -1
|
|
304
|
-
if self._handle_unknown == "
|
|
310
|
+
if self._handle_unknown == "drop":
|
|
311
|
+
joined_df.drop(joined_df[unknown_mask].index, inplace=True)
|
|
312
|
+
if joined_df.empty:
|
|
313
|
+
warnings.warn(
|
|
314
|
+
f"You are trying to transform dataframe with all values are unknown for {self._col}, "
|
|
315
|
+
"with `handle_unknown_strategy=drop` leads to empty dataframe",
|
|
316
|
+
LabelEncoderTransformWarning,
|
|
317
|
+
)
|
|
318
|
+
elif self._handle_unknown == "error":
|
|
305
319
|
unknown_unique_labels = joined_df[self._col][unknown_mask].unique().tolist()
|
|
306
320
|
msg = f"Found unknown labels {unknown_unique_labels} in column {self._col} during transform"
|
|
307
321
|
raise ValueError(msg)
|
|
308
|
-
|
|
309
|
-
|
|
322
|
+
else:
|
|
323
|
+
if default_value != -1:
|
|
324
|
+
joined_df[self._target_col] = joined_df[self._target_col].astype("int")
|
|
325
|
+
joined_df[self._target_col] = joined_df[self._target_col].replace({-1: default_value})
|
|
310
326
|
|
|
311
327
|
result_df = joined_df.drop(self._col, axis=1).rename(columns={self._target_col: self._col})
|
|
312
328
|
return result_df
|
|
@@ -318,17 +334,24 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
318
334
|
transformed_df = df.join(mapping_on_spark, on=self._col, how="left").withColumn(
|
|
319
335
|
"unknown_mask", sf.isnull(self._target_col)
|
|
320
336
|
)
|
|
321
|
-
unknown_label_count = transformed_df.select(sf.sum(sf.col("unknown_mask").cast("long"))).first()[
|
|
322
|
-
0
|
|
323
|
-
] # type: ignore
|
|
337
|
+
unknown_label_count = transformed_df.select(sf.sum(sf.col("unknown_mask").cast("long"))).first()[0]
|
|
324
338
|
if unknown_label_count > 0:
|
|
325
|
-
if self._handle_unknown == "
|
|
339
|
+
if self._handle_unknown == "drop":
|
|
340
|
+
transformed_df = transformed_df.filter("unknown_mask == False")
|
|
341
|
+
if transformed_df.rdd.isEmpty():
|
|
342
|
+
warnings.warn(
|
|
343
|
+
f"You are trying to transform dataframe with all values are unknown for {self._col}, "
|
|
344
|
+
"with `handle_unknown_strategy=drop` leads to empty dataframe",
|
|
345
|
+
LabelEncoderTransformWarning,
|
|
346
|
+
)
|
|
347
|
+
elif self._handle_unknown == "error":
|
|
326
348
|
collected_list = transformed_df.filter("unknown_mask == True").select(self._col).distinct().collect()
|
|
327
349
|
unique_labels = [row[self._col] for row in collected_list]
|
|
328
350
|
msg = f"Found unknown labels {unique_labels} in column {self._col} during transform"
|
|
329
351
|
raise ValueError(msg)
|
|
330
|
-
|
|
331
|
-
|
|
352
|
+
else:
|
|
353
|
+
if default_value:
|
|
354
|
+
transformed_df = transformed_df.fillna({self._target_col: default_value})
|
|
332
355
|
|
|
333
356
|
result_df = transformed_df.drop(self._col, "unknown_mask").withColumnRenamed(self._target_col, self._col)
|
|
334
357
|
return result_df
|
|
@@ -338,20 +361,27 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
338
361
|
[list(self.get_mapping().keys()), list(self.get_mapping().values())],
|
|
339
362
|
schema=[self._col, self._target_col],
|
|
340
363
|
)
|
|
341
|
-
mapping_on_polars = mapping_on_polars.with_columns(
|
|
342
|
-
pl.col(self._col).cast(df.get_column(self._col).dtype)
|
|
343
|
-
)
|
|
364
|
+
mapping_on_polars = mapping_on_polars.with_columns(pl.col(self._col).cast(df.get_column(self._col).dtype))
|
|
344
365
|
transformed_df = df.join(mapping_on_polars, on=self._col, how="left").with_columns(
|
|
345
366
|
pl.col(self._target_col).is_null().alias("unknown_mask")
|
|
346
367
|
)
|
|
347
368
|
unknown_df = transformed_df.filter(pl.col("unknown_mask"))
|
|
348
369
|
if not unknown_df.is_empty():
|
|
349
|
-
if self._handle_unknown == "
|
|
370
|
+
if self._handle_unknown == "drop":
|
|
371
|
+
transformed_df = transformed_df.filter(pl.col("unknown_mask") == "false")
|
|
372
|
+
if transformed_df.is_empty():
|
|
373
|
+
warnings.warn(
|
|
374
|
+
f"You are trying to transform dataframe with all values are unknown for {self._col}, "
|
|
375
|
+
"with `handle_unknown_strategy=drop` leads to empty dataframe",
|
|
376
|
+
LabelEncoderTransformWarning,
|
|
377
|
+
)
|
|
378
|
+
elif self._handle_unknown == "error":
|
|
350
379
|
unique_labels = unknown_df.select(self._col).unique().to_series().to_list()
|
|
351
380
|
msg = f"Found unknown labels {unique_labels} in column {self._col} during transform"
|
|
352
381
|
raise ValueError(msg)
|
|
353
|
-
|
|
354
|
-
|
|
382
|
+
else:
|
|
383
|
+
if default_value:
|
|
384
|
+
transformed_df = transformed_df.with_columns(pl.col(self._target_col).fill_null(default_value))
|
|
355
385
|
|
|
356
386
|
result_df = transformed_df.drop([self._col, "unknown_mask"]).rename({self._target_col: self._col})
|
|
357
387
|
return result_df
|
|
@@ -364,18 +394,20 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
364
394
|
:returns: transformed dataframe.
|
|
365
395
|
"""
|
|
366
396
|
if self._mapping is None:
|
|
367
|
-
|
|
397
|
+
msg = "Label encoder is not fitted"
|
|
398
|
+
raise RuntimeError(msg)
|
|
368
399
|
|
|
369
400
|
default_value = len(self._mapping) if self._default_value == "last" else self._default_value
|
|
370
401
|
|
|
371
402
|
if isinstance(df, PandasDataFrame):
|
|
372
|
-
transformed_df = self._transform_pandas(df, default_value)
|
|
403
|
+
transformed_df = self._transform_pandas(df, default_value)
|
|
373
404
|
elif isinstance(df, SparkDataFrame):
|
|
374
|
-
transformed_df = self._transform_spark(df, default_value)
|
|
405
|
+
transformed_df = self._transform_spark(df, default_value)
|
|
375
406
|
elif isinstance(df, PolarsDataFrame):
|
|
376
|
-
transformed_df = self._transform_polars(df, default_value)
|
|
407
|
+
transformed_df = self._transform_polars(df, default_value)
|
|
377
408
|
else:
|
|
378
|
-
|
|
409
|
+
msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
|
|
410
|
+
raise NotImplementedError(msg)
|
|
379
411
|
return transformed_df
|
|
380
412
|
|
|
381
413
|
def _inverse_transform_pandas(self, df: PandasDataFrame) -> PandasDataFrame:
|
|
@@ -414,7 +446,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
414
446
|
:returns: initial dataframe.
|
|
415
447
|
"""
|
|
416
448
|
if self._mapping is None:
|
|
417
|
-
|
|
449
|
+
msg = "Label encoder is not fitted"
|
|
450
|
+
raise RuntimeError(msg)
|
|
418
451
|
|
|
419
452
|
if isinstance(df, PandasDataFrame):
|
|
420
453
|
transformed_df = self._inverse_transform_pandas(df)
|
|
@@ -423,7 +456,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
423
456
|
elif isinstance(df, PolarsDataFrame):
|
|
424
457
|
transformed_df = self._inverse_transform_polars(df)
|
|
425
458
|
else:
|
|
426
|
-
|
|
459
|
+
msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
|
|
460
|
+
raise NotImplementedError(msg)
|
|
427
461
|
return transformed_df
|
|
428
462
|
|
|
429
463
|
def set_default_value(self, default_value: Optional[Union[int, str]]) -> None:
|
|
@@ -434,7 +468,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
434
468
|
:param default_value: default value.
|
|
435
469
|
"""
|
|
436
470
|
if default_value is not None and not isinstance(default_value, int) and default_value != "last":
|
|
437
|
-
|
|
471
|
+
msg = "Default value should be None, int or 'last'"
|
|
472
|
+
raise ValueError(msg)
|
|
438
473
|
self._default_value = default_value
|
|
439
474
|
|
|
440
475
|
def set_handle_unknown(self, handle_unknown: HandleUnknownStrategies) -> None:
|
|
@@ -444,7 +479,8 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
444
479
|
:param handle_unknown: handle unknown strategy.
|
|
445
480
|
"""
|
|
446
481
|
if handle_unknown not in self._HANDLE_UNKNOWN_STRATEGIES:
|
|
447
|
-
|
|
482
|
+
msg = f"handle_unknown should be either 'error' or 'use_default_value', got {handle_unknown}."
|
|
483
|
+
raise ValueError(msg)
|
|
448
484
|
self._handle_unknown = handle_unknown
|
|
449
485
|
|
|
450
486
|
|
|
@@ -582,11 +618,12 @@ class LabelEncoder:
|
|
|
582
618
|
If ``str`` value, should be \"last\" only, then fill by n_classes number.
|
|
583
619
|
Default ``None``.
|
|
584
620
|
"""
|
|
585
|
-
columns = [i.column for i in self.rules]
|
|
621
|
+
columns = [i.column for i in self.rules]
|
|
586
622
|
for column, handle_unknown in handle_unknown_rules.items():
|
|
587
623
|
if column not in columns:
|
|
588
|
-
|
|
589
|
-
|
|
624
|
+
msg = f"Column {column} not found."
|
|
625
|
+
raise ValueError(msg)
|
|
626
|
+
rule = list(filter(lambda x: x.column == column, self.rules))
|
|
590
627
|
rule[0].set_handle_unknown(handle_unknown)
|
|
591
628
|
|
|
592
629
|
def set_default_values(self, default_value_rules: Dict[str, Optional[Union[int, str]]]) -> None:
|
|
@@ -605,9 +642,10 @@ class LabelEncoder:
|
|
|
605
642
|
to the value given for the parameter default_value.
|
|
606
643
|
Default: ``error``.
|
|
607
644
|
"""
|
|
608
|
-
columns = [i.column for i in self.rules]
|
|
645
|
+
columns = [i.column for i in self.rules]
|
|
609
646
|
for column, default_value in default_value_rules.items():
|
|
610
647
|
if column not in columns:
|
|
611
|
-
|
|
612
|
-
|
|
648
|
+
msg = f"Column {column} not found."
|
|
649
|
+
raise ValueError(msg)
|
|
650
|
+
rule = list(filter(lambda x: x.column == column, self.rules))
|
|
613
651
|
rule[0].set_default_value(default_value)
|
|
@@ -10,7 +10,6 @@ if PYSPARK_AVAILABLE:
|
|
|
10
10
|
from pyspark.sql.window import Window
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
# pylint: disable=too-many-instance-attributes, too-few-public-methods
|
|
14
13
|
class Sessionizer:
|
|
15
14
|
"""
|
|
16
15
|
Create and filter sessions from given interactions.
|
|
@@ -51,7 +50,6 @@ class Sessionizer:
|
|
|
51
50
|
<BLANKLINE>
|
|
52
51
|
"""
|
|
53
52
|
|
|
54
|
-
# pylint: disable=too-many-arguments
|
|
55
53
|
def __init__(
|
|
56
54
|
self,
|
|
57
55
|
user_column: str = "user_id",
|
|
@@ -191,7 +189,6 @@ class Sessionizer:
|
|
|
191
189
|
Window.partitionBy(self.user_column).orderBy(sf.col(self.time_column), sf.col("timestamp_diff").desc())
|
|
192
190
|
),
|
|
193
191
|
)
|
|
194
|
-
# data_with_sum_timediff.cache()
|
|
195
192
|
|
|
196
193
|
grouped_users = data_with_sum_timediff.groupBy(self.user_column).count()
|
|
197
194
|
grouped_users_with_cumsum = grouped_users.withColumn(
|
|
@@ -212,11 +209,9 @@ class Sessionizer:
|
|
|
212
209
|
)
|
|
213
210
|
)
|
|
214
211
|
|
|
215
|
-
# data_with_sum_timediff.unpersist()
|
|
216
212
|
return result
|
|
217
213
|
|
|
218
214
|
def _filter_sessions(self, interactions: DataFrameLike) -> DataFrameLike:
|
|
219
|
-
# interactions.cache()
|
|
220
215
|
if isinstance(interactions, SparkDataFrame):
|
|
221
216
|
return self._filter_sessions_spark(interactions)
|
|
222
217
|
|
|
@@ -254,8 +249,6 @@ class Sessionizer:
|
|
|
254
249
|
entries_counter.select(self.session_column), self.session_column, how="right"
|
|
255
250
|
)
|
|
256
251
|
|
|
257
|
-
# filtered_interactions.cache()
|
|
258
|
-
|
|
259
252
|
nunique = filtered_interactions.groupby(self.user_column).agg(
|
|
260
253
|
sf.expr("count(distinct session_id)").alias("nunique")
|
|
261
254
|
)
|
|
@@ -284,9 +277,6 @@ class Sessionizer:
|
|
|
284
277
|
result = self._filter_sessions(result)
|
|
285
278
|
columns_order += [self.session_column]
|
|
286
279
|
|
|
287
|
-
if isinstance(result, SparkDataFrame)
|
|
288
|
-
result = result.select(*columns_order)
|
|
289
|
-
else:
|
|
290
|
-
result = result[columns_order]
|
|
280
|
+
result = result.select(*columns_order) if isinstance(result, SparkDataFrame) else result[columns_order]
|
|
291
281
|
|
|
292
282
|
return result
|
replay/scenarios/fallback.py
CHANGED
|
@@ -1,16 +1,14 @@
|
|
|
1
|
-
# pylint: disable=protected-access
|
|
2
1
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
3
2
|
|
|
4
3
|
from replay.data import Dataset
|
|
5
|
-
from replay.preprocessing.filters import MinCountFilter
|
|
6
4
|
from replay.metrics import NDCG, Metric
|
|
7
5
|
from replay.models import PopRec
|
|
8
6
|
from replay.models.base_rec import BaseRecommender
|
|
7
|
+
from replay.preprocessing.filters import MinCountFilter
|
|
9
8
|
from replay.utils import SparkDataFrame
|
|
10
9
|
from replay.utils.spark_utils import fallback, get_unique_entities
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
# pylint: disable=too-many-instance-attributes
|
|
14
12
|
class Fallback(BaseRecommender):
|
|
15
13
|
"""Fill missing recommendations using fallback model.
|
|
16
14
|
Behaves like a recommender and have the same interface."""
|
|
@@ -33,16 +31,15 @@ class Fallback(BaseRecommender):
|
|
|
33
31
|
self.threshold = threshold
|
|
34
32
|
self.hot_queries = None
|
|
35
33
|
self.main_model = main_model
|
|
36
|
-
# pylint: disable=invalid-name
|
|
37
34
|
self.fb_model = fallback_model
|
|
38
35
|
|
|
39
|
-
#
|
|
36
|
+
# TODO: add save/load for scenarios
|
|
40
37
|
@property
|
|
41
38
|
def _init_args(self):
|
|
42
39
|
return {"threshold": self.threshold}
|
|
43
40
|
|
|
44
41
|
def __str__(self):
|
|
45
|
-
return f"Fallback_{
|
|
42
|
+
return f"Fallback_{self.main_model!s}_{self.fb_model!s}"
|
|
46
43
|
|
|
47
44
|
def fit(
|
|
48
45
|
self,
|
|
@@ -67,7 +64,6 @@ class Fallback(BaseRecommender):
|
|
|
67
64
|
self._fit_wrap(hot_dataset)
|
|
68
65
|
self.fb_model._fit_wrap(dataset)
|
|
69
66
|
|
|
70
|
-
# pylint: disable=too-many-arguments
|
|
71
67
|
def predict(
|
|
72
68
|
self,
|
|
73
69
|
dataset: Dataset,
|
|
@@ -125,7 +121,6 @@ class Fallback(BaseRecommender):
|
|
|
125
121
|
pred = fallback(hot_pred, cold_pred, k)
|
|
126
122
|
return pred
|
|
127
123
|
|
|
128
|
-
# pylint: disable=too-many-arguments, too-many-locals
|
|
129
124
|
def optimize(
|
|
130
125
|
self,
|
|
131
126
|
train_dataset: Dataset,
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from abc import ABC, abstractmethod
|
|
3
|
+
from pathlib import Path
|
|
2
4
|
from typing import Optional, Tuple
|
|
3
5
|
|
|
4
6
|
import polars as pl
|
|
@@ -7,19 +9,20 @@ from replay.utils import (
|
|
|
7
9
|
PYSPARK_AVAILABLE,
|
|
8
10
|
DataFrameLike,
|
|
9
11
|
PandasDataFrame,
|
|
10
|
-
SparkDataFrame,
|
|
11
12
|
PolarsDataFrame,
|
|
13
|
+
SparkDataFrame,
|
|
12
14
|
)
|
|
13
15
|
|
|
14
16
|
if PYSPARK_AVAILABLE:
|
|
15
|
-
from pyspark.sql import
|
|
16
|
-
|
|
17
|
+
from pyspark.sql import (
|
|
18
|
+
Window,
|
|
19
|
+
functions as sf,
|
|
20
|
+
)
|
|
17
21
|
|
|
18
22
|
|
|
19
23
|
SplitterReturnType = Tuple[DataFrameLike, DataFrameLike]
|
|
20
24
|
|
|
21
25
|
|
|
22
|
-
# pylint: disable=too-few-public-methods, too-many-instance-attributes
|
|
23
26
|
class Splitter(ABC):
|
|
24
27
|
"""Base class"""
|
|
25
28
|
|
|
@@ -33,7 +36,6 @@ class Splitter(ABC):
|
|
|
33
36
|
"session_id_processing_strategy",
|
|
34
37
|
]
|
|
35
38
|
|
|
36
|
-
# pylint: disable=too-many-arguments
|
|
37
39
|
def __init__(
|
|
38
40
|
self,
|
|
39
41
|
drop_cold_items: bool = False,
|
|
@@ -68,17 +70,43 @@ class Splitter(ABC):
|
|
|
68
70
|
def _init_args(self):
|
|
69
71
|
return {name: getattr(self, name) for name in self._init_arg_names}
|
|
70
72
|
|
|
73
|
+
def save(self, path: str) -> None:
|
|
74
|
+
"""
|
|
75
|
+
Method for saving splitter in `.replay` directory.
|
|
76
|
+
"""
|
|
77
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
78
|
+
base_path.mkdir(parents=True, exist_ok=True)
|
|
79
|
+
|
|
80
|
+
splitter_dict = {}
|
|
81
|
+
splitter_dict["init_args"] = self._init_args
|
|
82
|
+
splitter_dict["_class_name"] = str(self)
|
|
83
|
+
|
|
84
|
+
with open(base_path / "init_args.json", "w+") as file:
|
|
85
|
+
json.dump(splitter_dict, file)
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def load(cls, path: str) -> "Splitter":
|
|
89
|
+
"""
|
|
90
|
+
Method for loading splitter from `.replay` directory.
|
|
91
|
+
"""
|
|
92
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
93
|
+
with open(base_path / "init_args.json", "r") as file:
|
|
94
|
+
splitter_dict = json.loads(file.read())
|
|
95
|
+
splitter = cls(**splitter_dict["init_args"])
|
|
96
|
+
|
|
97
|
+
return splitter
|
|
98
|
+
|
|
71
99
|
def __str__(self):
|
|
72
100
|
return type(self).__name__
|
|
73
101
|
|
|
74
|
-
# pylint: disable=too-many-arguments
|
|
75
102
|
def _drop_cold_items_and_users(
|
|
76
103
|
self,
|
|
77
104
|
train: DataFrameLike,
|
|
78
105
|
test: DataFrameLike,
|
|
79
106
|
) -> DataFrameLike:
|
|
80
107
|
if isinstance(train, type(test)) is False:
|
|
81
|
-
|
|
108
|
+
msg = "Train and test dataframes must have consistent types"
|
|
109
|
+
raise TypeError(msg)
|
|
82
110
|
|
|
83
111
|
if isinstance(test, SparkDataFrame):
|
|
84
112
|
return self._drop_cold_items_and_users_from_spark(train, test)
|
|
@@ -105,7 +133,6 @@ class Splitter(ABC):
|
|
|
105
133
|
train: SparkDataFrame,
|
|
106
134
|
test: SparkDataFrame,
|
|
107
135
|
) -> SparkDataFrame:
|
|
108
|
-
|
|
109
136
|
if self.drop_cold_items:
|
|
110
137
|
train_tmp = train.select(sf.col(self.item_column).alias("item")).distinct()
|
|
111
138
|
test = test.join(train_tmp, train_tmp["item"] == test[self.item_column]).drop("item")
|
|
@@ -121,7 +148,6 @@ class Splitter(ABC):
|
|
|
121
148
|
train: PolarsDataFrame,
|
|
122
149
|
test: PolarsDataFrame,
|
|
123
150
|
) -> PolarsDataFrame:
|
|
124
|
-
|
|
125
151
|
if self.drop_cold_items:
|
|
126
152
|
train_tmp = train.select(self.item_column).unique()
|
|
127
153
|
test = test.join(train_tmp, on=self.item_column)
|
|
@@ -164,9 +190,9 @@ class Splitter(ABC):
|
|
|
164
190
|
def _recalculate_with_session_id_column_pandas(self, data: PandasDataFrame) -> PandasDataFrame:
|
|
165
191
|
agg_function_name = "first" if self.session_id_processing_strategy == "train" else "last"
|
|
166
192
|
res = data.copy()
|
|
167
|
-
res["is_test"] = res.groupby(
|
|
168
|
-
|
|
169
|
-
)
|
|
193
|
+
res["is_test"] = res.groupby([self.query_column, self.session_id_column])["is_test"].transform(
|
|
194
|
+
agg_function_name
|
|
195
|
+
)
|
|
170
196
|
|
|
171
197
|
return res
|
|
172
198
|
|
|
@@ -176,7 +202,7 @@ class Splitter(ABC):
|
|
|
176
202
|
"is_test",
|
|
177
203
|
agg_function("is_test").over(
|
|
178
204
|
Window.orderBy(self.timestamp_column)
|
|
179
|
-
.partitionBy(self.query_column, self.session_id_column)
|
|
205
|
+
.partitionBy(self.query_column, self.session_id_column)
|
|
180
206
|
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
|
181
207
|
),
|
|
182
208
|
)
|
|
@@ -186,7 +212,9 @@ class Splitter(ABC):
|
|
|
186
212
|
def _recalculate_with_session_id_column_polars(self, data: PolarsDataFrame) -> PolarsDataFrame:
|
|
187
213
|
agg_function = pl.Expr.first if self.session_id_processing_strategy == "train" else pl.Expr.last
|
|
188
214
|
res = data.with_columns(
|
|
189
|
-
agg_function(pl.col("is_test").sort_by(self.timestamp_column))
|
|
190
|
-
|
|
215
|
+
agg_function(pl.col("is_test").sort_by(self.timestamp_column)).over(
|
|
216
|
+
[self.query_column, self.session_id_column]
|
|
217
|
+
)
|
|
218
|
+
)
|
|
191
219
|
|
|
192
220
|
return res
|