replay-rec 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +27 -1
  3. replay/data/dataset_utils/dataset_label_encoder.py +6 -3
  4. replay/data/nn/schema.py +37 -16
  5. replay/data/nn/sequence_tokenizer.py +313 -165
  6. replay/data/nn/torch_sequential_dataset.py +17 -8
  7. replay/data/nn/utils.py +14 -7
  8. replay/data/schema.py +10 -6
  9. replay/metrics/offline_metrics.py +2 -2
  10. replay/models/__init__.py +1 -0
  11. replay/models/base_rec.py +18 -21
  12. replay/models/lin_ucb.py +407 -0
  13. replay/models/nn/sequential/bert4rec/dataset.py +17 -4
  14. replay/models/nn/sequential/bert4rec/lightning.py +121 -54
  15. replay/models/nn/sequential/bert4rec/model.py +21 -0
  16. replay/models/nn/sequential/callbacks/prediction_callbacks.py +5 -1
  17. replay/models/nn/sequential/compiled/__init__.py +5 -0
  18. replay/models/nn/sequential/compiled/base_compiled_model.py +261 -0
  19. replay/models/nn/sequential/compiled/bert4rec_compiled.py +152 -0
  20. replay/models/nn/sequential/compiled/sasrec_compiled.py +145 -0
  21. replay/models/nn/sequential/postprocessors/postprocessors.py +27 -1
  22. replay/models/nn/sequential/sasrec/dataset.py +17 -1
  23. replay/models/nn/sequential/sasrec/lightning.py +126 -50
  24. replay/models/nn/sequential/sasrec/model.py +3 -4
  25. replay/preprocessing/__init__.py +7 -1
  26. replay/preprocessing/discretizer.py +719 -0
  27. replay/preprocessing/label_encoder.py +384 -52
  28. replay/splitters/cold_user_random_splitter.py +1 -1
  29. replay/utils/__init__.py +1 -0
  30. replay/utils/common.py +7 -8
  31. replay/utils/session_handler.py +3 -4
  32. replay/utils/spark_utils.py +15 -1
  33. replay/utils/types.py +8 -0
  34. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/METADATA +73 -60
  35. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/RECORD +37 -31
  36. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/LICENSE +0 -0
  37. {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,719 @@
1
+ import abc
2
+ import json
3
+ import os
4
+ import warnings
5
+ from pathlib import Path
6
+ from typing import Dict, List, Literal, Sequence
7
+
8
+ import numpy as np
9
+ import polars as pl
10
+ from sklearn.preprocessing import KBinsDiscretizer
11
+
12
+ from replay.utils import (
13
+ PYSPARK_AVAILABLE,
14
+ DataFrameLike,
15
+ PandasDataFrame,
16
+ PolarsDataFrame,
17
+ SparkDataFrame,
18
+ )
19
+
20
+ if PYSPARK_AVAILABLE: # pragma: no cover
21
+ from pyspark.ml.feature import Bucketizer, QuantileDiscretizer
22
+ from pyspark.sql.functions import isnan
23
+
24
+ HandleInvalidStrategies = Literal["error", "skip", "keep"]
25
+
26
+
27
+ class BaseDiscretizingRule(abc.ABC): # pragma: no cover
28
+ """
29
+ Interface of the discretizing rule
30
+ """
31
+
32
+ @property
33
+ @abc.abstractmethod
34
+ def column(self) -> str:
35
+ raise NotImplementedError()
36
+
37
+ @property
38
+ @abc.abstractmethod
39
+ def n_bins(self) -> int:
40
+ raise NotImplementedError()
41
+
42
+ @abc.abstractmethod
43
+ def fit(self, df: DataFrameLike) -> "BaseDiscretizingRule":
44
+ raise NotImplementedError()
45
+
46
+ @abc.abstractmethod
47
+ def partial_fit(self, df: DataFrameLike) -> "BaseDiscretizingRule":
48
+ raise NotImplementedError()
49
+
50
+ @abc.abstractmethod
51
+ def transform(self, df: DataFrameLike) -> DataFrameLike:
52
+ raise NotImplementedError()
53
+
54
+ @abc.abstractmethod
55
+ def set_handle_invalid(self, handle_invalid: HandleInvalidStrategies) -> None:
56
+ raise NotImplementedError()
57
+
58
+ def fit_transform(self, df: DataFrameLike) -> DataFrameLike:
59
+ return self.fit(df).transform(df)
60
+
61
+
62
+ class GreedyDiscretizingRule(BaseDiscretizingRule):
63
+ """
64
+ Implementation of the Discretizing rule for a column of PySpark, Polars and Pandas DataFrames.
65
+ Discretizes column values according to the Greedy binning strategy:
66
+ https://github.com/microsoft/LightGBM/blob/master/src/io/bin.cpp#L78
67
+ It is recommended to use together with the Discretizer.
68
+ """
69
+
70
+ _HANDLE_INVALID_STRATEGIES = ("error", "skip", "keep")
71
+
72
+ def __init__(
73
+ self,
74
+ column: str,
75
+ n_bins: int,
76
+ min_data_in_bin: int = 1,
77
+ handle_invalid: HandleInvalidStrategies = "keep",
78
+ ) -> None:
79
+ """
80
+ :param column: Name of the column to discretize.
81
+ :param n_bins: Number of intervals where data will be binned.
82
+ :param min_data_in_bin: Minimum number of samples in one bin.
83
+ :param handle_invalid: handle_invalid rule.
84
+ indicates how to process NaN in data.
85
+ If ``skip`` - filter out rows with invalid values.
86
+ If ``error`` - throw an error.
87
+ If ``keep`` - keep invalid values in a special additional bucket with number = n_bins.
88
+ Default ``keep``.
89
+ """
90
+ self._n_bins = n_bins
91
+ self._col = column
92
+ self._min_data_in_bin = min_data_in_bin
93
+ self._bins = None
94
+ self._is_fitted = False
95
+
96
+ if handle_invalid not in self._HANDLE_INVALID_STRATEGIES:
97
+ msg = f"handle_invalid should be either 'error' or 'skip' or 'keep', got {handle_invalid}."
98
+ raise ValueError(msg)
99
+ self._handle_invalid = handle_invalid
100
+
101
+ @property
102
+ def column(self) -> str:
103
+ return self._col
104
+
105
+ @property
106
+ def n_bins(self) -> str:
107
+ return self._n_bins
108
+
109
+ def _greedy_bin_find(
110
+ self,
111
+ distinct_values: np.ndarray,
112
+ counts: np.ndarray,
113
+ num_distinct_values: int,
114
+ max_bin: int,
115
+ total_cnt: int,
116
+ min_data_in_bin: int,
117
+ ) -> List[float]:
118
+ """
119
+ Computes bound for bins.
120
+
121
+ :param distinct_values: Array of unique values.
122
+ :param counts: Number of samples corresponding to the every unique value.
123
+ :param num_distinct_values: Number of unique value.
124
+ :param max_bin: Maximum bin number.
125
+ :param total_cnt: Total number of samples.
126
+ :param min_data_in_bin: Minimum number of samples in one bin.
127
+ """
128
+ bin_upper_bound = []
129
+ assert max_bin > 0
130
+
131
+ if total_cnt < max_bin * min_data_in_bin:
132
+ warn_msg = f"Expected at least {max_bin*min_data_in_bin} samples (n_bins*min_data_in_bin) \
133
+ = ({self._n_bins}*{min_data_in_bin}). Got {total_cnt}. The number of bins will be less in the result"
134
+ warnings.warn(warn_msg)
135
+ if num_distinct_values <= max_bin:
136
+ cur_cnt_inbin = 0
137
+ for i in range(num_distinct_values - 1):
138
+ cur_cnt_inbin += counts[i]
139
+ if cur_cnt_inbin >= min_data_in_bin:
140
+ bin_upper_bound.append((distinct_values[i] + distinct_values[i + 1]) / 2.0)
141
+ cur_cnt_inbin = 0
142
+
143
+ cur_cnt_inbin += counts[num_distinct_values - 1]
144
+ bin_upper_bound.append(float("Inf"))
145
+
146
+ else:
147
+ if min_data_in_bin > 0:
148
+ max_bin = min(max_bin, total_cnt // min_data_in_bin)
149
+ max_bin = max(max_bin, 1)
150
+ mean_bin_size = total_cnt / max_bin
151
+ rest_bin_cnt = max_bin
152
+ rest_sample_cnt = total_cnt
153
+
154
+ is_big_count_value = counts >= mean_bin_size
155
+ rest_bin_cnt -= np.sum(is_big_count_value)
156
+ rest_sample_cnt -= np.sum(counts[is_big_count_value])
157
+
158
+ mean_bin_size = rest_sample_cnt / rest_bin_cnt
159
+ upper_bounds = [float("Inf")] * max_bin
160
+ lower_bounds = [float("Inf")] * max_bin
161
+
162
+ bin_cnt = 0
163
+ lower_bounds[bin_cnt] = distinct_values[0]
164
+ cur_cnt_inbin = 0
165
+
166
+ for i in range(num_distinct_values - 1): # pragma: no cover
167
+ if not is_big_count_value[i]:
168
+ rest_sample_cnt -= counts[i]
169
+
170
+ cur_cnt_inbin += counts[i]
171
+
172
+ if (
173
+ is_big_count_value[i]
174
+ or cur_cnt_inbin >= mean_bin_size
175
+ or is_big_count_value[i + 1]
176
+ and cur_cnt_inbin >= max(1.0, mean_bin_size * 0.5)
177
+ ):
178
+ upper_bounds[bin_cnt] = distinct_values[i]
179
+ bin_cnt += 1
180
+ lower_bounds[bin_cnt] = distinct_values[i + 1]
181
+ if bin_cnt >= max_bin - 1:
182
+ break
183
+ cur_cnt_inbin = 0
184
+ if not is_big_count_value[i]:
185
+ rest_bin_cnt -= 1
186
+ mean_bin_size = rest_sample_cnt / rest_bin_cnt
187
+
188
+ bin_upper_bound = [(upper_bounds[i] + lower_bounds[i + 1]) / 2.0 for i in range(bin_cnt - 1)]
189
+ bin_upper_bound.append(float("Inf"))
190
+ return bin_upper_bound
191
+
192
+ def _fit_spark(self, df: SparkDataFrame) -> None:
193
+ warn_msg = "DataFrame will be partially converted to the Pandas type during internal calculations in 'fit'"
194
+ warnings.warn(warn_msg)
195
+ value_counts = df.groupBy(self._col).count().orderBy(self._col).toPandas()
196
+ bins = [-float("inf")]
197
+ bins += self._greedy_bin_find(
198
+ value_counts[self._col].values,
199
+ value_counts["count"].values,
200
+ value_counts.shape[0],
201
+ self._n_bins + 1,
202
+ df.count(),
203
+ self._min_data_in_bin,
204
+ )
205
+ self._bins = bins
206
+
207
+ def _fit_pandas(self, df: PandasDataFrame) -> None:
208
+ vc = df[self._col].value_counts().sort_index()
209
+ bins = self._greedy_bin_find(
210
+ vc.index.values, vc.values, len(vc), self._n_bins + 1, vc.sum(), self._min_data_in_bin
211
+ )
212
+ self._bins = [-np.inf, *bins]
213
+
214
+ def _fit_polars(self, df: PolarsDataFrame) -> None:
215
+ warn_msg = "DataFrame will be converted to the Pandas type during internal calculations in 'fit'"
216
+ warnings.warn(warn_msg)
217
+ self._fit_pandas(df.to_pandas())
218
+
219
+ def fit(self, df: DataFrameLike) -> "GreedyDiscretizingRule":
220
+ """
221
+ Fits Discretizing Rule to input dataframe.
222
+
223
+ :param df: input dataframe.
224
+ :returns: fitted DiscretizingRule.
225
+ """
226
+ if self._is_fitted:
227
+ return self
228
+
229
+ df = self._validate_input(df)
230
+
231
+ if isinstance(df, PandasDataFrame):
232
+ self._fit_pandas(df)
233
+ elif isinstance(df, SparkDataFrame):
234
+ self._fit_spark(df)
235
+ else:
236
+ self._fit_polars(df)
237
+
238
+ self._is_fitted = True
239
+ return self
240
+
241
+ def partial_fit(self, df: DataFrameLike) -> "GreedyDiscretizingRule":
242
+ """
243
+ Fits new data to already fitted DiscretizingRule.
244
+
245
+ :param df: input dataframe.
246
+ :returns: fitted DiscretizingRule.
247
+ """
248
+ if not self._is_fitted:
249
+ return self.fit(df)
250
+
251
+ msg = f"{self.__class__.__name__} is not implemented for partial_fit yet."
252
+ raise NotImplementedError(msg)
253
+
254
+ def _transform_pandas(self, df: PandasDataFrame) -> PandasDataFrame:
255
+ binned_column = np.digitize(df[self._col].values, self._bins)
256
+ binned_column -= 1
257
+ df_transformed = df.copy()
258
+ df_transformed.loc[:, self._col] = binned_column
259
+ return df_transformed
260
+
261
+ def _transform_spark(self, df: SparkDataFrame) -> SparkDataFrame:
262
+ target_col = self._col + "_discretized"
263
+ bucketizer = Bucketizer(
264
+ splits=self._bins, inputCol=self._col, outputCol=target_col, handleInvalid=self._handle_invalid
265
+ )
266
+ return bucketizer.transform(df).drop(self._col).withColumnRenamed(target_col, self._col)
267
+
268
+ def _transform_polars(self, df: PolarsDataFrame) -> PolarsDataFrame:
269
+ warn_msg = "DataFrame will be converted to the Pandas type during internal calculations in 'transform'"
270
+ warnings.warn(warn_msg)
271
+ return pl.from_pandas(self._transform_pandas(df.to_pandas()))
272
+
273
+ def transform(self, df: DataFrameLike) -> DataFrameLike:
274
+ """
275
+ Transforms input dataframe with fitted DiscretizingRule.
276
+
277
+ :param df: input dataframe.
278
+ :returns: transformed dataframe.
279
+ """
280
+ if not self._is_fitted:
281
+ msg = "Discretizer is not fitted"
282
+ raise RuntimeError(msg)
283
+
284
+ df = self._validate_input(df)
285
+
286
+ if isinstance(df, PandasDataFrame):
287
+ transformed_df = self._transform_pandas(df)
288
+ elif isinstance(df, SparkDataFrame):
289
+ transformed_df = self._transform_spark(df)
290
+ else:
291
+ transformed_df = self._transform_polars(df)
292
+ return transformed_df
293
+
294
+ def set_handle_invalid(self, handle_invalid: HandleInvalidStrategies) -> None:
295
+ """
296
+ Sets strategy to handle invalid values.
297
+
298
+ :param handle_invalid: handle invalid strategy.
299
+ """
300
+ if handle_invalid not in self._HANDLE_INVALID_STRATEGIES:
301
+ msg = f"handle_invalid should be either 'error' or 'skip' or 'keep', got {handle_invalid}."
302
+ raise ValueError(msg)
303
+ self._handle_invalid = handle_invalid
304
+
305
+ def _validate_input(self, df: DataFrameLike) -> DataFrameLike:
306
+ if isinstance(df, PandasDataFrame):
307
+ df_val = df.copy()
308
+ if (self._handle_invalid == "error") and (df_val[self._col].isna().sum() > 0):
309
+ msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
310
+ Set 'keep' or 'skip' for processing NaN."
311
+ raise ValueError(msg)
312
+ if self._handle_invalid == "skip":
313
+ df_val = df_val.dropna(subset=[self._col], axis=0)
314
+ return df_val
315
+
316
+ elif isinstance(df, SparkDataFrame):
317
+ if (self._handle_invalid == "error") and (df.filter(isnan(df[self._col])).count() > 0):
318
+ msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
319
+ Set 'keep' or 'skip' for processing NaN."
320
+ raise ValueError(msg)
321
+ return df
322
+
323
+ elif isinstance(df, PolarsDataFrame):
324
+ if (self._handle_invalid == "error") and (df[self._col].is_null().sum() > 0):
325
+ msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
326
+ Set 'keep' or 'skip' for processing NaN."
327
+ raise ValueError(msg)
328
+ if self._handle_invalid == "skip":
329
+ df = df.clone().fill_nan(None).drop_nulls(subset=[self._col])
330
+ return df
331
+
332
+ else:
333
+ msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
334
+ raise NotImplementedError(msg)
335
+
336
+ def save(
337
+ self,
338
+ path: str,
339
+ ) -> None:
340
+ discretizer_rule_dict = {}
341
+ discretizer_rule_dict["_class_name"] = self.__class__.__name__
342
+ discretizer_rule_dict["init_args"] = {
343
+ "n_bins": self._n_bins,
344
+ "column": self._col,
345
+ "min_data_in_bin": self._min_data_in_bin,
346
+ "handle_invalid": self._handle_invalid,
347
+ }
348
+ discretizer_rule_dict["fitted_args"] = {
349
+ "bins": self._bins,
350
+ "is_fitted": self._is_fitted,
351
+ }
352
+
353
+ base_path = Path(path).with_suffix(".replay").resolve()
354
+
355
+ if os.path.exists(base_path): # pragma: no cover
356
+ msg = "There is already DiscretizingRule object saved at the given path. File will be overwrited."
357
+ warnings.warn(msg)
358
+ else: # pragma: no cover
359
+ base_path.mkdir(parents=True, exist_ok=True)
360
+
361
+ with open(base_path / "init_args.json", "w+") as file:
362
+ json.dump(discretizer_rule_dict, file)
363
+
364
+ @classmethod
365
+ def load(cls, path: str) -> "GreedyDiscretizingRule":
366
+ base_path = Path(path).with_suffix(".replay").resolve()
367
+ with open(base_path / "init_args.json", "r") as file:
368
+ discretizer_rule_dict = json.loads(file.read())
369
+
370
+ discretizer_rule = cls(**discretizer_rule_dict["init_args"])
371
+ discretizer_rule._bins = discretizer_rule_dict["fitted_args"]["bins"]
372
+ discretizer_rule._is_fitted = discretizer_rule_dict["fitted_args"]["is_fitted"]
373
+ return discretizer_rule
374
+
375
+
376
+ class QuantileDiscretizingRule(BaseDiscretizingRule):
377
+ """
378
+ Implementation of the Discretizing rule for a column of PySpark, Polars and Pandas DataFrames.
379
+ Discretizes columns values according to the Quantile strategy. All the data will be distributed
380
+ into buckets with approximately same size.
381
+ It is recommended to use together with the Discretizer.
382
+ """
383
+
384
+ _HANDLE_INVALID_STRATEGIES = ("error", "skip", "keep")
385
+
386
+ def __init__(
387
+ self,
388
+ column: str,
389
+ n_bins: int,
390
+ handle_invalid: HandleInvalidStrategies = "keep",
391
+ ) -> None:
392
+ """
393
+ :param column: Name of the column to discretize.
394
+ :param n_bins: Number of intervals where data will be binned.
395
+ :param handle_invalid: handle_invalid rule.
396
+ indicates how to process NaN in data.
397
+ If ``skip`` - filter out rows with invalid values.
398
+ If ``error`` - throw an error.
399
+ If ``keep`` - keep invalid values in a special additional bucket with number = n_bins.
400
+ Default ``keep``.
401
+ """
402
+ self._n_bins = n_bins
403
+ self._col = column
404
+ self._bins = None
405
+ self._discretizer = None
406
+ self._is_fitted = False
407
+
408
+ if handle_invalid not in self._HANDLE_INVALID_STRATEGIES:
409
+ msg = f"handle_invalid should be either 'error' or 'ski[]' or 'keep', got {handle_invalid}."
410
+ raise ValueError(msg)
411
+ self._handle_invalid = handle_invalid
412
+
413
+ @property
414
+ def column(self) -> str:
415
+ return self._col
416
+
417
+ @property
418
+ def n_bins(self) -> str:
419
+ return self._n_bins
420
+
421
+ def _fit_spark(self, df: SparkDataFrame) -> None:
422
+ discretizer = QuantileDiscretizer(
423
+ numBuckets=self._n_bins, inputCol=self._col, handleInvalid=self._handle_invalid
424
+ )
425
+ self._discretizer = discretizer.fit(df)
426
+ self._bins = self._discretizer.getSplits()
427
+
428
+ def _fit_pandas(self, df: PandasDataFrame) -> None:
429
+ discretizer = KBinsDiscretizer(n_bins=self._n_bins, encode="ordinal", strategy="quantile")
430
+ if self._handle_invalid == "keep":
431
+ self._discretizer = discretizer.fit(df.dropna(subset=[self._col], axis=0)[[self._col]])
432
+ else:
433
+ self._discretizer = discretizer.fit(df[[self._col]])
434
+ self._bins = self._discretizer.bin_edges_[0].astype(float).tolist()
435
+ self._bins[0] = -np.inf
436
+ self._bins[-1] = np.inf
437
+
438
+ def _fit_polars(self, df: PolarsDataFrame) -> None:
439
+ warn_msg = "DataFrame will be converted to the Pandas type during internal calculations in 'fit'"
440
+ warnings.warn(warn_msg)
441
+ self._fit_pandas(df.to_pandas())
442
+
443
+ def fit(self, df: DataFrameLike) -> "GreedyDiscretizingRule":
444
+ """
445
+ Fits DiscretizingRule to input dataframe.
446
+
447
+ :param df: input dataframe.
448
+ :returns: fitted DiscretizingRule.
449
+ """
450
+ if self._is_fitted:
451
+ return self
452
+
453
+ df = self._validate_input(df)
454
+
455
+ if isinstance(df, PandasDataFrame):
456
+ self._fit_pandas(df)
457
+ elif isinstance(df, SparkDataFrame):
458
+ self._fit_spark(df)
459
+ else:
460
+ self._fit_polars(df)
461
+
462
+ self._is_fitted = True
463
+ return self
464
+
465
+ def partial_fit(self, df: DataFrameLike):
466
+ """
467
+ Fits new data to already fitted DiscretizingRule.
468
+
469
+ :param df: input dataframe.
470
+ :returns: fitted DiscretizingRule.
471
+ """
472
+ if not self._is_fitted:
473
+ return self.fit(df)
474
+
475
+ msg = f"{self.__class__.__name__} is not implemented for partial_fit yet."
476
+ raise NotImplementedError(msg)
477
+
478
+ def _transform_pandas(self, df: PandasDataFrame) -> PandasDataFrame:
479
+ df_nan_part = df[df[self._col].isna()]
480
+ df_real_part = df[~df[self._col].isna()]
481
+
482
+ binned_column = np.digitize(df_real_part[self._col].values, self._bins)
483
+ binned_column -= 1
484
+
485
+ df_transformed = df.copy()
486
+ df_transformed.loc[df_real_part.index, self._col] = binned_column
487
+ df_transformed.loc[df_nan_part.index, self._col] = [self._n_bins] * len(df_nan_part)
488
+ return df_transformed
489
+
490
+ def _transform_spark(self, df: SparkDataFrame) -> SparkDataFrame:
491
+ target_col = self._col + "_discretized"
492
+ bucketizer = Bucketizer(
493
+ splits=self._bins, inputCol=self._col, outputCol=target_col, handleInvalid=self._handle_invalid
494
+ )
495
+ return bucketizer.transform(df).drop(self._col).withColumnRenamed(target_col, self._col)
496
+
497
+ def _transform_polars(self, df: PolarsDataFrame) -> SparkDataFrame:
498
+ warn_msg = "DataFrame will be converted to the Pandas type during internal calculations in 'transform'"
499
+ warnings.warn(warn_msg)
500
+ return pl.from_pandas(self._transform_pandas(df.to_pandas()))
501
+
502
+ def transform(self, df: DataFrameLike) -> DataFrameLike:
503
+ """
504
+ Transforms input dataframe with fitted DiscretizingRule.
505
+
506
+ :param df: input dataframe.
507
+ :returns: transformed dataframe.
508
+ """
509
+ if not self._is_fitted:
510
+ msg = "Discretizer is not fitted"
511
+ raise RuntimeError(msg)
512
+
513
+ df = self._validate_input(df)
514
+
515
+ if isinstance(df, PandasDataFrame):
516
+ transformed_df = self._transform_pandas(df)
517
+ elif isinstance(df, SparkDataFrame):
518
+ transformed_df = self._transform_spark(df)
519
+ else:
520
+ transformed_df = self._transform_polars(df)
521
+ return transformed_df
522
+
523
+ def set_handle_invalid(self, handle_invalid: HandleInvalidStrategies) -> None:
524
+ """
525
+ Sets strategy to handle invalid values.
526
+
527
+ :param handle_invalid: handle invalid strategy.
528
+ """
529
+ if handle_invalid not in self._HANDLE_INVALID_STRATEGIES:
530
+ msg = f"handle_invalid should be either 'error' or 'skip' or 'keep', got {handle_invalid}."
531
+ raise ValueError(msg)
532
+ self._handle_invalid = handle_invalid
533
+
534
+ def _validate_input(self, df: DataFrameLike) -> DataFrameLike:
535
+ if isinstance(df, PandasDataFrame):
536
+ if (self._handle_invalid == "error") and (df[self._col].isna().sum() > 0):
537
+ msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
538
+ Set 'keep' or 'skip' for processing NaN."
539
+ raise ValueError(msg)
540
+ if self._handle_invalid == "skip":
541
+ df = df.copy().dropna(subset=[self._col], axis=0)
542
+
543
+ elif isinstance(df, SparkDataFrame):
544
+ if (self._handle_invalid == "error") and (df.filter(isnan(df[self._col])).count() > 0):
545
+ msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
546
+ Set 'keep' or 'skip' for processing NaN."
547
+ raise ValueError(msg)
548
+ if self._handle_invalid == "skip":
549
+ df = df.dropna(subset=[self._col])
550
+
551
+ elif isinstance(df, PolarsDataFrame):
552
+ if (self._handle_invalid == "error") and (df[self._col].is_null().sum() > 0):
553
+ msg = "Data contains NaN. 'handle_invalid' param equals 'error'. \
554
+ Set 'keep' or 'skip' for processing NaN."
555
+ raise ValueError(msg)
556
+ if self._handle_invalid == "skip":
557
+ df = df.clone().fill_nan(None).drop_nulls(subset=[self._col])
558
+
559
+ else:
560
+ msg = f"{self.__class__.__name__} is not implemented for {type(df)}"
561
+ raise NotImplementedError(msg)
562
+ return df
563
+
564
+ def save(
565
+ self,
566
+ path: str,
567
+ ) -> None:
568
+ discretizer_rule_dict = {}
569
+ discretizer_rule_dict["_class_name"] = self.__class__.__name__
570
+ discretizer_rule_dict["init_args"] = {
571
+ "n_bins": self._n_bins,
572
+ "column": self._col,
573
+ "handle_invalid": self._handle_invalid,
574
+ }
575
+ discretizer_rule_dict["fitted_args"] = {
576
+ "bins": self._bins,
577
+ "is_fitted": self._is_fitted,
578
+ }
579
+
580
+ base_path = Path(path).with_suffix(".replay").resolve()
581
+
582
+ if os.path.exists(base_path): # pragma: no cover
583
+ msg = "There is already DiscretizingRule object saved at the given path. File will be overwrited."
584
+ warnings.warn(msg)
585
+ else: # pragma: no cover
586
+ base_path.mkdir(parents=True, exist_ok=True)
587
+
588
+ with open(base_path / "init_args.json", "w+") as file:
589
+ json.dump(discretizer_rule_dict, file)
590
+
591
+ @classmethod
592
+ def load(cls, path: str) -> "QuantileDiscretizingRule":
593
+ base_path = Path(path).with_suffix(".replay").resolve()
594
+ with open(base_path / "init_args.json", "r") as file:
595
+ discretizer_rule_dict = json.loads(file.read())
596
+
597
+ discretizer_rule = cls(**discretizer_rule_dict["init_args"])
598
+ discretizer_rule._bins = discretizer_rule_dict["fitted_args"]["bins"]
599
+ discretizer_rule._is_fitted = discretizer_rule_dict["fitted_args"]["is_fitted"]
600
+ return discretizer_rule
601
+
602
+
603
+ class Discretizer:
604
+ """
605
+ Applies multiple discretizing rules to the data frame.
606
+ Every sample will be distributed into bucket with number from the set [0, 1, ..., n_bins-1].
607
+ """
608
+
609
+ def __init__(self, rules: Sequence[BaseDiscretizingRule]):
610
+ """
611
+ :param rules: Sequence of rules.
612
+ """
613
+ self.rules = rules
614
+
615
+ def fit(self, df: DataFrameLike) -> "Discretizer":
616
+ """
617
+ Fits a Discretizer by the input dataframe with given rules.
618
+
619
+ :param df: input dataframe.
620
+ :returns: fitted Discretizer.
621
+ """
622
+ for rule in self.rules:
623
+ rule.fit(df)
624
+ return self
625
+
626
+ def partial_fit(self, df: DataFrameLike) -> "Discretizer":
627
+ """
628
+ Fits an already fitted Discretizer by the new input data frame with given rules.
629
+ If Discretizer has not been fitted yet - performs default fit.
630
+
631
+ :param df: input dataframe.
632
+ :returns: fitted Discretizer.
633
+ """
634
+ for rule in self.rules:
635
+ rule.partial_fit(df)
636
+ return self
637
+
638
+ def transform(self, df: DataFrameLike) -> DataFrameLike:
639
+ """
640
+ Transforms the input data frame.
641
+ If the input data frame contains NaN values then they will be transformed by handle_invalid strategy.
642
+
643
+ :param df: input dataframe.
644
+ :returns: transformed dataframe.
645
+ """
646
+ for rule in self.rules:
647
+ df = rule.transform(df)
648
+ return df
649
+
650
+ def fit_transform(self, df: DataFrameLike) -> DataFrameLike:
651
+ """
652
+ Fits a Discretizer by the input dataframe with given rules and transforms the input dataframe.
653
+
654
+ :param df: input dataframe.
655
+ :returns: transformed dataframe.
656
+ """
657
+ return self.fit(df).transform(df)
658
+
659
+ def set_handle_invalid(self, handle_invalid_rules: Dict[str, HandleInvalidStrategies]) -> None:
660
+ """
661
+ Modify handle_invalid strategy on already fitted Discretizer.
662
+
663
+ :param handle_invalid_rules: handle_invalid rule.
664
+
665
+ Example: {"item_id" : "keep", "user_id" : "skip", "category_column" : "error"}
666
+
667
+ Default value examples:
668
+ If ``skip`` - filter out rows with invalid values.
669
+ If ``error`` - throw an error.
670
+ If ``keep`` - keep invalid values in a special additional bucket with number = n_bins.
671
+ Default ``keep``.
672
+ """
673
+ columns = [i.column for i in self.rules]
674
+ for column, handle_invalid in handle_invalid_rules.items():
675
+ if column not in columns:
676
+ msg = f"Column {column} not found."
677
+ raise ValueError(msg)
678
+ rule = list(filter(lambda x: x.column == column, self.rules))
679
+ rule[0].set_handle_invalid(handle_invalid)
680
+
681
+ def save(
682
+ self,
683
+ path: str,
684
+ ) -> None:
685
+ discretizer_dict = {}
686
+ discretizer_dict["_class_name"] = self.__class__.__name__
687
+
688
+ base_path = Path(path).with_suffix(".replay").resolve()
689
+ if os.path.exists(base_path): # pragma: no cover
690
+ msg = "There is already LabelEncoder object saved at the given path. File will be overwrited."
691
+ warnings.warn(msg)
692
+ else: # pragma: no cover
693
+ base_path.mkdir(parents=True, exist_ok=True)
694
+
695
+ discretizer_dict["rule_names"] = []
696
+
697
+ for rule in self.rules:
698
+ path_suffix = f"{rule.__class__.__name__}_{rule.column}"
699
+ rule.save(str(base_path) + f"/rules/{path_suffix}")
700
+ discretizer_dict["rule_names"].append(path_suffix)
701
+
702
+ with open(base_path / "init_args.json", "w+") as file:
703
+ json.dump(discretizer_dict, file)
704
+
705
+ @classmethod
706
+ def load(cls, path: str) -> "Discretizer":
707
+ base_path = Path(path).with_suffix(".replay").resolve()
708
+ with open(base_path / "init_args.json", "r") as file:
709
+ discretizer_dict = json.loads(file.read())
710
+ rules = []
711
+ for root, dirs, files in os.walk(str(base_path) + "/rules/"):
712
+ for d in dirs:
713
+ if d.split(".")[0] in discretizer_dict["rule_names"]:
714
+ with open(root + d + "/init_args.json", "r") as file:
715
+ discretizer_rule_dict = json.loads(file.read())
716
+ rules.append(globals()[discretizer_rule_dict["_class_name"]].load(root + d))
717
+
718
+ discretizer = cls(rules=rules)
719
+ return discretizer