replay-rec 0.20.1rc0__py3-none-any.whl → 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. replay/__init__.py +1 -1
  2. {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/METADATA +18 -12
  3. {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/RECORD +6 -61
  4. replay/experimental/__init__.py +0 -0
  5. replay/experimental/metrics/__init__.py +0 -62
  6. replay/experimental/metrics/base_metric.py +0 -603
  7. replay/experimental/metrics/coverage.py +0 -97
  8. replay/experimental/metrics/experiment.py +0 -175
  9. replay/experimental/metrics/hitrate.py +0 -26
  10. replay/experimental/metrics/map.py +0 -30
  11. replay/experimental/metrics/mrr.py +0 -18
  12. replay/experimental/metrics/ncis_precision.py +0 -31
  13. replay/experimental/metrics/ndcg.py +0 -49
  14. replay/experimental/metrics/precision.py +0 -22
  15. replay/experimental/metrics/recall.py +0 -25
  16. replay/experimental/metrics/rocauc.py +0 -49
  17. replay/experimental/metrics/surprisal.py +0 -90
  18. replay/experimental/metrics/unexpectedness.py +0 -76
  19. replay/experimental/models/__init__.py +0 -50
  20. replay/experimental/models/admm_slim.py +0 -257
  21. replay/experimental/models/base_neighbour_rec.py +0 -200
  22. replay/experimental/models/base_rec.py +0 -1386
  23. replay/experimental/models/base_torch_rec.py +0 -234
  24. replay/experimental/models/cql.py +0 -454
  25. replay/experimental/models/ddpg.py +0 -932
  26. replay/experimental/models/dt4rec/__init__.py +0 -0
  27. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  28. replay/experimental/models/dt4rec/gpt1.py +0 -401
  29. replay/experimental/models/dt4rec/trainer.py +0 -127
  30. replay/experimental/models/dt4rec/utils.py +0 -264
  31. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  32. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  33. replay/experimental/models/hierarchical_recommender.py +0 -331
  34. replay/experimental/models/implicit_wrap.py +0 -131
  35. replay/experimental/models/lightfm_wrap.py +0 -303
  36. replay/experimental/models/mult_vae.py +0 -332
  37. replay/experimental/models/neural_ts.py +0 -986
  38. replay/experimental/models/neuromf.py +0 -406
  39. replay/experimental/models/scala_als.py +0 -293
  40. replay/experimental/models/u_lin_ucb.py +0 -115
  41. replay/experimental/nn/data/__init__.py +0 -1
  42. replay/experimental/nn/data/schema_builder.py +0 -102
  43. replay/experimental/preprocessing/__init__.py +0 -3
  44. replay/experimental/preprocessing/data_preparator.py +0 -839
  45. replay/experimental/preprocessing/padder.py +0 -229
  46. replay/experimental/preprocessing/sequence_generator.py +0 -208
  47. replay/experimental/scenarios/__init__.py +0 -1
  48. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  49. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  50. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  51. replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
  52. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  53. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  54. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  55. replay/experimental/utils/__init__.py +0 -0
  56. replay/experimental/utils/logger.py +0 -24
  57. replay/experimental/utils/model_handler.py +0 -186
  58. replay/experimental/utils/session_handler.py +0 -44
  59. {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/WHEEL +0 -0
  60. {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/licenses/LICENSE +0 -0
  61. {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/licenses/NOTICE +0 -0
@@ -1,839 +0,0 @@
1
- """
2
- Contains classes for data preparation and categorical features transformation.
3
- ``DataPreparator`` is used to transform DataFrames to a library format.
4
- ``Indexed`` is used to convert user and item ids to numeric format.
5
- ``CatFeaturesTransformer`` transforms categorical features with one-hot encoding.
6
- ``ToNumericFeatureTransformer`` leaves only numerical features
7
- by one-hot encoding of some features and deleting the others.
8
- """
9
-
10
- import json
11
- import logging
12
- import string
13
- from os.path import join
14
- from typing import Optional
15
-
16
- from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, MissingImport, SparkDataFrame
17
- from replay.utils.session_handler import State
18
-
19
- if PYSPARK_AVAILABLE:
20
- from pyspark.ml import Estimator, Transformer
21
- from pyspark.ml.feature import IndexToString, StringIndexer, StringIndexerModel
22
- from pyspark.ml.util import DefaultParamsWriter, MLReadable, MLReader, MLWritable, MLWriter
23
- from pyspark.sql import functions as sf
24
- from pyspark.sql.types import DoubleType, IntegerType, NumericType, StructField, StructType
25
-
26
- from replay.utils.spark_utils import convert2spark, process_timestamp_column
27
-
28
-
29
- LOG_COLUMNS = ["user_id", "item_id", "timestamp", "relevance"]
30
-
31
-
32
- if PYSPARK_AVAILABLE:
33
-
34
- class Indexer:
35
- """
36
- This class is used to convert arbitrary id to numerical idx and back.
37
- """
38
-
39
- user_indexer: StringIndexerModel
40
- item_indexer: StringIndexerModel
41
- inv_user_indexer: IndexToString
42
- inv_item_indexer: IndexToString
43
- user_type: None
44
- item_type: None
45
- suffix = "inner"
46
-
47
- def __init__(self, user_col="user_id", item_col="item_id"):
48
- """
49
- Provide column names for indexer to use
50
- """
51
- self.user_col = user_col
52
- self.item_col = item_col
53
-
54
- @property
55
- def _init_args(self):
56
- return {
57
- "user_col": self.user_col,
58
- "item_col": self.item_col,
59
- }
60
-
61
- def fit(
62
- self,
63
- users: SparkDataFrame,
64
- items: SparkDataFrame,
65
- ) -> None:
66
- """
67
- Creates indexers to map raw id to numerical idx so that spark can handle them.
68
- :param users: SparkDataFrame containing user column
69
- :param items: SparkDataFrame containing item column
70
- :return:
71
- """
72
- users = users.select(self.user_col).withColumnRenamed(self.user_col, f"{self.user_col}_{self.suffix}")
73
- items = items.select(self.item_col).withColumnRenamed(self.item_col, f"{self.item_col}_{self.suffix}")
74
-
75
- self.user_type = users.schema[f"{self.user_col}_{self.suffix}"].dataType
76
- self.item_type = items.schema[f"{self.item_col}_{self.suffix}"].dataType
77
-
78
- self.user_indexer = StringIndexer(inputCol=f"{self.user_col}_{self.suffix}", outputCol="user_idx").fit(
79
- users
80
- )
81
- self.item_indexer = StringIndexer(inputCol=f"{self.item_col}_{self.suffix}", outputCol="item_idx").fit(
82
- items
83
- )
84
- self.inv_user_indexer = IndexToString(
85
- inputCol=f"{self.user_col}_{self.suffix}",
86
- outputCol=self.user_col,
87
- labels=self.user_indexer.labels,
88
- )
89
- self.inv_item_indexer = IndexToString(
90
- inputCol=f"{self.item_col}_{self.suffix}",
91
- outputCol=self.item_col,
92
- labels=self.item_indexer.labels,
93
- )
94
-
95
- def transform(self, df: SparkDataFrame) -> Optional[SparkDataFrame]:
96
- """
97
- Convert raw ``user_col`` and ``item_col`` to numerical ``user_idx`` and ``item_idx``
98
-
99
- :param df: dataframe with raw indexes
100
- :return: dataframe with converted indexes
101
- """
102
- if self.item_col in df.columns:
103
- remaining_cols = df.drop(self.item_col).columns
104
- df = df.withColumnRenamed(self.item_col, f"{self.item_col}_{self.suffix}")
105
- self._reindex(df, "item")
106
- df = self.item_indexer.transform(df).select(
107
- sf.col("item_idx").cast("int").alias("item_idx"),
108
- *remaining_cols,
109
- )
110
- if self.user_col in df.columns:
111
- remaining_cols = df.drop(self.user_col).columns
112
- df = df.withColumnRenamed(self.user_col, f"{self.user_col}_{self.suffix}")
113
- self._reindex(df, "user")
114
- df = self.user_indexer.transform(df).select(
115
- sf.col("user_idx").cast("int").alias("user_idx"),
116
- *remaining_cols,
117
- )
118
- return df
119
-
120
- def inverse_transform(self, df: SparkDataFrame) -> SparkDataFrame:
121
- """
122
- Convert SparkDataFrame to the initial indexes.
123
-
124
- :param df: SparkDataFrame with numerical ``user_idx/item_idx`` columns
125
- :return: SparkDataFrame with original user/item columns
126
- """
127
- res = df
128
- if "item_idx" in df.columns:
129
- remaining_cols = res.drop("item_idx").columns
130
- res = self.inv_item_indexer.transform(
131
- res.withColumnRenamed("item_idx", f"{self.item_col}_{self.suffix}")
132
- ).select(
133
- sf.col(self.item_col).cast(self.item_type).alias(self.item_col),
134
- *remaining_cols,
135
- )
136
- if "user_idx" in df.columns:
137
- remaining_cols = res.drop("user_idx").columns
138
- res = self.inv_user_indexer.transform(
139
- res.withColumnRenamed("user_idx", f"{self.user_col}_{self.suffix}")
140
- ).select(
141
- sf.col(self.user_col).cast(self.user_type).alias(self.user_col),
142
- *remaining_cols,
143
- )
144
- return res
145
-
146
- def _reindex(self, df: SparkDataFrame, entity: str):
147
- """
148
- Update indexer with new entries.
149
-
150
- :param df: SparkDataFrame with users/items
151
- :param entity: user or item
152
- """
153
- indexer = getattr(self, f"{entity}_indexer")
154
- inv_indexer = getattr(self, f"inv_{entity}_indexer")
155
- new_objects = set(
156
- map(
157
- str,
158
- df.select(indexer.getInputCol()).distinct().toPandas()[indexer.getInputCol()],
159
- )
160
- ).difference(indexer.labels)
161
- if new_objects:
162
- new_labels = indexer.labels + list(new_objects)
163
- setattr(
164
- self,
165
- f"{entity}_indexer",
166
- indexer.from_labels(
167
- new_labels,
168
- inputCol=indexer.getInputCol(),
169
- outputCol=indexer.getOutputCol(),
170
- handleInvalid="error",
171
- ),
172
- )
173
- inv_indexer.setLabels(new_labels)
174
-
175
- # We need to inherit it from DefaultParamsWriter to make it being saved correctly within Pipeline
176
- class JoinIndexerMLWriter(DefaultParamsWriter):
177
- """Implements saving the JoinIndexerTransformer instance to disk.
178
- Used when saving a trained pipeline.
179
- Implements MLWriter.saveImpl(path) method.
180
- """
181
-
182
- def __init__(self, instance):
183
- super().__init__(instance)
184
- self.instance = instance
185
-
186
- def saveImpl(self, path: str) -> None: # noqa: N802
187
- """Save implementation"""
188
- super().saveImpl(path)
189
-
190
- spark = State().session
191
-
192
- init_args = self.instance._init_args
193
- sc = spark.sparkContext
194
- df = spark.read.json(sc.parallelize([json.dumps(init_args)]))
195
- df.coalesce(1).write.mode("overwrite").json(join(path, "init_args.json"))
196
-
197
- self.instance.user_col_2_index_map.write.mode("overwrite").save(join(path, "user_col_2_index_map.parquet"))
198
- self.instance.item_col_2_index_map.write.mode("overwrite").save(join(path, "item_col_2_index_map.parquet"))
199
-
200
- class JoinIndexerMLReader(MLReader):
201
- """Implements reading the JoinIndexerTransformer instance from disk.
202
- Used when loading a trained pipeline.
203
- """
204
-
205
- def load(self, path):
206
- """Load the ML instance from the input path."""
207
- spark = State().session
208
- args = spark.read.json(join(path, "init_args.json")).first().asDict(recursive=True)
209
- user_col_2_index_map = spark.read.parquet(join(path, "user_col_2_index_map.parquet"))
210
- item_col_2_index_map = spark.read.parquet(join(path, "item_col_2_index_map.parquet"))
211
-
212
- indexer = JoinBasedIndexerTransformer(
213
- user_col=args["user_col"],
214
- user_type=args["user_type"],
215
- user_col_2_index_map=user_col_2_index_map,
216
- item_col=args["item_col"],
217
- item_type=args["item_type"],
218
- item_col_2_index_map=item_col_2_index_map,
219
- )
220
-
221
- return indexer
222
-
223
- class JoinBasedIndexerTransformer(Transformer, MLWritable, MLReadable):
224
- """
225
- JoinBasedIndexer, that index user column and item column in input dataframe
226
- """
227
-
228
- def __init__(
229
- self,
230
- user_col: str,
231
- item_col: str,
232
- user_type: str,
233
- item_type: str,
234
- user_col_2_index_map: SparkDataFrame,
235
- item_col_2_index_map: SparkDataFrame,
236
- update_map_on_transform: bool = True,
237
- force_broadcast_on_mapping_joins: bool = True,
238
- ):
239
- super().__init__()
240
- self.user_col = user_col
241
- self.item_col = item_col
242
- self.user_type = user_type
243
- self.item_type = item_type
244
- self.user_col_2_index_map = user_col_2_index_map
245
- self.item_col_2_index_map = item_col_2_index_map
246
- self.update_map_on_transform = update_map_on_transform
247
- self.force_broadcast_on_mapping_joins = force_broadcast_on_mapping_joins
248
-
249
- @property
250
- def _init_args(self):
251
- return {
252
- "user_col": self.user_col,
253
- "item_col": self.item_col,
254
- "user_type": self.user_type,
255
- "item_type": self.item_type,
256
- "update_map_on_transform": self.update_map_on_transform,
257
- "force_broadcast_on_mapping_joins": self.force_broadcast_on_mapping_joins,
258
- }
259
-
260
- def set_update_map_on_transform(self, value: bool):
261
- """Sets 'update_map_on_transform' flag"""
262
- self.update_map_on_transform = value
263
-
264
- def set_force_broadcast_on_mapping_joins(self, value: bool):
265
- """Sets 'force_broadcast_on_mapping_joins' flag"""
266
- self.force_broadcast_on_mapping_joins = value
267
-
268
- def _get_item_mapping(self) -> SparkDataFrame:
269
- if self.force_broadcast_on_mapping_joins:
270
- mapping = sf.broadcast(self.item_col_2_index_map)
271
- else:
272
- mapping = self.item_col_2_index_map
273
- return mapping
274
-
275
- def _get_user_mapping(self) -> SparkDataFrame:
276
- if self.force_broadcast_on_mapping_joins:
277
- mapping = sf.broadcast(self.user_col_2_index_map)
278
- else:
279
- mapping = self.user_col_2_index_map
280
- return mapping
281
-
282
- def write(self) -> MLWriter:
283
- """Returns MLWriter instance that can save the Transformer instance."""
284
- return JoinIndexerMLWriter(self)
285
-
286
- @classmethod
287
- def read(cls):
288
- """Returns an MLReader instance for this class."""
289
- return JoinIndexerMLReader()
290
-
291
- def _update_maps(self, df: SparkDataFrame):
292
- new_items = (
293
- df.join(self._get_item_mapping(), on=self.item_col, how="left_anti").select(self.item_col).distinct()
294
- )
295
- prev_item_count = self.item_col_2_index_map.count()
296
- new_items_map = JoinBasedIndexerEstimator.get_map(new_items, self.item_col, "item_idx").select(
297
- self.item_col, (sf.col("item_idx") + prev_item_count).alias("item_idx")
298
- )
299
- self.item_col_2_index_map = self.item_col_2_index_map.union(new_items_map)
300
-
301
- new_users = (
302
- df.join(self._get_user_mapping(), on=self.user_col, how="left_anti").select(self.user_col).distinct()
303
- )
304
- prev_user_count = self.user_col_2_index_map.count()
305
- new_users_map = JoinBasedIndexerEstimator.get_map(new_users, self.user_col, "user_idx").select(
306
- self.user_col, (sf.col("user_idx") + prev_user_count).alias("user_idx")
307
- )
308
- self.user_col_2_index_map = self.user_col_2_index_map.union(new_users_map)
309
-
310
- def _transform(self, dataset: SparkDataFrame) -> SparkDataFrame:
311
- if self.update_map_on_transform:
312
- self._update_maps(dataset)
313
-
314
- if self.item_col in dataset.columns:
315
- remaining_cols = dataset.drop(self.item_col).columns
316
- dataset = dataset.join(self._get_item_mapping(), on=self.item_col, how="left").select(
317
- sf.col("item_idx").cast("int").alias("item_idx"),
318
- *remaining_cols,
319
- )
320
- if self.user_col in dataset.columns:
321
- remaining_cols = dataset.drop(self.user_col).columns
322
- dataset = dataset.join(self._get_user_mapping(), on=self.user_col, how="left").select(
323
- sf.col("user_idx").cast("int").alias("user_idx"),
324
- *remaining_cols,
325
- )
326
- return dataset
327
-
328
- def inverse_transform(self, df: SparkDataFrame) -> SparkDataFrame:
329
- """
330
- Convert SparkDataFrame to the initial indexes.
331
-
332
- :param df: SparkDataFrame with numerical ``user_idx/item_idx`` columns
333
- :return: SparkDataFrame with original user/item columns
334
- """
335
- if "item_idx" in df.columns:
336
- remaining_cols = df.drop("item_idx").columns
337
- df = df.join(self._get_item_mapping(), on="item_idx", how="left").select(
338
- self.item_col,
339
- *remaining_cols,
340
- )
341
- if "user_idx" in df.columns:
342
- remaining_cols = df.drop("user_idx").columns
343
- df = df.join(self._get_user_mapping(), on="user_idx", how="left").select(
344
- self.user_col,
345
- *remaining_cols,
346
- )
347
- return df
348
-
349
- class JoinBasedIndexerEstimator(Estimator):
350
- """
351
- Estimator that produces JoinBasedIndexerTransformer
352
- """
353
-
354
- def __init__(self, user_col="user_id", item_col="item_id"):
355
- """
356
- Provide column names for indexer to use
357
- """
358
- self.user_col = user_col
359
- self.item_col = item_col
360
- self.user_col_2_index_map = None
361
- self.item_col_2_index_map = None
362
- self.user_type = None
363
- self.item_type = None
364
-
365
- @staticmethod
366
- def get_map(df: SparkDataFrame, col_name: str, idx_col_name: str) -> SparkDataFrame:
367
- """Creates indexes [0, .., k] for values from `col_name` column.
368
-
369
- :param df: input dataframe
370
- :param col_name: column name from `df` that need to index
371
- :param idx_col_name: column name with indexes
372
- :return: SparkDataFrame with map "col_name" -> "idx_col_name"
373
- """
374
- uid_rdd = df.select(col_name).distinct().rdd.map(lambda x: x[col_name]).zipWithIndex()
375
-
376
- return uid_rdd.toDF(
377
- StructType(
378
- [
379
- df.schema[col_name],
380
- StructField(idx_col_name, IntegerType(), False),
381
- ]
382
- )
383
- )
384
-
385
- def _fit(self, dataset: SparkDataFrame) -> Transformer:
386
- """
387
- Creates indexers to map raw id to numerical idx so that spark can handle them.
388
- :param df: SparkDataFrame containing user column and item column
389
- :return:
390
- """
391
-
392
- self.user_col_2_index_map = self.get_map(dataset, self.user_col, "user_idx")
393
- self.item_col_2_index_map = self.get_map(dataset, self.item_col, "item_idx")
394
-
395
- self.user_type = dataset.schema[self.user_col].dataType
396
- self.item_type = dataset.schema[self.item_col].dataType
397
-
398
- return JoinBasedIndexerTransformer(
399
- user_col=self.user_col,
400
- user_type=str(self.user_type),
401
- item_col=self.item_col,
402
- item_type=str(self.item_type),
403
- user_col_2_index_map=self.user_col_2_index_map,
404
- item_col_2_index_map=self.item_col_2_index_map,
405
- )
406
-
407
- class DataPreparator:
408
- """Transforms data to a library format:
409
- - read as a spark dataframe/ convert pandas dataframe to spark
410
- - check for nulls
411
- - create relevance/timestamp columns if absent
412
- - convert dates to TimestampType
413
-
414
- Examples:
415
-
416
- Loading log DataFrame
417
-
418
- >>> import pandas as pd
419
- >>> from replay.experimental.preprocessing.data_preparator import DataPreparator
420
- >>>
421
- >>> log = pd.DataFrame({"user": [2, 2, 2, 1],
422
- ... "item_id": [1, 2, 3, 3],
423
- ... "rel": [5, 5, 5, 5]}
424
- ... )
425
- >>> dp = DataPreparator()
426
- >>> correct_log = dp.transform(data=log,
427
- ... columns_mapping={"user_id": "user",
428
- ... "item_id": "item_id",
429
- ... "relevance": "rel"}
430
- ... )
431
- >>> correct_log.show(2)
432
- +-------+-------+---------+-------------------+
433
- |user_id|item_id|relevance| timestamp|
434
- +-------+-------+---------+-------------------+
435
- | 2| 1| 5.0|2099-01-01 00:00:00|
436
- | 2| 2| 5.0|2099-01-01 00:00:00|
437
- +-------+-------+---------+-------------------+
438
- only showing top 2 rows
439
- <BLANKLINE>
440
-
441
-
442
- Loading user features
443
-
444
- >>> import pandas as pd
445
- >>> from replay.experimental.preprocessing.data_preparator import DataPreparator
446
- >>>
447
- >>> log = pd.DataFrame({"user": ["user1", "user1", "user2"],
448
- ... "f0": ["feature1","feature2","feature1"],
449
- ... "f1": ["left","left","center"],
450
- ... "ts": ["2019-01-01","2019-01-01","2019-01-01"]}
451
- ... )
452
- >>> dp = DataPreparator()
453
- >>> correct_log = dp.transform(data=log,
454
- ... columns_mapping={"user_id": "user"},
455
- ... )
456
- >>> correct_log.show(3)
457
- +-------+--------+------+----------+
458
- |user_id| f0| f1| ts|
459
- +-------+--------+------+----------+
460
- | user1|feature1| left|2019-01-01|
461
- | user1|feature2| left|2019-01-01|
462
- | user2|feature1|center|2019-01-01|
463
- +-------+--------+------+----------+
464
- <BLANKLINE>
465
-
466
- """
467
-
468
- _logger: Optional[logging.Logger] = None
469
-
470
- @property
471
- def logger(self) -> logging.Logger:
472
- """
473
- :returns: get library logger
474
- """
475
- if self._logger is None:
476
- self._logger = logging.getLogger("replay")
477
- return self._logger
478
-
479
- @staticmethod
480
- def read_as_spark_df(
481
- data: Optional[DataFrameLike] = None,
482
- path: Optional[str] = None,
483
- format_type: Optional[str] = None,
484
- **kwargs,
485
- ) -> SparkDataFrame:
486
- """
487
- Read spark dataframe from file of transform pandas dataframe.
488
-
489
- :param data: DataFrame to process (``pass`` or ``data`` should be defined)
490
- :param path: path to data (``pass`` or ``data`` should be defined)
491
- :param format_type: file type, one of ``[csv , parquet , json , table]``
492
- :param kwargs: extra arguments passed to
493
- ``spark.read.<format>(path, **reader_kwargs)``
494
- :return: spark DataFrame
495
- """
496
- if data is not None:
497
- dataframe = convert2spark(data)
498
- elif path and format_type:
499
- spark = State().session
500
- if format_type == "csv":
501
- dataframe = spark.read.csv(path, inferSchema=True, **kwargs)
502
- elif format_type == "parquet":
503
- dataframe = spark.read.parquet(path)
504
- elif format_type == "json":
505
- dataframe = spark.read.json(path, **kwargs)
506
- elif format_type == "table":
507
- dataframe = spark.read.table(path)
508
- else:
509
- msg = f"Invalid value of format_type='{format_type}'"
510
- raise ValueError(msg)
511
- else:
512
- msg = "Either data or path parameters must not be None"
513
- raise ValueError(msg)
514
- return dataframe
515
-
516
- def check_df(self, dataframe: SparkDataFrame, columns_mapping: dict[str, str]) -> None:
517
- """
518
- Check:
519
- - if dataframe is not empty,
520
- - if columns from ``columns_mapping`` are present in dataframe
521
- - warn about nulls in columns from ``columns_mapping``
522
- - warn about absent of ``timestamp/relevance`` columns for interactions log
523
- - warn about wrong relevance DataType
524
-
525
- :param dataframe: spark DataFrame to process
526
- :param columns_mapping: dictionary mapping "key: column name in input DataFrame".
527
- Possible keys: ``[user_id, user_id, timestamp, relevance]``
528
- ``columns_mapping`` values specifies the nature of the DataFrame:
529
- - if both ``[user_id, item_id]`` are present,
530
- then the dataframe is a log of interactions.
531
- Specify ``timestamp, relevance`` columns in mapping if available.
532
- - if ether ``user_id`` or ``item_id`` is present,
533
- then the dataframe is a dataframe of user/item features
534
- """
535
- if not dataframe.head(1):
536
- msg = "DataFrame is empty"
537
- raise ValueError(msg)
538
-
539
- for value in columns_mapping.values():
540
- if value not in dataframe.columns:
541
- msg = f"Column `{value}` stated in mapping is absent in dataframe"
542
- raise ValueError(msg)
543
-
544
- for column in columns_mapping.values():
545
- if dataframe.where(sf.col(column).isNull()).count() > 0:
546
- self.logger.info(
547
- "Column `%s` has NULL values. Handle NULL values before "
548
- "the next data preprocessing/model training steps",
549
- column,
550
- )
551
-
552
- if "user_id" in columns_mapping and "item_id" in columns_mapping:
553
- absent_cols = set(LOG_COLUMNS).difference(columns_mapping.keys())
554
- if len(absent_cols) > 0:
555
- self.logger.info(
556
- "Columns %s are absent, but may be required for models training. "
557
- "Add them with DataPreparator().generate_absent_log_cols",
558
- list(absent_cols),
559
- )
560
- if "relevance" in columns_mapping and not isinstance(
561
- dataframe.schema[columns_mapping["relevance"]].dataType,
562
- NumericType,
563
- ):
564
- self.logger.info(
565
- "Relevance column `%s` should be numeric, but it is %s",
566
- columns_mapping["relevance"],
567
- dataframe.schema[columns_mapping["relevance"]].dataType,
568
- )
569
-
570
- @staticmethod
571
- def add_absent_log_cols(
572
- dataframe: SparkDataFrame,
573
- columns_mapping: dict[str, str],
574
- default_relevance: float = 1.0,
575
- default_ts: str = "2099-01-01",
576
- ):
577
- """
578
- Add ``relevance`` and ``timestamp`` columns with default values if
579
- ``relevance`` or ``timestamp`` is absent among mapping keys.
580
-
581
- :param dataframe: interactions log to process
582
- :param columns_mapping: dictionary mapping "key: column name in input DataFrame".
583
- Possible keys: ``[user_id, user_id, timestamp, relevance]``
584
- :param default_relevance: default value for generated `relevance` column
585
- :param default_ts: str, default value for generated `timestamp` column
586
- :return: spark DataFrame with generated ``timestamp`` and ``relevance`` columns
587
- if absent in original dataframe
588
- """
589
- absent_cols = set(LOG_COLUMNS).difference(columns_mapping.keys())
590
- if "relevance" in absent_cols:
591
- dataframe = dataframe.withColumn("relevance", sf.lit(default_relevance).cast(DoubleType()))
592
- if "timestamp" in absent_cols:
593
- dataframe = dataframe.withColumn("timestamp", sf.to_timestamp(sf.lit(default_ts)))
594
- return dataframe
595
-
596
- @staticmethod
597
- def _rename(df: SparkDataFrame, mapping: dict) -> Optional[SparkDataFrame]:
598
- """
599
- rename dataframe columns based on mapping
600
- """
601
- if df is None or mapping is None:
602
- return df
603
- for out_col, in_col in mapping.items():
604
- if in_col in df.columns:
605
- df = df.withColumnRenamed(in_col, out_col)
606
- return df
607
-
608
- def transform(
609
- self,
610
- columns_mapping: dict[str, str],
611
- data: Optional[DataFrameLike] = None,
612
- path: Optional[str] = None,
613
- format_type: Optional[str] = None,
614
- date_format: Optional[str] = None,
615
- reader_kwargs: Optional[dict] = None,
616
- ) -> SparkDataFrame:
617
- """
618
- Transforms log, user or item features into a Spark DataFrame
619
- ``[user_id, user_id, timestamp, relevance]``,
620
- ``[user_id, *features]``, or ``[item_id, *features]``.
621
- Input is either file of ``format_type``
622
- at ``path``, or ``pandas.DataFrame`` or ``spark.DataFrame``.
623
- Transform performs:
624
- - dataframe reading/convert to spark DataFrame format
625
- - check dataframe (nulls, columns_mapping)
626
- - rename columns from mapping to standard names (user_id, user_id, timestamp, relevance)
627
- - for interactions log: create absent columns,
628
- convert ``timestamp`` column to TimestampType and ``relevance`` to DoubleType
629
-
630
- :param columns_mapping: dictionary mapping "key: column name in input DataFrame".
631
- Possible keys: ``[user_id, user_id, timestamp, relevance]``
632
- ``columns_mapping`` values specifies the nature of the DataFrame:
633
- - if both ``[user_id, item_id]`` are present,
634
- then the dataframe is a log of interactions.
635
- Specify ``timestamp, relevance`` columns in mapping if present.
636
- - if ether ``user_id`` or ``item_id`` is present,
637
- then the dataframe is a dataframe of user/item features
638
-
639
- :param data: DataFrame to process
640
- :param path: path to data
641
- :param format_type: file type, one of ``[csv , parquet , json , table]``
642
- :param date_format: format for the ``timestamp`` column
643
- :param reader_kwargs: extra arguments passed to
644
- ``spark.read.<format>(path, **reader_kwargs)``
645
- :return: processed DataFrame
646
- """
647
- is_log = False
648
- if "user_id" in columns_mapping and "item_id" in columns_mapping:
649
- self.logger.info(
650
- "Columns with ids of users or items are present in mapping. "
651
- "The dataframe will be treated as an interactions log."
652
- )
653
- is_log = True
654
- elif "user_id" not in columns_mapping and "item_id" not in columns_mapping:
655
- msg = "Mapping either for user ids or for item ids is not stated in `columns_mapping`"
656
- raise ValueError(msg)
657
- else:
658
- self.logger.info(
659
- "Column with ids of users or items is absent in mapping. "
660
- "The dataframe will be treated as a users'/items' features dataframe."
661
- )
662
- reader_kwargs = {} if reader_kwargs is None else reader_kwargs
663
- dataframe = self.read_as_spark_df(data=data, path=path, format_type=format_type, **reader_kwargs)
664
- self.check_df(dataframe, columns_mapping=columns_mapping)
665
- dataframe = self._rename(df=dataframe, mapping=columns_mapping)
666
- if is_log:
667
- dataframe = self.add_absent_log_cols(dataframe=dataframe, columns_mapping=columns_mapping)
668
- dataframe = dataframe.withColumn("relevance", sf.col("relevance").cast(DoubleType()))
669
- dataframe = process_timestamp_column(
670
- dataframe=dataframe,
671
- column_name="timestamp",
672
- date_format=date_format,
673
- )
674
-
675
- return dataframe
676
-
677
- class CatFeaturesTransformer:
678
- """Transform categorical features in ``cat_cols_list``
679
- with one-hot encoding and remove original columns."""
680
-
681
- def __init__(
682
- self,
683
- cat_cols_list: list,
684
- alias: str = "ohe",
685
- ):
686
- """
687
- :param cat_cols_list: list of categorical columns
688
- :param alias: prefix for one-hot encoding columns
689
- """
690
- self.cat_cols_list = cat_cols_list
691
- self.expressions_list = []
692
- self.alias = alias
693
-
694
- def fit(self, spark_df: Optional[SparkDataFrame]) -> None:
695
- """
696
- Save categories for each column
697
- :param spark_df: Spark DataFrame with features
698
- """
699
- if spark_df is None:
700
- return
701
-
702
- cat_feat_values_dict = {
703
- name: (spark_df.select(sf.collect_set(sf.col(name))).first()[0]) for name in self.cat_cols_list
704
- }
705
- self.expressions_list = [
706
- sf.when(sf.col(col_name) == cur_name, 1)
707
- .otherwise(0)
708
- .alias(
709
- f"""{self.alias}_{col_name}_{str(cur_name).translate(
710
- str.maketrans(
711
- "", "", string.punctuation + string.whitespace
712
- )
713
- )[:30]}"""
714
- )
715
- for col_name, col_values in cat_feat_values_dict.items()
716
- for cur_name in col_values
717
- ]
718
-
719
- def transform(self, spark_df: Optional[SparkDataFrame]):
720
- """
721
- Transform categorical columns.
722
- If there are any new categories that were not present at fit stage, they will be ignored.
723
- :param spark_df: feature DataFrame
724
- :return: transformed DataFrame
725
- """
726
- if spark_df is None:
727
- return None
728
- return spark_df.select(*spark_df.columns, *self.expressions_list).drop(*self.cat_cols_list)
729
-
730
- class ToNumericFeatureTransformer:
731
- """Transform user/item features to numeric types:
732
- - numeric features stays as is
733
- - categorical features:
734
- if threshold is defined:
735
- - all non-numeric columns with less unique values than threshold are one-hot encoded
736
- - remaining columns are dropped
737
- else all non-numeric columns are one-hot encoded
738
- """
739
-
740
- cat_feat_transformer: Optional[CatFeaturesTransformer]
741
- cols_to_ohe: Optional[list]
742
- cols_to_del: Optional[list]
743
- all_columns: Optional[list]
744
-
745
- def __init__(self, threshold: Optional[int] = 100):
746
- self.threshold = threshold
747
- self.fitted = False
748
-
749
- def fit(self, features: Optional[SparkDataFrame]) -> None:
750
- """
751
- Determine categorical columns for one-hot encoding.
752
- Non categorical columns with more values than threshold will be deleted.
753
- Saves categories for each column.
754
- :param features: input DataFrame
755
- """
756
- self.cat_feat_transformer = None
757
- self.cols_to_del = []
758
- self.fitted = True
759
-
760
- if features is None:
761
- self.all_columns = None
762
- return
763
-
764
- self.all_columns = sorted(features.columns)
765
-
766
- spark_df_non_numeric_cols = [
767
- col
768
- for col in features.columns
769
- if (not isinstance(features.schema[col].dataType, NumericType))
770
- and (col not in {"user_idx", "item_idx"})
771
- ]
772
-
773
- # numeric only
774
- if len(spark_df_non_numeric_cols) == 0:
775
- self.cols_to_ohe = []
776
- return
777
-
778
- if self.threshold is None:
779
- self.cols_to_ohe = spark_df_non_numeric_cols
780
- else:
781
- counts_pd = (
782
- features.agg(*[sf.approx_count_distinct(sf.col(c)).alias(c) for c in spark_df_non_numeric_cols])
783
- .toPandas()
784
- .T
785
- )
786
- self.cols_to_ohe = (counts_pd[counts_pd[0] <= self.threshold]).index.values
787
-
788
- self.cols_to_del = [col for col in spark_df_non_numeric_cols if col not in set(self.cols_to_ohe)]
789
-
790
- if self.cols_to_del:
791
- State().logger.warning(
792
- "%s columns contain more that threshold unique values and will be deleted",
793
- self.cols_to_del,
794
- )
795
-
796
- if len(self.cols_to_ohe) > 0:
797
- self.cat_feat_transformer = CatFeaturesTransformer(cat_cols_list=self.cols_to_ohe)
798
- self.cat_feat_transformer.fit(features.drop(*self.cols_to_del))
799
-
800
- def transform(self, spark_df: Optional[SparkDataFrame]) -> Optional[SparkDataFrame]:
801
- """
802
- Transform categorical features.
803
- Use one hot encoding for columns with the amount of unique values smaller
804
- than threshold and delete other columns.
805
- :param spark_df: input DataFrame
806
- :return: processed DataFrame
807
- """
808
- if not self.fitted:
809
- msg = "Call fit before running transform"
810
- raise AttributeError(msg)
811
-
812
- if spark_df is None or self.all_columns is None:
813
- return None
814
-
815
- if self.cat_feat_transformer is None:
816
- return spark_df.drop(*self.cols_to_del)
817
-
818
- if sorted(spark_df.columns) != self.all_columns:
819
- msg = (
820
- f"Columns from fit do not match "
821
- f"columns in transform. "
822
- f"Fit columns: {self.all_columns},"
823
- f"Transform columns: {spark_df.columns}"
824
- )
825
- raise ValueError(msg)
826
-
827
- return self.cat_feat_transformer.transform(spark_df.drop(*self.cols_to_del))
828
-
829
- def fit_transform(self, spark_df: SparkDataFrame) -> SparkDataFrame:
830
- """
831
- :param spark_df: input DataFrame
832
- :return: output DataFrame
833
- """
834
- self.fit(spark_df)
835
- return self.transform(spark_df)
836
-
837
- else:
838
- Indexer = MissingImport
839
- DataPreparator = MissingImport