replay-rec 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +27 -1
- replay/data/dataset_utils/dataset_label_encoder.py +6 -3
- replay/data/nn/schema.py +37 -16
- replay/data/nn/sequence_tokenizer.py +313 -165
- replay/data/nn/torch_sequential_dataset.py +17 -8
- replay/data/nn/utils.py +14 -7
- replay/data/schema.py +10 -6
- replay/metrics/offline_metrics.py +2 -2
- replay/models/__init__.py +1 -0
- replay/models/base_rec.py +18 -21
- replay/models/lin_ucb.py +407 -0
- replay/models/nn/sequential/bert4rec/dataset.py +17 -4
- replay/models/nn/sequential/bert4rec/lightning.py +121 -54
- replay/models/nn/sequential/bert4rec/model.py +21 -0
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +5 -1
- replay/models/nn/sequential/compiled/__init__.py +5 -0
- replay/models/nn/sequential/compiled/base_compiled_model.py +261 -0
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +152 -0
- replay/models/nn/sequential/compiled/sasrec_compiled.py +145 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +27 -1
- replay/models/nn/sequential/sasrec/dataset.py +17 -1
- replay/models/nn/sequential/sasrec/lightning.py +126 -50
- replay/models/nn/sequential/sasrec/model.py +3 -4
- replay/preprocessing/__init__.py +7 -1
- replay/preprocessing/discretizer.py +719 -0
- replay/preprocessing/label_encoder.py +384 -52
- replay/splitters/cold_user_random_splitter.py +1 -1
- replay/utils/__init__.py +1 -0
- replay/utils/common.py +7 -8
- replay/utils/session_handler.py +3 -4
- replay/utils/spark_utils.py +15 -1
- replay/utils/types.py +8 -0
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/METADATA +73 -60
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/RECORD +37 -31
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/LICENSE +0 -0
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,719 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import warnings
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, List, Literal, Sequence
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import polars as pl
|
|
10
|
+
from sklearn.preprocessing import KBinsDiscretizer
|
|
11
|
+
|
|
12
|
+
from replay.utils import (
|
|
13
|
+
PYSPARK_AVAILABLE,
|
|
14
|
+
DataFrameLike,
|
|
15
|
+
PandasDataFrame,
|
|
16
|
+
PolarsDataFrame,
|
|
17
|
+
SparkDataFrame,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if PYSPARK_AVAILABLE: # pragma: no cover
|
|
21
|
+
from pyspark.ml.feature import Bucketizer, QuantileDiscretizer
|
|
22
|
+
from pyspark.sql.functions import isnan
|
|
23
|
+
|
|
24
|
+
HandleInvalidStrategies = Literal["error", "skip", "keep"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class BaseDiscretizingRule(abc.ABC): # pragma: no cover
|
|
28
|
+
"""
|
|
29
|
+
Interface of the discretizing rule
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
@abc.abstractmethod
|
|
34
|
+
def column(self) -> str:
|
|
35
|
+
raise NotImplementedError()
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def n_bins(self) -> int:
|
|
40
|
+
raise NotImplementedError()
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def fit(self, df: DataFrameLike) -> "BaseDiscretizingRule":
|
|
44
|
+
raise NotImplementedError()
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def partial_fit(self, df: DataFrameLike) -> "BaseDiscretizingRule":
|
|
48
|
+
raise NotImplementedError()
|
|
49
|
+
|
|
50
|
+
@abc.abstractmethod
|
|
51
|
+
def transform(self, df: DataFrameLike) -> DataFrameLike:
|
|
52
|
+
raise NotImplementedError()
|
|
53
|
+
|
|
54
|
+
@abc.abstractmethod
|
|
55
|
+
def set_handle_invalid(self, handle_invalid: HandleInvalidStrategies) -> None:
|
|
56
|
+
raise NotImplementedError()
|
|
57
|
+
|
|
58
|
+
def fit_transform(self, df: DataFrameLike) -> DataFrameLike:
|
|
59
|
+
return self.fit(df).transform(df)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GreedyDiscretizingRule(BaseDiscretizingRule):
|
|
63
|
+
"""
|
|
64
|
+
Implementation of the Discretizing rule for a column of PySpark, Polars and Pandas DataFrames.
|
|
65
|
+
Discretizes column values according to the Greedy binning strategy:
|
|
66
|
+
https://github.com/microsoft/LightGBM/blob/master/src/io/bin.cpp#L78
|
|
67
|
+
It is recommended to use together with the Discretizer.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
_HANDLE_INVALID_STRATEGIES = ("error", "skip", "keep")
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
column: str,
|
|
75
|
+
n_bins: int,
|
|
76
|
+
min_data_in_bin: int = 1,
|
|
77
|
+
handle_invalid: HandleInvalidStrategies = "keep",
|
|
78
|
+
) -> None:
|
|
79
|
+
"""
|
|
80
|
+
:param column: Name of the column to discretize.
|
|
81
|
+
:param n_bins: Number of intervals where data will be binned.
|
|
82
|
+
:param min_data_in_bin: Minimum number of samples in one bin.
|
|
83
|
+
:param handle_invalid: handle_invalid rule.
|
|
84
|
+
indicates how to process NaN in data.
|
|
85
|
+
If ``skip`` - filter out rows with invalid values.
|
|
86
|
+
If ``error`` - throw an error.
|
|
87
|
+
If ``keep`` - keep invalid values in a special additional bucket with number = n_bins.
|
|
88
|
+
Default ``keep``.
|
|
89
|
+
"""
|
|
90
|
+
self._n_bins = n_bins
|
|
91
|
+
self._col = column
|
|
92
|
+
self._min_data_in_bin = min_data_in_bin
|
|
93
|
+
self._bins = None
|
|
94
|
+
self._is_fitted = False
|
|
95
|
+
|
|
96
|
+
if handle_invalid not in self._HANDLE_INVALID_STRATEGIES:
|
|
97
|
+
msg = f"handle_invalid should be either 'error' or 'skip' or 'keep', got {handle_invalid}."
|
|
98
|
+
raise ValueError(msg)
|
|
99
|
+
self._handle_invalid = handle_invalid
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def column(self) -> str:
|
|
103
|
+
return self._col
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def n_bins(self) -> str:
|
|
107
|
+
return self._n_bins
|
|
108
|
+
|
|
109
|
+
def _greedy_bin_find(
|
|
110
|
+
self,
|
|
111
|
+
distinct_values: np.ndarray,
|
|
112
|
+
counts: np.ndarray,
|
|
113
|
+
num_distinct_values: int,
|
|
114
|
+
max_bin: int,
|
|
115
|
+
total_cnt: int,
|
|
116
|
+
min_data_in_bin: int,
|
|
117
|
+
) -> List[float]:
|
|
118
|
+
"""
|
|
119
|
+
Computes bound for bins.
|
|
120
|
+
|
|
121
|
+
:param distinct_values: Array of unique values.
|
|
122
|
+
:param counts: Number of samples corresponding to the every unique value.
|
|
123
|
+
:param num_distinct_values: Number of unique value.
|
|
124
|
+
:param max_bin: Maximum bin number.
|
|
125
|
+
:param total_cnt: Total number of samples.
|
|
126
|
+
:param min_data_in_bin: Minimum number of samples in one bin.
|
|
127
|
+
"""
|
|
128
|
+
bin_upper_bound = []
|
|
129
|
+
assert max_bin > 0
|
|
130
|
+
|
|
131
|
+
if total_cnt < max_bin * min_data_in_bin:
|
|
132
|
+
warn_msg = f"Expected at least {max_bin*min_data_in_bin} samples (n_bins*min_data_in_bin) \
|
|
133
|
+
= ({self._n_bins}*{min_data_in_bin}). Got {total_cnt}. The number of bins will be less in the result"
|
|
134
|
+
warnings.warn(warn_msg)
|
|
135
|
+
if num_distinct_values <= max_bin:
|
|
136
|
+
cur_cnt_inbin = 0
|
|
137
|
+
for i in range(num_distinct_values - 1):
|
|
138
|
+
cur_cnt_inbin += counts[i]
|
|
139
|
+
if cur_cnt_inbin >= min_data_in_bin:
|
|
140
|
+
bin_upper_bound.append((distinct_values[i] + distinct_values[i + 1]) / 2.0)
|
|
141
|
+
cur_cnt_inbin = 0
|
|
142
|
+
|
|
143
|
+
cur_cnt_inbin += counts[num_distinct_values - 1]
|
|
144
|
+
bin_upper_bound.append(float("Inf"))
|
|
145
|
+
|
|
146
|
+
else:
|
|
147
|
+
if min_data_in_bin > 0:
|
|
148
|
+
max_bin = min(max_bin, total_cnt // min_data_in_bin)
|
|
149
|
+
max_bin = max(max_bin, 1)
|
|
150
|
+
mean_bin_size = total_cnt / max_bin
|
|
151
|
+
rest_bin_cnt = max_bin
|
|
152
|
+
rest_sample_cnt = total_cnt
|
|
153
|
+
|
|
154
|
+
is_big_count_value = counts >= mean_bin_size
|
|
155
|
+
rest_bin_cnt -= np.sum(is_big_count_value)
|
|
156
|
+
rest_sample_cnt -= np.sum(counts[is_big_count_value])
|
|
157
|
+
|
|
158
|
+
mean_bin_size = rest_sample_cnt / rest_bin_cnt
|
|
159
|
+
upper_bounds = [float("Inf")] * max_bin
|
|
160
|
+
lower_bounds = [float("Inf")] * max_bin
|
|
161
|
+
|
|
162
|
+
bin_cnt = 0
|
|
163
|
+
lower_bounds[bin_cnt] = distinct_values[0]
|
|
164
|
+
cur_cnt_inbin = 0
|
|
165
|
+
|
|
166
|
+
for i in range(num_distinct_values - 1): # pragma: no cover
|
|
167
|
+
if not is_big_count_value[i]:
|
|
168
|
+
rest_sample_cnt -= counts[i]
|
|
169
|
+
|
|
170
|
+
cur_cnt_inbin += counts[i]
|
|
171
|
+
|
|
172
|
+
if (
|
|
173
|
+
is_big_count_value[i]
|
|
174
|
+
or cur_cnt_inbin >= mean_bin_size
|
|
175
|
+
or is_big_count_value[i + 1]
|
|
176
|
+
and cur_cnt_inbin >= max(1.0, mean_bin_size * 0.5)
|
|
177
|
+
):
|
|
178
|
+
upper_bounds[bin_cnt] = distinct_values[i]
|
|
179
|
+
bin_cnt += 1
|
|
180
|
+
lower_bounds[bin_cnt] = distinct_values[i + 1]
|
|
181
|
+
if bin_cnt >= max_bin - 1:
|
|
182
|
+
break
|
|
183
|
+
cur_cnt_inbin = 0
|
|
184
|
+
if not is_big_count_value[i]:
|
|
185
|
+
rest_bin_cnt -= 1
|
|
186
|
+
mean_bin_size = rest_sample_cnt / rest_bin_cnt
|
|
187
|
+
|
|
188
|
+
bin_upper_bound = [(upper_bounds[i] + lower_bounds[i + 1]) / 2.0 for i in range(bin_cnt - 1)]
|
|
189
|
+
bin_upper_bound.append(float("Inf"))
|
|
190
|
+
return bin_upper_bound
|
|
191
|
+
|
|
192
|
+
def _fit_spark(self, df: SparkDataFrame) -> None:
|
|
193
|
+
warn_msg = "DataFrame will be partially converted to the Pandas type during internal calculations in 'fit'"
|
|
194
|
+
warnings.warn(warn_msg)
|
|
195
|
+
value_counts = df.groupBy(self._col).count().orderBy(self._col).toPandas()
|
|
196
|
+
bins = [-float("inf")]
|
|
197
|
+
bins += self._greedy_bin_find(
|
|
198
|
+
value_counts[self._col].values,
|
|
199
|
+
value_counts["count"].values,
|
|
200
|
+
value_counts.shape[0],
|
|
201
|
+
self._n_bins + 1,
|
|
202
|
+
df.count(),
|
|
203
|
+
self._min_data_in_bin,
|
|
204
|
+
)
|
|
205
|
+
self._bins = bins
|
|
206
|
+
|
|
207
|
+
def _fit_pandas(self, df: PandasDataFrame) -> None:
|
|
208
|
+
vc = df[self._col].value_counts().sort_index()
|
|
209
|
+
bins = self._greedy_bin_find(
|
|
210
|
+
vc.index.values, vc.values, len(vc), self._n_bins + 1, vc.sum(), self._min_data_in_bin
|
|
211
|
+
)
|
|
212
|
+
self._bins = [-np.inf, *bins]
|
|
213
|
+
|
|
214
|
+
def _fit_polars(self, df: PolarsDataFrame) -> None:
|
|
215
|
+
warn_msg = "DataFrame will be converted to the Pandas type during internal calculations in 'fit'"
|
|
216
|
+
warnings.warn(warn_msg)
|
|
217
|
+
self._fit_pandas(df.to_pandas())
|
|
218
|
+
|
|
219
|
+
def fit(self, df: DataFrameLike) -> "GreedyDiscretizingRule":
|
|
220
|
+
"""
|
|
221
|
+
Fits Discretizing Rule to input dataframe.
|
|
222
|
+
|
|
223
|
+
:param df: input dataframe.
|
|
224
|
+
:returns: fitted DiscretizingRule.
|
|
225
|
+
"""
|
|
226
|
+
if self._is_fitted:
|
|
227
|
+
return self
|
|
228
|
+
|
|
229
|
+
df = self._validate_input(df)
|
|
230
|
+
|
|
231
|
+
if isinstance(df, PandasDataFrame):
|
|
232
|
+
self._fit_pandas(df)
|
|
233
|
+
elif isinstance(df, SparkDataFrame):
|
|
234
|
+
self._fit_spark(df)
|
|
235
|
+
else:
|
|
236
|
+
self._fit_polars(df)
|
|
237
|
+
|
|
238
|
+
self._is_fitted = True
|
|
239
|
+
return self
|
|
240
|
+
|
|
241
|
+
def partial_fit(self, df: DataFrameLike) -> "GreedyDiscretizingRule":
|
|
242
|
+
"""
|
|
243
|
+
Fits new data to already fitted DiscretizingRule.
|
|
244
|
+
|
|
245
|
+
:param df: input dataframe.
|
|
246
|
+
:returns: fitted DiscretizingRule.
|
|
247
|
+
"""
|
|
248
|
+
if not self._is_fitted:
|
|
249
|
+
return self.fit(df)
|
|
250
|
+
|
|
251
|
+
msg = f"{self.__class__.__name__} is not implemented for partial_fit yet."
|
|
252
|
+
raise NotImplementedError(msg)
|
|
253
|
+
|
|
254
|
+
def _transform_pandas(self, df: PandasDataFrame) -> PandasDataFrame:
|
|
255
|
+
binned_column = np.digitize(df[self._col].values, self._bins)
|
|
256
|
+
binned_column -= 1
|
|
257
|
+
df_transformed = df.copy()
|
|
258
|
+
df_transformed.loc[:, self._col] = binned_column
|
|
259
|
+
return df_transformed
|
|
260
|
+
|
|
261
|
+
def _transform_spark(self, df: SparkDataFrame) -> SparkDataFrame:
|
|
262
|
+
target_col = self._col + "_discretized"
|
|
263
|
+
bucketizer = Bucketizer(
|
|
264
|
+
splits=self._bins, inputCol=self._col, outputCol=target_col, handleInvalid=self._handle_invalid
|
|
265
|
+
)
|
|
266
|
+
return bucketizer.transform(df).drop(self._col).withColumnRenamed(target_col, self._col)
|
|
267
|
+
|
|
268
|
+
def _transform_polars(self, df: PolarsDataFrame) -> PolarsDataFrame:
|
|
269
|
+
warn_msg = "DataFrame will be converted to the Pandas type during internal calculations in 'transform'"
|
|
270
|
+
warnings.warn(warn_msg)
|
|
271
|
+
return pl.from_pandas(self._transform_pandas(df.to_pandas()))
|
|
272
|
+
|
|
273
|
+
def transform(self, df: DataFrameLike) -> DataFrameLike:
|
|
274
|
+
"""
|
|
275
|
+
Transforms input dataframe with fitted DiscretizingRule.
|
|
276
|
+
|
|
277
|
+
:param df: input dataframe.
|
|
278
|
+
:returns: transformed dataframe.
|
|
279
|
+
"""
|
|
280
|
+
if not self._is_fitted:
|
|
281
|
+
msg = "Discretizer is not fitted"
|
|
282
|
+
raise RuntimeError(msg)
|
|
283
|
+
|
|
284
|
+
df = self._validate_input(df)
|
|
285
|
+
|
|
286
|
+
if isinstance(df, PandasDataFrame):
|
|
287
|
+
transformed_df = self._transform_pandas(df)
|
|
288
|
+
elif isinstance(df, SparkDataFrame):
|
|
289
|
+
transformed_df = self._transform_spark(df)
|
|
290
|
+
else:
|
|
291
|
+
transformed_df = self._transform_polars(df)
|
|
292
|
+
return transformed_df
|
|
293
|
+
|
|
294
|
+
def set_handle_invalid(self, handle_invalid: HandleInvalidStrategies) -> None:
|
|
295
|
+
"""
|
|
296
|
+
Sets strategy to handle invalid values.
|
|
297
|
+
|
|
298
|
+
:param handle_invalid: handle invalid strategy.
|
|
299
|
+
"""
|
|
300
|
+
if handle_invalid not in self._HANDLE_INVALID_STRATEGIES:
|
|
301
|
+
msg = f"handle_invalid should be either 'error' or 'skip' or 'keep', got {handle_invalid}."
|
|
302
|
+
raise ValueError(msg)
|
|
303
|
+
self._handle_invalid = handle_invalid
|
|
304
|
+
|
|
305
|
+
def _validate_input(self, df: DataFrameLike) -> DataFrameLike:
|
|
306
|
+
if isinstance(df, PandasDataFrame):
|
|
307
|
+
df_val = df.copy()
|
|
308
|
+
if (self._handle_invalid == "error") and (df_val[self._col].isna().sum() > 0):
|
|
309
|
+
msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
|
|
310
|
+
Set 'keep' or 'skip' for processing NaN."
|
|
311
|
+
raise ValueError(msg)
|
|
312
|
+
if self._handle_invalid == "skip":
|
|
313
|
+
df_val = df_val.dropna(subset=[self._col], axis=0)
|
|
314
|
+
return df_val
|
|
315
|
+
|
|
316
|
+
elif isinstance(df, SparkDataFrame):
|
|
317
|
+
if (self._handle_invalid == "error") and (df.filter(isnan(df[self._col])).count() > 0):
|
|
318
|
+
msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
|
|
319
|
+
Set 'keep' or 'skip' for processing NaN."
|
|
320
|
+
raise ValueError(msg)
|
|
321
|
+
return df
|
|
322
|
+
|
|
323
|
+
elif isinstance(df, PolarsDataFrame):
|
|
324
|
+
if (self._handle_invalid == "error") and (df[self._col].is_null().sum() > 0):
|
|
325
|
+
msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
|
|
326
|
+
Set 'keep' or 'skip' for processing NaN."
|
|
327
|
+
raise ValueError(msg)
|
|
328
|
+
if self._handle_invalid == "skip":
|
|
329
|
+
df = df.clone().fill_nan(None).drop_nulls(subset=[self._col])
|
|
330
|
+
return df
|
|
331
|
+
|
|
332
|
+
else:
|
|
333
|
+
msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
|
|
334
|
+
raise NotImplementedError(msg)
|
|
335
|
+
|
|
336
|
+
def save(
|
|
337
|
+
self,
|
|
338
|
+
path: str,
|
|
339
|
+
) -> None:
|
|
340
|
+
discretizer_rule_dict = {}
|
|
341
|
+
discretizer_rule_dict["_class_name"] = self.__class__.__name__
|
|
342
|
+
discretizer_rule_dict["init_args"] = {
|
|
343
|
+
"n_bins": self._n_bins,
|
|
344
|
+
"column": self._col,
|
|
345
|
+
"min_data_in_bin": self._min_data_in_bin,
|
|
346
|
+
"handle_invalid": self._handle_invalid,
|
|
347
|
+
}
|
|
348
|
+
discretizer_rule_dict["fitted_args"] = {
|
|
349
|
+
"bins": self._bins,
|
|
350
|
+
"is_fitted": self._is_fitted,
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
354
|
+
|
|
355
|
+
if os.path.exists(base_path): # pragma: no cover
|
|
356
|
+
msg = "There is already DiscretizingRule object saved at the given path. File will be overwrited."
|
|
357
|
+
warnings.warn(msg)
|
|
358
|
+
else: # pragma: no cover
|
|
359
|
+
base_path.mkdir(parents=True, exist_ok=True)
|
|
360
|
+
|
|
361
|
+
with open(base_path / "init_args.json", "w+") as file:
|
|
362
|
+
json.dump(discretizer_rule_dict, file)
|
|
363
|
+
|
|
364
|
+
@classmethod
|
|
365
|
+
def load(cls, path: str) -> "GreedyDiscretizingRule":
|
|
366
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
367
|
+
with open(base_path / "init_args.json", "r") as file:
|
|
368
|
+
discretizer_rule_dict = json.loads(file.read())
|
|
369
|
+
|
|
370
|
+
discretizer_rule = cls(**discretizer_rule_dict["init_args"])
|
|
371
|
+
discretizer_rule._bins = discretizer_rule_dict["fitted_args"]["bins"]
|
|
372
|
+
discretizer_rule._is_fitted = discretizer_rule_dict["fitted_args"]["is_fitted"]
|
|
373
|
+
return discretizer_rule
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class QuantileDiscretizingRule(BaseDiscretizingRule):
|
|
377
|
+
"""
|
|
378
|
+
Implementation of the Discretizing rule for a column of PySpark, Polars and Pandas DataFrames.
|
|
379
|
+
Discretizes columns values according to the Quantile strategy. All the data will be distributed
|
|
380
|
+
into buckets with approximately same size.
|
|
381
|
+
It is recommended to use together with the Discretizer.
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
_HANDLE_INVALID_STRATEGIES = ("error", "skip", "keep")
|
|
385
|
+
|
|
386
|
+
def __init__(
|
|
387
|
+
self,
|
|
388
|
+
column: str,
|
|
389
|
+
n_bins: int,
|
|
390
|
+
handle_invalid: HandleInvalidStrategies = "keep",
|
|
391
|
+
) -> None:
|
|
392
|
+
"""
|
|
393
|
+
:param column: Name of the column to discretize.
|
|
394
|
+
:param n_bins: Number of intervals where data will be binned.
|
|
395
|
+
:param handle_invalid: handle_invalid rule.
|
|
396
|
+
indicates how to process NaN in data.
|
|
397
|
+
If ``skip`` - filter out rows with invalid values.
|
|
398
|
+
If ``error`` - throw an error.
|
|
399
|
+
If ``keep`` - keep invalid values in a special additional bucket with number = n_bins.
|
|
400
|
+
Default ``keep``.
|
|
401
|
+
"""
|
|
402
|
+
self._n_bins = n_bins
|
|
403
|
+
self._col = column
|
|
404
|
+
self._bins = None
|
|
405
|
+
self._discretizer = None
|
|
406
|
+
self._is_fitted = False
|
|
407
|
+
|
|
408
|
+
if handle_invalid not in self._HANDLE_INVALID_STRATEGIES:
|
|
409
|
+
msg = f"handle_invalid should be either 'error' or 'ski[]' or 'keep', got {handle_invalid}."
|
|
410
|
+
raise ValueError(msg)
|
|
411
|
+
self._handle_invalid = handle_invalid
|
|
412
|
+
|
|
413
|
+
@property
|
|
414
|
+
def column(self) -> str:
|
|
415
|
+
return self._col
|
|
416
|
+
|
|
417
|
+
@property
|
|
418
|
+
def n_bins(self) -> str:
|
|
419
|
+
return self._n_bins
|
|
420
|
+
|
|
421
|
+
def _fit_spark(self, df: SparkDataFrame) -> None:
|
|
422
|
+
discretizer = QuantileDiscretizer(
|
|
423
|
+
numBuckets=self._n_bins, inputCol=self._col, handleInvalid=self._handle_invalid
|
|
424
|
+
)
|
|
425
|
+
self._discretizer = discretizer.fit(df)
|
|
426
|
+
self._bins = self._discretizer.getSplits()
|
|
427
|
+
|
|
428
|
+
def _fit_pandas(self, df: PandasDataFrame) -> None:
|
|
429
|
+
discretizer = KBinsDiscretizer(n_bins=self._n_bins, encode="ordinal", strategy="quantile")
|
|
430
|
+
if self._handle_invalid == "keep":
|
|
431
|
+
self._discretizer = discretizer.fit(df.dropna(subset=[self._col], axis=0)[[self._col]])
|
|
432
|
+
else:
|
|
433
|
+
self._discretizer = discretizer.fit(df[[self._col]])
|
|
434
|
+
self._bins = self._discretizer.bin_edges_[0].astype(float).tolist()
|
|
435
|
+
self._bins[0] = -np.inf
|
|
436
|
+
self._bins[-1] = np.inf
|
|
437
|
+
|
|
438
|
+
def _fit_polars(self, df: PolarsDataFrame) -> None:
|
|
439
|
+
warn_msg = "DataFrame will be converted to the Pandas type during internal calculations in 'fit'"
|
|
440
|
+
warnings.warn(warn_msg)
|
|
441
|
+
self._fit_pandas(df.to_pandas())
|
|
442
|
+
|
|
443
|
+
def fit(self, df: DataFrameLike) -> "GreedyDiscretizingRule":
|
|
444
|
+
"""
|
|
445
|
+
Fits DiscretizingRule to input dataframe.
|
|
446
|
+
|
|
447
|
+
:param df: input dataframe.
|
|
448
|
+
:returns: fitted DiscretizingRule.
|
|
449
|
+
"""
|
|
450
|
+
if self._is_fitted:
|
|
451
|
+
return self
|
|
452
|
+
|
|
453
|
+
df = self._validate_input(df)
|
|
454
|
+
|
|
455
|
+
if isinstance(df, PandasDataFrame):
|
|
456
|
+
self._fit_pandas(df)
|
|
457
|
+
elif isinstance(df, SparkDataFrame):
|
|
458
|
+
self._fit_spark(df)
|
|
459
|
+
else:
|
|
460
|
+
self._fit_polars(df)
|
|
461
|
+
|
|
462
|
+
self._is_fitted = True
|
|
463
|
+
return self
|
|
464
|
+
|
|
465
|
+
def partial_fit(self, df: DataFrameLike):
|
|
466
|
+
"""
|
|
467
|
+
Fits new data to already fitted DiscretizingRule.
|
|
468
|
+
|
|
469
|
+
:param df: input dataframe.
|
|
470
|
+
:returns: fitted DiscretizingRule.
|
|
471
|
+
"""
|
|
472
|
+
if not self._is_fitted:
|
|
473
|
+
return self.fit(df)
|
|
474
|
+
|
|
475
|
+
msg = f"{self.__class__.__name__} is not implemented for partial_fit yet."
|
|
476
|
+
raise NotImplementedError(msg)
|
|
477
|
+
|
|
478
|
+
def _transform_pandas(self, df: PandasDataFrame) -> PandasDataFrame:
|
|
479
|
+
df_nan_part = df[df[self._col].isna()]
|
|
480
|
+
df_real_part = df[~df[self._col].isna()]
|
|
481
|
+
|
|
482
|
+
binned_column = np.digitize(df_real_part[self._col].values, self._bins)
|
|
483
|
+
binned_column -= 1
|
|
484
|
+
|
|
485
|
+
df_transformed = df.copy()
|
|
486
|
+
df_transformed.loc[df_real_part.index, self._col] = binned_column
|
|
487
|
+
df_transformed.loc[df_nan_part.index, self._col] = [self._n_bins] * len(df_nan_part)
|
|
488
|
+
return df_transformed
|
|
489
|
+
|
|
490
|
+
def _transform_spark(self, df: SparkDataFrame) -> SparkDataFrame:
|
|
491
|
+
target_col = self._col + "_discretized"
|
|
492
|
+
bucketizer = Bucketizer(
|
|
493
|
+
splits=self._bins, inputCol=self._col, outputCol=target_col, handleInvalid=self._handle_invalid
|
|
494
|
+
)
|
|
495
|
+
return bucketizer.transform(df).drop(self._col).withColumnRenamed(target_col, self._col)
|
|
496
|
+
|
|
497
|
+
def _transform_polars(self, df: PolarsDataFrame) -> SparkDataFrame:
|
|
498
|
+
warn_msg = "DataFrame will be converted to the Pandas type during internal calculations in 'transform'"
|
|
499
|
+
warnings.warn(warn_msg)
|
|
500
|
+
return pl.from_pandas(self._transform_pandas(df.to_pandas()))
|
|
501
|
+
|
|
502
|
+
def transform(self, df: DataFrameLike) -> DataFrameLike:
|
|
503
|
+
"""
|
|
504
|
+
Transforms input dataframe with fitted DiscretizingRule.
|
|
505
|
+
|
|
506
|
+
:param df: input dataframe.
|
|
507
|
+
:returns: transformed dataframe.
|
|
508
|
+
"""
|
|
509
|
+
if not self._is_fitted:
|
|
510
|
+
msg = "Discretizer is not fitted"
|
|
511
|
+
raise RuntimeError(msg)
|
|
512
|
+
|
|
513
|
+
df = self._validate_input(df)
|
|
514
|
+
|
|
515
|
+
if isinstance(df, PandasDataFrame):
|
|
516
|
+
transformed_df = self._transform_pandas(df)
|
|
517
|
+
elif isinstance(df, SparkDataFrame):
|
|
518
|
+
transformed_df = self._transform_spark(df)
|
|
519
|
+
else:
|
|
520
|
+
transformed_df = self._transform_polars(df)
|
|
521
|
+
return transformed_df
|
|
522
|
+
|
|
523
|
+
def set_handle_invalid(self, handle_invalid: HandleInvalidStrategies) -> None:
|
|
524
|
+
"""
|
|
525
|
+
Sets strategy to handle invalid values.
|
|
526
|
+
|
|
527
|
+
:param handle_invalid: handle invalid strategy.
|
|
528
|
+
"""
|
|
529
|
+
if handle_invalid not in self._HANDLE_INVALID_STRATEGIES:
|
|
530
|
+
msg = f"handle_invalid should be either 'error' or 'skip' or 'keep', got {handle_invalid}."
|
|
531
|
+
raise ValueError(msg)
|
|
532
|
+
self._handle_invalid = handle_invalid
|
|
533
|
+
|
|
534
|
+
def _validate_input(self, df: DataFrameLike) -> DataFrameLike:
|
|
535
|
+
if isinstance(df, PandasDataFrame):
|
|
536
|
+
if (self._handle_invalid == "error") and (df[self._col].isna().sum() > 0):
|
|
537
|
+
msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
|
|
538
|
+
Set 'keep' or 'skip' for processing NaN."
|
|
539
|
+
raise ValueError(msg)
|
|
540
|
+
if self._handle_invalid == "skip":
|
|
541
|
+
df = df.copy().dropna(subset=[self._col], axis=0)
|
|
542
|
+
|
|
543
|
+
elif isinstance(df, SparkDataFrame):
|
|
544
|
+
if (self._handle_invalid == "error") and (df.filter(isnan(df[self._col])).count() > 0):
|
|
545
|
+
msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
|
|
546
|
+
Set 'keep' or 'skip' for processing NaN."
|
|
547
|
+
raise ValueError(msg)
|
|
548
|
+
if self._handle_invalid == "skip":
|
|
549
|
+
df = df.dropna(subset=[self._col])
|
|
550
|
+
|
|
551
|
+
elif isinstance(df, PolarsDataFrame):
|
|
552
|
+
if (self._handle_invalid == "error") and (df[self._col].is_null().sum() > 0):
|
|
553
|
+
msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
|
|
554
|
+
Set 'keep' or 'skip' for processing NaN."
|
|
555
|
+
raise ValueError(msg)
|
|
556
|
+
if self._handle_invalid == "skip":
|
|
557
|
+
df = df.clone().fill_nan(None).drop_nulls(subset=[self._col])
|
|
558
|
+
|
|
559
|
+
else:
|
|
560
|
+
msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
|
|
561
|
+
raise NotImplementedError(msg)
|
|
562
|
+
return df
|
|
563
|
+
|
|
564
|
+
def save(
|
|
565
|
+
self,
|
|
566
|
+
path: str,
|
|
567
|
+
) -> None:
|
|
568
|
+
discretizer_rule_dict = {}
|
|
569
|
+
discretizer_rule_dict["_class_name"] = self.__class__.__name__
|
|
570
|
+
discretizer_rule_dict["init_args"] = {
|
|
571
|
+
"n_bins": self._n_bins,
|
|
572
|
+
"column": self._col,
|
|
573
|
+
"handle_invalid": self._handle_invalid,
|
|
574
|
+
}
|
|
575
|
+
discretizer_rule_dict["fitted_args"] = {
|
|
576
|
+
"bins": self._bins,
|
|
577
|
+
"is_fitted": self._is_fitted,
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
581
|
+
|
|
582
|
+
if os.path.exists(base_path): # pragma: no cover
|
|
583
|
+
msg = "There is already DiscretizingRule object saved at the given path. File will be overwrited."
|
|
584
|
+
warnings.warn(msg)
|
|
585
|
+
else: # pragma: no cover
|
|
586
|
+
base_path.mkdir(parents=True, exist_ok=True)
|
|
587
|
+
|
|
588
|
+
with open(base_path / "init_args.json", "w+") as file:
|
|
589
|
+
json.dump(discretizer_rule_dict, file)
|
|
590
|
+
|
|
591
|
+
@classmethod
|
|
592
|
+
def load(cls, path: str) -> "QuantileDiscretizingRule":
|
|
593
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
594
|
+
with open(base_path / "init_args.json", "r") as file:
|
|
595
|
+
discretizer_rule_dict = json.loads(file.read())
|
|
596
|
+
|
|
597
|
+
discretizer_rule = cls(**discretizer_rule_dict["init_args"])
|
|
598
|
+
discretizer_rule._bins = discretizer_rule_dict["fitted_args"]["bins"]
|
|
599
|
+
discretizer_rule._is_fitted = discretizer_rule_dict["fitted_args"]["is_fitted"]
|
|
600
|
+
return discretizer_rule
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
class Discretizer:
|
|
604
|
+
"""
|
|
605
|
+
Applies multiple discretizing rules to the data frame.
|
|
606
|
+
Every sample will be distributed into bucket with number from the set [0, 1, ..., n_bins-1].
|
|
607
|
+
"""
|
|
608
|
+
|
|
609
|
+
def __init__(self, rules: Sequence[BaseDiscretizingRule]):
|
|
610
|
+
"""
|
|
611
|
+
:param rules: Sequence of rules.
|
|
612
|
+
"""
|
|
613
|
+
self.rules = rules
|
|
614
|
+
|
|
615
|
+
def fit(self, df: DataFrameLike) -> "Discretizer":
|
|
616
|
+
"""
|
|
617
|
+
Fits a Discretizer by the input dataframe with given rules.
|
|
618
|
+
|
|
619
|
+
:param df: input dataframe.
|
|
620
|
+
:returns: fitted Discretizer.
|
|
621
|
+
"""
|
|
622
|
+
for rule in self.rules:
|
|
623
|
+
rule.fit(df)
|
|
624
|
+
return self
|
|
625
|
+
|
|
626
|
+
def partial_fit(self, df: DataFrameLike) -> "Discretizer":
|
|
627
|
+
"""
|
|
628
|
+
Fits an already fitted Discretizer by the new input data frame with given rules.
|
|
629
|
+
If Discretizer has not been fitted yet - performs default fit.
|
|
630
|
+
|
|
631
|
+
:param df: input dataframe.
|
|
632
|
+
:returns: fitted Discretizer.
|
|
633
|
+
"""
|
|
634
|
+
for rule in self.rules:
|
|
635
|
+
rule.partial_fit(df)
|
|
636
|
+
return self
|
|
637
|
+
|
|
638
|
+
def transform(self, df: DataFrameLike) -> DataFrameLike:
|
|
639
|
+
"""
|
|
640
|
+
Transforms the input data frame.
|
|
641
|
+
If the input data frame contains NaN values then they will be transformed by handle_invalid strategy.
|
|
642
|
+
|
|
643
|
+
:param df: input dataframe.
|
|
644
|
+
:returns: transformed dataframe.
|
|
645
|
+
"""
|
|
646
|
+
for rule in self.rules:
|
|
647
|
+
df = rule.transform(df)
|
|
648
|
+
return df
|
|
649
|
+
|
|
650
|
+
def fit_transform(self, df: DataFrameLike) -> DataFrameLike:
|
|
651
|
+
"""
|
|
652
|
+
Fits a Discretizer by the input dataframe with given rules and transforms the input dataframe.
|
|
653
|
+
|
|
654
|
+
:param df: input dataframe.
|
|
655
|
+
:returns: transformed dataframe.
|
|
656
|
+
"""
|
|
657
|
+
return self.fit(df).transform(df)
|
|
658
|
+
|
|
659
|
+
def set_handle_invalid(self, handle_invalid_rules: Dict[str, HandleInvalidStrategies]) -> None:
|
|
660
|
+
"""
|
|
661
|
+
Modify handle_invalid strategy on already fitted Discretizer.
|
|
662
|
+
|
|
663
|
+
:param handle_invalid_rules: handle_invalid rule.
|
|
664
|
+
|
|
665
|
+
Example: {"item_id" : "keep", "user_id" : "skip", "category_column" : "error"}
|
|
666
|
+
|
|
667
|
+
Default value examples:
|
|
668
|
+
If ``skip`` - filter out rows with invalid values.
|
|
669
|
+
If ``error`` - throw an error.
|
|
670
|
+
If ``keep`` - keep invalid values in a special additional bucket with number = n_bins.
|
|
671
|
+
Default ``keep``.
|
|
672
|
+
"""
|
|
673
|
+
columns = [i.column for i in self.rules]
|
|
674
|
+
for column, handle_invalid in handle_invalid_rules.items():
|
|
675
|
+
if column not in columns:
|
|
676
|
+
msg = f"Column {column} not found."
|
|
677
|
+
raise ValueError(msg)
|
|
678
|
+
rule = list(filter(lambda x: x.column == column, self.rules))
|
|
679
|
+
rule[0].set_handle_invalid(handle_invalid)
|
|
680
|
+
|
|
681
|
+
def save(
|
|
682
|
+
self,
|
|
683
|
+
path: str,
|
|
684
|
+
) -> None:
|
|
685
|
+
discretizer_dict = {}
|
|
686
|
+
discretizer_dict["_class_name"] = self.__class__.__name__
|
|
687
|
+
|
|
688
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
689
|
+
if os.path.exists(base_path): # pragma: no cover
|
|
690
|
+
msg = "There is already LabelEncoder object saved at the given path. File will be overwrited."
|
|
691
|
+
warnings.warn(msg)
|
|
692
|
+
else: # pragma: no cover
|
|
693
|
+
base_path.mkdir(parents=True, exist_ok=True)
|
|
694
|
+
|
|
695
|
+
discretizer_dict["rule_names"] = []
|
|
696
|
+
|
|
697
|
+
for rule in self.rules:
|
|
698
|
+
path_suffix = f"{rule.__class__.__name__}_{rule.column}"
|
|
699
|
+
rule.save(str(base_path) + f"/rules/{path_suffix}")
|
|
700
|
+
discretizer_dict["rule_names"].append(path_suffix)
|
|
701
|
+
|
|
702
|
+
with open(base_path / "init_args.json", "w+") as file:
|
|
703
|
+
json.dump(discretizer_dict, file)
|
|
704
|
+
|
|
705
|
+
@classmethod
|
|
706
|
+
def load(cls, path: str) -> "Discretizer":
|
|
707
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
708
|
+
with open(base_path / "init_args.json", "r") as file:
|
|
709
|
+
discretizer_dict = json.loads(file.read())
|
|
710
|
+
rules = []
|
|
711
|
+
for root, dirs, files in os.walk(str(base_path) + "/rules/"):
|
|
712
|
+
for d in dirs:
|
|
713
|
+
if d.split(".")[0] in discretizer_dict["rule_names"]:
|
|
714
|
+
with open(root + d + "/init_args.json", "r") as file:
|
|
715
|
+
discretizer_rule_dict = json.loads(file.read())
|
|
716
|
+
rules.append(globals()[discretizer_rule_dict["_class_name"]].load(root + d))
|
|
717
|
+
|
|
718
|
+
discretizer = cls(rules=rules)
|
|
719
|
+
return discretizer
|