replay-rec 0.17.0rc0__py3-none-any.whl → 0.17.1rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
replay/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
1
  """ RecSys library """
2
- __version__ = "0.17.0.preview"
2
+ __version__ = "0.17.1.preview"
replay/data/dataset.py CHANGED
@@ -3,11 +3,22 @@
3
3
  """
4
4
  from __future__ import annotations
5
5
 
6
- from typing import Callable, Dict, Iterable, List, Optional, Sequence
6
+ import json
7
+ from pathlib import Path
8
+ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
7
9
 
8
10
  import numpy as np
9
-
10
- from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
11
+ from pandas import read_parquet as pd_read_parquet
12
+ from polars import read_parquet as pl_read_parquet
13
+
14
+ from replay.utils import (
15
+ PYSPARK_AVAILABLE,
16
+ DataFrameLike,
17
+ PandasDataFrame,
18
+ PolarsDataFrame,
19
+ SparkDataFrame,
20
+ )
21
+ from replay.utils.session_handler import get_spark_session
11
22
 
12
23
  from .schema import FeatureHint, FeatureInfo, FeatureSchema, FeatureSource, FeatureType
13
24
 
@@ -47,9 +58,7 @@ class Dataset:
47
58
  self._query_features = query_features
48
59
  self._item_features = item_features
49
60
 
50
- self.is_pandas = isinstance(interactions, PandasDataFrame)
51
- self.is_spark = isinstance(interactions, SparkDataFrame)
52
- self.is_polars = isinstance(interactions, PolarsDataFrame)
61
+ self._assign_df_type()
53
62
 
54
63
  self._categorical_encoded = categorical_encoded
55
64
 
@@ -74,16 +83,8 @@ class Dataset:
74
83
  msg = "Interactions and query features should have the same type."
75
84
  raise TypeError(msg)
76
85
 
77
- self._feature_source_map: Dict[FeatureSource, DataFrameLike] = {
78
- FeatureSource.INTERACTIONS: self.interactions,
79
- FeatureSource.QUERY_FEATURES: self.query_features,
80
- FeatureSource.ITEM_FEATURES: self.item_features,
81
- }
82
-
83
- self._ids_feature_map: Dict[FeatureHint, DataFrameLike] = {
84
- FeatureHint.QUERY_ID: self.query_features if self.query_features is not None else self.interactions,
85
- FeatureHint.ITEM_ID: self.item_features if self.item_features is not None else self.interactions,
86
- }
86
+ self._get_feature_source_map()
87
+ self._get_ids_source_map()
87
88
 
88
89
  self._feature_schema = self._fill_feature_schema(feature_schema)
89
90
 
@@ -92,7 +93,6 @@ class Dataset:
92
93
  self._check_ids_consistency(hint=FeatureHint.QUERY_ID)
93
94
  if self.item_features is not None:
94
95
  self._check_ids_consistency(hint=FeatureHint.ITEM_ID)
95
-
96
96
  if self._categorical_encoded:
97
97
  self._check_encoded()
98
98
 
@@ -189,6 +189,157 @@ class Dataset:
189
189
  """
190
190
  return self._feature_schema
191
191
 
192
+ def _get_df_type(self) -> str:
193
+ """
194
+ :returns: Stored dataframe type.
195
+ """
196
+ if self.is_spark:
197
+ return "spark"
198
+ if self.is_pandas:
199
+ return "pandas"
200
+ if self.is_polars:
201
+ return "polars"
202
+ msg = "No known dataframe types are provided"
203
+ raise ValueError(msg)
204
+
205
+ def _to_parquet(self, df: DataFrameLike, path: Path) -> None:
206
+ """
207
+ Save the content of the dataframe in parquet format to the provided path.
208
+
209
+ :param df: Dataframe to save.
210
+ :param path: Path to save the dataframe to.
211
+ """
212
+ if self.is_spark:
213
+ path = str(path)
214
+ df = df.withColumn("idx", sf.monotonically_increasing_id())
215
+ df.write.mode("overwrite").parquet(path)
216
+ elif self.is_pandas:
217
+ df.to_parquet(path)
218
+ elif self.is_polars:
219
+ df.write_parquet(path)
220
+ else:
221
+ msg = """
222
+ _to_parquet() can only be used to save polars|pandas|spark dataframes;
223
+ No known dataframe types are provided
224
+ """
225
+ raise TypeError(msg)
226
+
227
+ @staticmethod
228
+ def _read_parquet(path: Path, mode: str) -> Union[SparkDataFrame, PandasDataFrame, PolarsDataFrame]:
229
+ """
230
+ Read the parquet file as dataframe.
231
+
232
+ :param path: The parquet file path.
233
+ :param mode: Dataframe type. Can be spark|pandas|polars.
234
+ :returns: The dataframe read from the file.
235
+ """
236
+ if mode == "spark":
237
+ path = str(path)
238
+ spark_session = get_spark_session()
239
+ df = spark_session.read.parquet(path)
240
+ if "idx" in df.columns:
241
+ df = df.orderBy("idx").drop("idx")
242
+ return df
243
+ if mode == "pandas":
244
+ df = pd_read_parquet(path)
245
+ if "idx" in df.columns:
246
+ df = df.set_index("idx").reset_index(drop=True)
247
+ return df
248
+ if mode == "polars":
249
+ df = pl_read_parquet(path, use_pyarrow=True)
250
+ if "idx" in df.columns:
251
+ df = df.sort("idx").drop("idx")
252
+ return df
253
+ msg = f"_read_parquet() can only be used to read polars|pandas|spark dataframes, not {mode}"
254
+ raise TypeError(msg)
255
+
256
+ def save(self, path: str) -> None:
257
+ """
258
+ Save the Dataset to the provided path.
259
+
260
+ :param path: Path to save the Dataset to.
261
+ """
262
+ dataset_dict = {}
263
+ dataset_dict["_class_name"] = self.__class__.__name__
264
+
265
+ interactions_type = self._get_df_type()
266
+ dataset_dict["init_args"] = {
267
+ "feature_schema": [],
268
+ "interactions": interactions_type,
269
+ "item_features": (interactions_type if self.item_features is not None else None),
270
+ "query_features": (interactions_type if self.query_features is not None else None),
271
+ "check_consistency": False,
272
+ "categorical_encoded": self._categorical_encoded,
273
+ }
274
+
275
+ for feature in self.feature_schema.all_features:
276
+ dataset_dict["init_args"]["feature_schema"].append(
277
+ {
278
+ "column": feature.column,
279
+ "feature_type": feature.feature_type.name,
280
+ "feature_hint": (feature.feature_hint.name if feature.feature_hint else None),
281
+ }
282
+ )
283
+
284
+ base_path = Path(path).with_suffix(".replay").resolve()
285
+ base_path.mkdir(parents=True, exist_ok=True)
286
+
287
+ with open(base_path / "init_args.json", "w+") as file:
288
+ json.dump(dataset_dict, file)
289
+
290
+ df_data = {
291
+ "interactions": self.interactions,
292
+ "item_features": self.item_features,
293
+ "query_features": self.query_features,
294
+ }
295
+
296
+ for df_name, df in df_data.items():
297
+ if df is not None:
298
+ df_path = base_path / f"{df_name}.parquet"
299
+ self._to_parquet(df, df_path)
300
+
301
+ @classmethod
302
+ def load(
303
+ cls,
304
+ path: str,
305
+ dataframe_type: Optional[str] = None,
306
+ ) -> Dataset:
307
+ """
308
+ Load the Dataset from the provided path.
309
+
310
+ :param path: The file path
311
+ :dataframe_type: Dataframe type to use to store internal data.
312
+ Can be spark|pandas|polars|None.
313
+ If not provided automatically sets to the one used when the Dataset was saved.
314
+ :returns: Loaded Dataset.
315
+ """
316
+ base_path = Path(path).with_suffix(".replay").resolve()
317
+ with open(base_path / "init_args.json", "r") as file:
318
+ dataset_dict = json.loads(file.read())
319
+
320
+ if dataframe_type not in ["pandas", "spark", "polars", None]:
321
+ msg = f"Argument dataframe_type can be spark|pandas|polars|None, not {dataframe_type}"
322
+ raise ValueError(msg)
323
+
324
+ feature_schema_data = dataset_dict["init_args"]["feature_schema"]
325
+ features_list = []
326
+ for feature_data in feature_schema_data:
327
+ f_type = feature_data["feature_type"]
328
+ f_hint = feature_data["feature_hint"]
329
+ feature_data["feature_type"] = FeatureType[f_type] if f_type else None
330
+ feature_data["feature_hint"] = FeatureHint[f_hint] if f_hint else None
331
+ features_list.append(FeatureInfo(**feature_data))
332
+ dataset_dict["init_args"]["feature_schema"] = FeatureSchema(features_list)
333
+
334
+ for df_name in ["interactions", "query_features", "item_features"]:
335
+ df_type = dataset_dict["init_args"][df_name]
336
+ if df_type:
337
+ df_type = dataframe_type or df_type
338
+ load_path = base_path / f"{df_name}.parquet"
339
+ dataset_dict["init_args"][df_name] = cls._read_parquet(load_path, df_type)
340
+ dataset = cls(**dataset_dict["init_args"])
341
+ return dataset
342
+
192
343
  if PYSPARK_AVAILABLE:
193
344
 
194
345
  def persist(self, storage_level: StorageLevel = StorageLevel(True, True, False, True, 1)) -> None:
@@ -283,6 +434,24 @@ class Dataset:
283
434
  categorical_encoded=self._categorical_encoded,
284
435
  )
285
436
 
437
+ def _get_feature_source_map(self):
438
+ self._feature_source_map: Dict[FeatureSource, DataFrameLike] = {
439
+ FeatureSource.INTERACTIONS: self.interactions,
440
+ FeatureSource.QUERY_FEATURES: self.query_features,
441
+ FeatureSource.ITEM_FEATURES: self.item_features,
442
+ }
443
+
444
+ def _get_ids_source_map(self):
445
+ self._ids_feature_map: Dict[FeatureHint, DataFrameLike] = {
446
+ FeatureHint.QUERY_ID: self.query_features if self.query_features is not None else self.interactions,
447
+ FeatureHint.ITEM_ID: self.item_features if self.item_features is not None else self.interactions,
448
+ }
449
+
450
+ def _assign_df_type(self):
451
+ self.is_pandas = isinstance(self.interactions, PandasDataFrame)
452
+ self.is_spark = isinstance(self.interactions, SparkDataFrame)
453
+ self.is_polars = isinstance(self.interactions, PolarsDataFrame)
454
+
286
455
  def _get_cardinality(self, feature: FeatureInfo) -> Callable:
287
456
  def callback(column: str) -> int:
288
457
  if feature.feature_hint in [FeatureHint.ITEM_ID, FeatureHint.QUERY_ID]:
@@ -381,7 +550,11 @@ class Dataset:
381
550
  is_consistent = (
382
551
  self.interactions.select(ids_column)
383
552
  .distinct()
384
- .join(features_df.select(ids_column).distinct(), on=[ids_column], how="leftanti")
553
+ .join(
554
+ features_df.select(ids_column).distinct(),
555
+ on=[ids_column],
556
+ how="leftanti",
557
+ )
385
558
  .count()
386
559
  ) == 0
387
560
  else:
@@ -389,7 +562,11 @@ class Dataset:
389
562
  len(
390
563
  self.interactions.select(ids_column)
391
564
  .unique()
392
- .join(features_df.select(ids_column).unique(), on=ids_column, how="anti")
565
+ .join(
566
+ features_df.select(ids_column).unique(),
567
+ on=ids_column,
568
+ how="anti",
569
+ )
393
570
  )
394
571
  == 0
395
572
  )
@@ -399,7 +576,11 @@ class Dataset:
399
576
  raise ValueError(msg)
400
577
 
401
578
  def _check_column_encoded(
402
- self, data: DataFrameLike, column: str, source: FeatureSource, cardinality: Optional[int]
579
+ self,
580
+ data: DataFrameLike,
581
+ column: str,
582
+ source: FeatureSource,
583
+ cardinality: Optional[int],
403
584
  ) -> None:
404
585
  """
405
586
  Checks that IDs are encoded:
@@ -482,6 +663,51 @@ class Dataset:
482
663
  feature.cardinality,
483
664
  )
484
665
 
666
+ def to_pandas(self) -> None:
667
+ """
668
+ Convert internally stored dataframes to pandas.DataFrame.
669
+ """
670
+ from replay.utils.common import convert2pandas
671
+
672
+ self._interactions = convert2pandas(self._interactions)
673
+ if self._query_features is not None:
674
+ self._query_features = convert2pandas(self._query_features)
675
+ if self._item_features is not None:
676
+ self._item_features = convert2pandas(self.item_features)
677
+ self._get_feature_source_map()
678
+ self._get_ids_source_map()
679
+ self._assign_df_type()
680
+
681
+ def to_spark(self):
682
+ """
683
+ Convert internally stored dataframes to pyspark.sql.DataFrame.
684
+ """
685
+ from replay.utils.common import convert2spark
686
+
687
+ self._interactions = convert2spark(self._interactions)
688
+ if self._query_features is not None:
689
+ self._query_features = convert2spark(self._query_features)
690
+ if self._item_features is not None:
691
+ self._item_features = convert2spark(self._item_features)
692
+ self._get_feature_source_map()
693
+ self._get_ids_source_map()
694
+ self._assign_df_type()
695
+
696
+ def to_polars(self):
697
+ """
698
+ Convert internally stored dataframes to polars.DataFrame.
699
+ """
700
+ from replay.utils.common import convert2polars
701
+
702
+ self._interactions = convert2polars(self._interactions)
703
+ if self._query_features is not None:
704
+ self._query_features = convert2polars(self._query_features)
705
+ if self._item_features is not None:
706
+ self._item_features = convert2polars(self._item_features)
707
+ self._get_feature_source_map()
708
+ self._get_ids_source_map()
709
+ self._assign_df_type()
710
+
485
711
 
486
712
  def nunique(data: DataFrameLike, column: str) -> int:
487
713
  """
replay/data/nn/schema.py CHANGED
@@ -408,6 +408,48 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
408
408
  return None
409
409
  return rating_features.item().name
410
410
 
411
+ def _get_object_args(self) -> Dict:
412
+ """
413
+ Returns list of features represented as dictionaries.
414
+ """
415
+ features = [
416
+ {
417
+ "name": feature.name,
418
+ "feature_type": feature.feature_type.name,
419
+ "is_seq": feature.is_seq,
420
+ "feature_hint": feature.feature_hint.name if feature.feature_hint else None,
421
+ "feature_sources": [
422
+ {"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources
423
+ ]
424
+ if feature.feature_sources
425
+ else None,
426
+ "cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
427
+ "embedding_dim": feature.embedding_dim if feature.feature_type == FeatureType.CATEGORICAL else None,
428
+ "tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
429
+ }
430
+ for feature in self.all_features
431
+ ]
432
+ return features
433
+
434
+ @classmethod
435
+ def _create_object_by_args(cls, args: Dict) -> "TensorSchema":
436
+ features_list = []
437
+ for feature_data in args:
438
+ feature_data["feature_sources"] = (
439
+ [
440
+ TensorFeatureSource(source=FeatureSource[x["source"]], column=x["column"], index=x["index"])
441
+ for x in feature_data["feature_sources"]
442
+ ]
443
+ if feature_data["feature_sources"]
444
+ else None
445
+ )
446
+ f_type = feature_data["feature_type"]
447
+ f_hint = feature_data["feature_hint"]
448
+ feature_data["feature_type"] = FeatureType[f_type] if f_type else None
449
+ feature_data["feature_hint"] = FeatureHint[f_hint] if f_hint else None
450
+ features_list.append(TensorFeatureInfo(**feature_data))
451
+ return TensorSchema(features_list)
452
+
411
453
  def filter(
412
454
  self,
413
455
  name: Optional[str] = None,
@@ -24,7 +24,10 @@ SequenceDataFrameLike = Union[PandasDataFrame, PolarsDataFrame]
24
24
 
25
25
  class SequenceTokenizer:
26
26
  """
27
- Data tokenizer for transformers
27
+ Data tokenizer for transformers;
28
+ Encodes all categorical features (the ones marked as FeatureType.CATEGORICAL in
29
+ the FeatureSchema) and stores all data as items sequences (sorted by time if a
30
+ feature of type FeatureHint.TIMESTAMP is provided, unsorted otherwise).
28
31
  """
29
32
 
30
33
  def __init__(
@@ -278,17 +281,17 @@ class SequenceTokenizer:
278
281
  ]
279
282
 
280
283
  for tensor_feature in tensor_schema.values():
281
- source = tensor_feature.feature_source
282
- assert source is not None
284
+ for source in tensor_feature.feature_sources:
285
+ assert source is not None
283
286
 
284
- # Some columns already added to encoder, skip them
285
- if source.column in features_subset:
286
- continue
287
+ # Some columns already added to encoder, skip them
288
+ if source.column in features_subset:
289
+ continue
287
290
 
288
- if isinstance(source.source, FeatureSource):
289
- features_subset.append(source.column)
290
- else:
291
- assert False, "Unknown tensor feature source"
291
+ if isinstance(source.source, FeatureSource):
292
+ features_subset.append(source.column)
293
+ else:
294
+ assert False, "Unknown tensor feature source"
292
295
 
293
296
  return set(features_subset)
294
297
 
@@ -404,7 +407,7 @@ class SequenceTokenizer:
404
407
 
405
408
  @classmethod
406
409
  @deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
407
- def load(cls, path: str, use_pickle: bool = False) -> "SequenceTokenizer":
410
+ def load(cls, path: str, use_pickle: bool = False, **kwargs) -> "SequenceTokenizer":
408
411
  """
409
412
  Load tokenizer object from the given path.
410
413
 
@@ -422,18 +425,7 @@ class SequenceTokenizer:
422
425
 
423
426
  # load tensor_schema, tensor_features
424
427
  tensor_schema_data = tokenizer_dict["init_args"]["tensor_schema"]
425
- features_list = []
426
- for feature_data in tensor_schema_data:
427
- feature_data["feature_sources"] = [
428
- TensorFeatureSource(source=FeatureSource[x["source"]], column=x["column"], index=x["index"])
429
- for x in feature_data["feature_sources"]
430
- ]
431
- f_type = feature_data["feature_type"]
432
- f_hint = feature_data["feature_hint"]
433
- feature_data["feature_type"] = FeatureType[f_type] if f_type else None
434
- feature_data["feature_hint"] = FeatureHint[f_hint] if f_hint else None
435
- features_list.append(TensorFeatureInfo(**feature_data))
436
- tokenizer_dict["init_args"]["tensor_schema"] = TensorSchema(features_list)
428
+ tokenizer_dict["init_args"]["tensor_schema"] = TensorSchema._create_object_by_args(tensor_schema_data)
437
429
 
438
430
  # Load encoder columns and rules
439
431
  types = list(FeatureHint) + list(FeatureSource)
@@ -447,7 +439,7 @@ class SequenceTokenizer:
447
439
  rule_data = rules_dict[rule]
448
440
  if rule_data["mapping"] and rule_data["is_int"]:
449
441
  rule_data["mapping"] = {int(key): value for key, value in rule_data["mapping"].items()}
450
- del rule_data["is_int"]
442
+ del rule_data["is_int"]
451
443
 
452
444
  tokenizer_dict["encoder"]["encoding_rules"][rule] = LabelEncodingRule(**rule_data)
453
445
 
@@ -478,31 +470,9 @@ class SequenceTokenizer:
478
470
  "allow_collect_to_master": self._allow_collect_to_master,
479
471
  "handle_unknown_rule": self._encoder._handle_unknown_rule,
480
472
  "default_value_rule": self._encoder._default_value_rule,
481
- "tensor_schema": [],
473
+ "tensor_schema": self._tensor_schema._get_object_args(),
482
474
  }
483
475
 
484
- # save tensor schema
485
- for feature in list(self._tensor_schema.values()):
486
- tokenizer_dict["init_args"]["tensor_schema"].append(
487
- {
488
- "name": feature.name,
489
- "feature_type": feature.feature_type.name,
490
- "is_seq": feature.is_seq,
491
- "feature_hint": feature.feature_hint.name if feature.feature_hint else None,
492
- "feature_sources": [
493
- {"source": x.source.name, "column": x.column, "index": x.index}
494
- for x in feature.feature_sources
495
- ]
496
- if feature.feature_sources
497
- else None,
498
- "cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
499
- "embedding_dim": feature.embedding_dim
500
- if feature.feature_type == FeatureType.CATEGORICAL
501
- else None,
502
- "tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
503
- }
504
- )
505
-
506
476
  # save DatasetLabelEncoder
507
477
  tokenizer_dict["encoder"] = {
508
478
  "features_columns": {key.name: value for key, value in self._encoder._features_columns.items()},
@@ -1,7 +1,10 @@
1
1
  import abc
2
+ import json
3
+ from pathlib import Path
2
4
  from typing import Tuple, Union
3
5
 
4
6
  import numpy as np
7
+ import pandas as pd
5
8
  import polars as pl
6
9
  from pandas import DataFrame as PandasDataFrame
7
10
  from polars import DataFrame as PolarsDataFrame
@@ -100,6 +103,23 @@ class SequentialDataset(abc.ABC):
100
103
  rhs_filtered = rhs.filter_by_query_id(common_queries)
101
104
  return lhs_filtered, rhs_filtered
102
105
 
106
+ def save(self, path: str) -> None:
107
+ base_path = Path(path).with_suffix(".replay").resolve()
108
+ base_path.mkdir(parents=True, exist_ok=True)
109
+
110
+ sequential_dict = {}
111
+ sequential_dict["_class_name"] = self.__class__.__name__
112
+ self._sequences.reset_index().to_json(base_path / "sequences.json")
113
+ sequential_dict["init_args"] = {
114
+ "tensor_schema": self._tensor_schema._get_object_args(),
115
+ "query_id_column": self._query_id_column,
116
+ "item_id_column": self._item_id_column,
117
+ "sequences_path": "sequences.json",
118
+ }
119
+
120
+ with open(base_path / "init_args.json", "w+") as file:
121
+ json.dump(sequential_dict, file)
122
+
103
123
 
104
124
  class PandasSequentialDataset(SequentialDataset):
105
125
  """
@@ -174,6 +194,25 @@ class PandasSequentialDataset(SequentialDataset):
174
194
  msg = "Tensor schema does not match with provided data frame"
175
195
  raise ValueError(msg)
176
196
 
197
+ @classmethod
198
+ def load(cls, path: str, **kwargs) -> "PandasSequentialDataset":
199
+ """
200
+ Method for loading PandasSequentialDataset object from `.replay` directory.
201
+ """
202
+ base_path = Path(path).with_suffix(".replay").resolve()
203
+ with open(base_path / "init_args.json", "r") as file:
204
+ sequential_dict = json.loads(file.read())
205
+
206
+ sequences = pd.read_json(base_path / sequential_dict["init_args"]["sequences_path"])
207
+ dataset = cls(
208
+ tensor_schema=TensorSchema._create_object_by_args(sequential_dict["init_args"]["tensor_schema"]),
209
+ query_id_column=sequential_dict["init_args"]["query_id_column"],
210
+ item_id_column=sequential_dict["init_args"]["item_id_column"],
211
+ sequences=sequences,
212
+ )
213
+
214
+ return dataset
215
+
177
216
 
178
217
  class PolarsSequentialDataset(PandasSequentialDataset):
179
218
  """
@@ -199,7 +238,7 @@ class PolarsSequentialDataset(PandasSequentialDataset):
199
238
  self._query_id_column = query_id_column
200
239
  self._item_id_column = item_id_column
201
240
 
202
- self._sequences = sequences.to_pandas()
241
+ self._sequences = self._convert_polars_to_pandas(sequences)
203
242
  if self._sequences.index.name != query_id_column:
204
243
  self._sequences = self._sequences.set_index(query_id_column)
205
244
 
@@ -211,12 +250,47 @@ class PolarsSequentialDataset(PandasSequentialDataset):
211
250
  tensor_schema=self._tensor_schema,
212
251
  query_id_column=self._query_id_column,
213
252
  item_id_column=self._item_id_column,
214
- sequences=pl.from_pandas(filtered_sequences),
253
+ sequences=self._convert_pandas_to_polars(filtered_sequences),
215
254
  )
216
255
 
256
+ def _convert_polars_to_pandas(self, df: PolarsDataFrame) -> PandasDataFrame:
257
+ pandas_df = PandasDataFrame(df.to_dict(as_series=False))
258
+
259
+ for column in pandas_df.select_dtypes(include="object").columns:
260
+ if isinstance(pandas_df[column].iloc[0], list):
261
+ pandas_df[column] = pandas_df[column].apply(lambda x: np.array(x))
262
+
263
+ return pandas_df
264
+
265
+ def _convert_pandas_to_polars(self, df: PandasDataFrame) -> PolarsDataFrame:
266
+ for column in df.select_dtypes(include="object").columns:
267
+ if isinstance(df[column].iloc[0], np.ndarray):
268
+ df[column] = df[column].apply(lambda x: x.tolist())
269
+
270
+ return pl.from_dict(df.to_dict("list"))
271
+
217
272
  @classmethod
218
273
  def _check_if_schema_matches_data(cls, tensor_schema: TensorSchema, data: PolarsDataFrame) -> None:
219
274
  for tensor_feature_name in tensor_schema:
220
275
  if tensor_feature_name not in data:
221
276
  msg = "Tensor schema does not match with provided data frame"
222
277
  raise ValueError(msg)
278
+
279
+ @classmethod
280
+ def load(cls, path: str, **kwargs) -> "PandasSequentialDataset":
281
+ """
282
+ Method for loading PandasSequentialDataset object from `.replay` directory.
283
+ """
284
+ base_path = Path(path).with_suffix(".replay").resolve()
285
+ with open(base_path / "init_args.json", "r") as file:
286
+ sequential_dict = json.loads(file.read())
287
+
288
+ sequences = pl.DataFrame(pd.read_json(base_path / sequential_dict["init_args"]["sequences_path"]))
289
+ dataset = cls(
290
+ tensor_schema=TensorSchema._create_object_by_args(sequential_dict["init_args"]["tensor_schema"]),
291
+ query_id_column=sequential_dict["init_args"]["query_id_column"],
292
+ item_id_column=sequential_dict["init_args"]["item_id_column"],
293
+ sequences=sequences,
294
+ )
295
+
296
+ return dataset
@@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
5
5
  from datetime import datetime, timedelta
6
6
  from typing import Callable, Optional, Tuple, Union
7
7
 
8
+ import numpy as np
9
+ import pandas as pd
8
10
  import polars as pl
9
11
 
10
12
  from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
@@ -357,7 +359,7 @@ class NumInteractionsFilter(_BaseFilter):
357
359
  ... "2020-02-01", "2020-01-01 00:04:15",
358
360
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
359
361
  ... )
360
- >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
362
+ >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
361
363
  >>> log_sp = convert2spark(log_pd)
362
364
  >>> log_sp.show()
363
365
  +-------+-------+------+-------------------+
@@ -499,7 +501,7 @@ class EntityDaysFilter(_BaseFilter):
499
501
  ... "2020-02-01", "2020-01-01 00:04:15",
500
502
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
501
503
  ... )
502
- >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
504
+ >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
503
505
  >>> log_sp = convert2spark(log_pd)
504
506
  >>> log_sp.orderBy('user_id', 'item_id').show()
505
507
  +-------+-------+------+-------------------+
@@ -638,7 +640,7 @@ class GlobalDaysFilter(_BaseFilter):
638
640
  ... "2020-02-01", "2020-01-01 00:04:15",
639
641
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
640
642
  ... )
641
- >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
643
+ >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
642
644
  >>> log_sp = convert2spark(log_pd)
643
645
  >>> log_sp.show()
644
646
  +-------+-------+------+-------------------+
@@ -740,7 +742,7 @@ class TimePeriodFilter(_BaseFilter):
740
742
  ... "2020-02-01", "2020-01-01 00:04:15",
741
743
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
742
744
  ... )
743
- >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
745
+ >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
744
746
  >>> log_sp = convert2spark(log_pd)
745
747
  >>> log_sp.show()
746
748
  +-------+-------+------+-------------------+
@@ -823,3 +825,166 @@ class TimePeriodFilter(_BaseFilter):
823
825
  return interactions.filter(
824
826
  pl.col(self.timestamp_column).is_between(self.start_date, self.end_date, closed="left")
825
827
  )
828
+
829
+
830
+ class QuantileItemsFilter(_BaseFilter):
831
+ """
832
+ Filter is aimed on undersampling the interactions dataset.
833
+
834
+ Filter algorithm performs undersampling by removing `items_proportion` of interactions
835
+ for each items counts that exceeds the `alpha_quantile` value in distribution. Filter firstly
836
+ removes popular items (items that have most interactions). Filter also keeps the original
837
+ relation of items popularity among each other by removing interactions only in range of
838
+ current item count and quantile count (specified by `alpha_quantile`).
839
+
840
+ >>> import pandas as pd
841
+ >>> from replay.utils.spark_utils import convert2spark
842
+ >>> log_pd = pd.DataFrame({
843
+ ... "user_id": [0, 0, 1, 2, 2, 2, 2],
844
+ ... "item_id": [0, 2, 1, 1, 2, 2, 2]
845
+ ... })
846
+ >>> log_spark = convert2spark(log_pd)
847
+ >>> log_spark.show()
848
+ +-------+-------+
849
+ |user_id|item_id|
850
+ +-------+-------+
851
+ | 0| 0|
852
+ | 0| 2|
853
+ | 1| 1|
854
+ | 2| 1|
855
+ | 2| 2|
856
+ | 2| 2|
857
+ | 2| 2|
858
+ +-------+-------+
859
+ <BLANKLINE>
860
+
861
+ >>> QuantileItemsFilter(query_column="user_id").transform(log_spark).show()
862
+ +-------+-------+
863
+ |user_id|item_id|
864
+ +-------+-------+
865
+ | 0| 0|
866
+ | 1| 1|
867
+ | 2| 1|
868
+ | 2| 2|
869
+ | 2| 2|
870
+ | 0| 2|
871
+ +-------+-------+
872
+ <BLANKLINE>
873
+ """
874
+
875
+ def __init__(
876
+ self,
877
+ alpha_quantile: float = 0.99,
878
+ items_proportion: float = 0.5,
879
+ query_column: str = "query_id",
880
+ item_column: str = "item_id",
881
+ ) -> None:
882
+ """
883
+ :param alpha_quantile: Quantile value of items counts distribution to keep unchanged.
884
+ Every items count that exceeds this value will be undersampled.
885
+ Default: ``0.99``.
886
+ :param items_proportion: proportion of items counts to remove for items that
887
+ exceeds `alpha_quantile` value in range of current item count and quantile count
888
+ to make sure we keep original relation between items unchanged.
889
+ Default: ``0.5``.
890
+ :param query_column: query column name.
891
+ Default: ``query_id``.
892
+ :param item_column: item column name.
893
+ Default: ``item_id``.
894
+ """
895
+ if not 0 < alpha_quantile < 1:
896
+ msg = "`alpha_quantile` value must be in (0, 1)"
897
+ raise ValueError(msg)
898
+ if not 0 < items_proportion < 1:
899
+ msg = "`items_proportion` value must be in (0, 1)"
900
+ raise ValueError(msg)
901
+
902
+ self.alpha_quantile = alpha_quantile
903
+ self.items_proportion = items_proportion
904
+ self.query_column = query_column
905
+ self.item_column = item_column
906
+
907
+ def _filter_pandas(self, df: pd.DataFrame):
908
+ items_distribution = df.groupby(self.item_column).size().reset_index().rename(columns={0: "counts"})
909
+ users_distribution = df.groupby(self.query_column).size().reset_index().rename(columns={0: "counts"})
910
+ count_threshold = items_distribution.loc[:, "counts"].quantile(self.alpha_quantile, interpolation="midpoint")
911
+ df_with_counts = df.merge(items_distribution, how="left", on=self.item_column).merge(
912
+ users_distribution, how="left", on=self.query_column, suffixes=["_items", "_users"]
913
+ )
914
+ long_tail = df_with_counts.loc[df_with_counts["counts_items"] <= count_threshold]
915
+ short_tail = df_with_counts.loc[df_with_counts["counts_items"] > count_threshold]
916
+ short_tail["num_items_to_delete"] = self.items_proportion * (
917
+ short_tail["counts_items"] - long_tail["counts_items"].max()
918
+ )
919
+ short_tail["num_items_to_delete"] = short_tail["num_items_to_delete"].astype("int")
920
+ short_tail = short_tail.sort_values("counts_users", ascending=False)
921
+
922
+ def get_mask(x):
923
+ mask = np.ones_like(x)
924
+ threshold = x.iloc[0]
925
+ mask[:threshold] = 0
926
+ return mask
927
+
928
+ mask = short_tail.groupby(self.item_column)["num_items_to_delete"].transform(get_mask).astype(bool)
929
+ return pd.concat([long_tail[df.columns], short_tail.loc[mask][df.columns]])
930
+
931
+ def _filter_polars(self, df: pl.DataFrame):
932
+ items_distribution = df.group_by(self.item_column).len()
933
+ users_distribution = df.group_by(self.query_column).len()
934
+ count_threshold = items_distribution.select("len").quantile(self.alpha_quantile, "midpoint")["len"][0]
935
+ df_with_counts = (
936
+ df.join(items_distribution, how="left", on=self.item_column).join(
937
+ users_distribution, how="left", on=self.query_column
938
+ )
939
+ ).rename({"len": "counts_items", "len_right": "counts_users"})
940
+ long_tail = df_with_counts.filter(pl.col("counts_items") <= count_threshold)
941
+ short_tail = df_with_counts.filter(pl.col("counts_items") > count_threshold)
942
+ max_long_tail_count = long_tail["counts_items"].max()
943
+ items_to_delete = (
944
+ short_tail.select(
945
+ self.query_column,
946
+ self.item_column,
947
+ self.items_proportion * (pl.col("counts_items") - max_long_tail_count),
948
+ )
949
+ .with_columns(pl.col("literal").cast(pl.Int64).alias("num_items_to_delete"))
950
+ .select(self.item_column, "num_items_to_delete")
951
+ .unique(maintain_order=True)
952
+ )
953
+ short_tail = short_tail.join(items_to_delete, how="left", on=self.item_column).sort(
954
+ "counts_users", descending=True
955
+ )
956
+ short_tail = short_tail.with_columns(index=pl.int_range(short_tail.shape[0]))
957
+ grouped = short_tail.group_by(self.item_column, maintain_order=True).agg(
958
+ pl.col("index"), pl.col("num_items_to_delete")
959
+ )
960
+ grouped = grouped.with_columns(
961
+ pl.col("num_items_to_delete").list.get(0),
962
+ (pl.col("index").list.len() - pl.col("num_items_to_delete").list.get(0)).alias("tail"),
963
+ )
964
+ grouped = grouped.with_columns(pl.col("index").list.tail(pl.col("tail")))
965
+ grouped = grouped.explode("index").select("index")
966
+ short_tail = grouped.join(short_tail, how="left", on="index")
967
+ return pl.concat([long_tail.select(df.columns), short_tail.select(df.columns)])
968
+
969
+ def _filter_spark(self, df: SparkDataFrame):
970
+ items_distribution = df.groupBy(self.item_column).agg(sf.count(self.query_column).alias("counts_items"))
971
+ users_distribution = df.groupBy(self.query_column).agg(sf.count(self.item_column).alias("counts_users"))
972
+ count_threshold = items_distribution.toPandas().loc[:, "counts_items"].quantile(self.alpha_quantile, "midpoint")
973
+ df_with_counts = df.join(items_distribution, on=self.item_column).join(users_distribution, on=self.query_column)
974
+ long_tail = df_with_counts.filter(sf.col("counts_items") <= count_threshold)
975
+ short_tail = df_with_counts.filter(sf.col("counts_items") > count_threshold)
976
+ max_long_tail_count = long_tail.agg({"counts_items": "max"}).collect()[0][0]
977
+ items_to_delete = (
978
+ short_tail.withColumn(
979
+ "num_items_to_delete",
980
+ (self.items_proportion * (sf.col("counts_items") - max_long_tail_count)).cast("int"),
981
+ )
982
+ .select(self.item_column, "num_items_to_delete")
983
+ .distinct()
984
+ )
985
+ short_tail = short_tail.join(items_to_delete, on=self.item_column, how="left")
986
+ short_tail = short_tail.withColumn(
987
+ "index", sf.row_number().over(Window.partitionBy(self.item_column).orderBy(sf.col("counts_users").desc()))
988
+ )
989
+ short_tail = short_tail.filter(sf.col("index") > sf.col("num_items_to_delete"))
990
+ return long_tail.select(df.columns).union(short_tail.select(df.columns))
@@ -85,7 +85,7 @@ class Splitter(ABC):
85
85
  json.dump(splitter_dict, file)
86
86
 
87
87
  @classmethod
88
- def load(cls, path: str) -> "Splitter":
88
+ def load(cls, path: str, **kwargs) -> "Splitter":
89
89
  """
90
90
  Method for loading splitter from `.replay` directory.
91
91
  """
replay/utils/common.py CHANGED
@@ -1,7 +1,12 @@
1
+ import functools
2
+ import inspect
1
3
  import json
2
4
  from pathlib import Path
3
- from typing import Union
5
+ from typing import Any, Callable, Union
4
6
 
7
+ from polars import from_pandas as pl_from_pandas
8
+
9
+ from replay.data.dataset import Dataset
5
10
  from replay.splitters import (
6
11
  ColdUserRandomSplitter,
7
12
  KFolds,
@@ -12,7 +17,16 @@ from replay.splitters import (
12
17
  TimeSplitter,
13
18
  TwoStageSplitter,
14
19
  )
15
- from replay.utils import TORCH_AVAILABLE
20
+ from replay.utils import (
21
+ TORCH_AVAILABLE,
22
+ PandasDataFrame,
23
+ PolarsDataFrame,
24
+ SparkDataFrame,
25
+ )
26
+ from replay.utils.spark_utils import (
27
+ convert2spark as pandas_to_spark,
28
+ spark_to_pandas,
29
+ )
16
30
 
17
31
  SavableObject = Union[
18
32
  ColdUserRandomSplitter,
@@ -23,10 +37,11 @@ SavableObject = Union[
23
37
  RatioSplitter,
24
38
  TimeSplitter,
25
39
  TwoStageSplitter,
40
+ Dataset,
26
41
  ]
27
42
 
28
43
  if TORCH_AVAILABLE:
29
- from replay.data.nn import SequenceTokenizer
44
+ from replay.data.nn import PandasSequentialDataset, PolarsSequentialDataset, SequenceTokenizer
30
45
 
31
46
  SavableObject = Union[
32
47
  ColdUserRandomSplitter,
@@ -38,6 +53,8 @@ if TORCH_AVAILABLE:
38
53
  TimeSplitter,
39
54
  TwoStageSplitter,
40
55
  SequenceTokenizer,
56
+ PandasSequentialDataset,
57
+ PolarsSequentialDataset,
41
58
  ]
42
59
 
43
60
 
@@ -50,7 +67,7 @@ def save_to_replay(obj: SavableObject, path: Union[str, Path]) -> None:
50
67
  obj.save(path)
51
68
 
52
69
 
53
- def load_from_replay(path: Union[str, Path]) -> SavableObject:
70
+ def load_from_replay(path: Union[str, Path], **kwargs) -> SavableObject:
54
71
  """
55
72
  General function to load RePlay models, splitters and tokenizer.
56
73
 
@@ -60,6 +77,91 @@ def load_from_replay(path: Union[str, Path]) -> SavableObject:
60
77
  with open(path / "init_args.json", "r") as file:
61
78
  class_name = json.loads(file.read())["_class_name"]
62
79
  obj_type = globals()[class_name]
63
- obj = obj_type.load(path)
80
+ obj = obj_type.load(path, **kwargs)
64
81
 
65
82
  return obj
83
+
84
+
85
+ def _check_if_dataframe(var: Any):
86
+ if not isinstance(var, (SparkDataFrame, PolarsDataFrame, PandasDataFrame)):
87
+ msg = f"Object of type {type(var)} is not a dataframe of known type (can be pandas|spark|polars)"
88
+ raise ValueError(msg)
89
+
90
+
91
+ def check_if_dataframe(*args_to_check: str) -> Callable[..., Any]:
92
+ def decorator_func(func: Callable[..., Any]) -> Callable[..., Any]:
93
+ @functools.wraps(func)
94
+ def wrap_func(*args: Any, **kwargs: Any) -> Any:
95
+ extended_kwargs = {}
96
+ extended_kwargs.update(kwargs)
97
+ extended_kwargs.update(dict(zip(inspect.signature(func).parameters.keys(), args)))
98
+ # add default param values to dict with arguments
99
+ extended_kwargs.update(
100
+ {
101
+ x.name: x.default
102
+ for x in inspect.signature(func).parameters.values()
103
+ if x.name not in extended_kwargs and x.default is not x.empty
104
+ }
105
+ )
106
+ vals_to_check = [extended_kwargs[_arg] for _arg in args_to_check]
107
+ for val in vals_to_check:
108
+ _check_if_dataframe(val)
109
+ return func(*args, **kwargs)
110
+
111
+ return wrap_func
112
+
113
+ return decorator_func
114
+
115
+
116
+ @check_if_dataframe("data")
117
+ def convert2pandas(
118
+ data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame], allow_collect_to_master: bool = False
119
+ ) -> PandasDataFrame:
120
+ """
121
+ Convert the spark|polars DataFrame to a pandas.DataFrame.
122
+ Returns unchanged dataframe if the input is already of type pandas.DataFrame.
123
+
124
+ :param data: The dataframe to convert. Can be polars|spark|pandas DataFrame.
125
+ :param allow_collect_to_master: If set to False (default) raises a warning
126
+ about collecting parallelized data to the master node.
127
+ """
128
+ if isinstance(data, PandasDataFrame):
129
+ return data
130
+ if isinstance(data, PolarsDataFrame):
131
+ return data.to_pandas()
132
+ if isinstance(data, SparkDataFrame):
133
+ return spark_to_pandas(data, allow_collect_to_master, from_constructor=False)
134
+
135
+
136
+ @check_if_dataframe("data")
137
+ def convert2polars(
138
+ data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame], allow_collect_to_master: bool = False
139
+ ) -> PolarsDataFrame:
140
+ """
141
+ Convert the spark|pandas DataFrame to a polars.DataFrame.
142
+ Returns unchanged dataframe if the input is already of type polars.DataFrame.
143
+
144
+ :param data: The dataframe to convert. Can be spark|pandas|polars DataFrame.
145
+ :param allow_collect_to_master: If set to False (default) raises a warning
146
+ about collecting parallelized data to the master node.
147
+ """
148
+ if isinstance(data, PandasDataFrame):
149
+ return pl_from_pandas(data)
150
+ if isinstance(data, PolarsDataFrame):
151
+ return data
152
+ if isinstance(data, SparkDataFrame):
153
+ return pl_from_pandas(spark_to_pandas(data, allow_collect_to_master, from_constructor=False))
154
+
155
+
156
+ @check_if_dataframe("data")
157
+ def convert2spark(data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame]) -> SparkDataFrame:
158
+ """
159
+ Convert the pandas|polars DataFrame to a pysaprk.sql.DataFrame.
160
+ Returns unchanged dataframe if the input is already of type pysaprk.sql.DataFrame.
161
+
162
+ :param data: The dataframe to convert. Can be pandas|polars|spark Datarame.
163
+ """
164
+ if isinstance(data, (PandasDataFrame, SparkDataFrame)):
165
+ return pandas_to_spark(data)
166
+ if isinstance(data, PolarsDataFrame):
167
+ return pandas_to_spark(data.to_pandas())
@@ -33,7 +33,9 @@ class SparkCollectToMasterWarning(Warning): # pragma: no cover
33
33
  """
34
34
 
35
35
 
36
- def spark_to_pandas(data: SparkDataFrame, allow_collect_to_master: bool = False) -> pd.DataFrame: # pragma: no cover
36
+ def spark_to_pandas(
37
+ data: SparkDataFrame, allow_collect_to_master: bool = False, from_constructor: bool = True
38
+ ) -> pd.DataFrame: # pragma: no cover
37
39
  """
38
40
  Convert Spark DataFrame to Pandas DataFrame.
39
41
 
@@ -42,10 +44,15 @@ def spark_to_pandas(data: SparkDataFrame, allow_collect_to_master: bool = False)
42
44
 
43
45
  :returns: Converted Pandas DataFrame.
44
46
  """
47
+ warn_msg = "Spark Data Frame is collected to master node, this may lead to OOM exception for larger dataset. "
48
+ if from_constructor:
49
+ _msg = "To remove this warning set allow_collect_to_master=True in the recommender constructor."
50
+ else:
51
+ _msg = "To remove this warning set allow_collect_to_master=True."
52
+ warn_msg += _msg
45
53
  if not allow_collect_to_master:
46
54
  warnings.warn(
47
- "Spark Data Frame is collected to master node, this may lead to OOM exception for larger dataset. "
48
- "To remove this warning set allow_collect_to_master=True in the recommender constructor.",
55
+ warn_msg,
49
56
  SparkCollectToMasterWarning,
50
57
  )
51
58
  return data.toPandas()
@@ -169,7 +176,7 @@ if PYSPARK_AVAILABLE:
169
176
  <BLANKLINE>
170
177
  >>> output_data = input_data.select(vector_dot("one", "two").alias("dot"))
171
178
  >>> output_data.schema
172
- StructType(List(StructField(dot,DoubleType,true)))
179
+ StructType([StructField('dot', DoubleType(), True)])
173
180
  >>> output_data.show()
174
181
  +----+
175
182
  | dot|
@@ -207,7 +214,7 @@ if PYSPARK_AVAILABLE:
207
214
  <BLANKLINE>
208
215
  >>> output_data = input_data.select(vector_mult("one", "two").alias("mult"))
209
216
  >>> output_data.schema
210
- StructType(List(StructField(mult,VectorUDT,true)))
217
+ StructType([StructField('mult', VectorUDT(), True)])
211
218
  >>> output_data.show()
212
219
  +---------+
213
220
  | mult|
@@ -244,7 +251,7 @@ if PYSPARK_AVAILABLE:
244
251
  <BLANKLINE>
245
252
  >>> output_data = input_data.select(array_mult("one", "two").alias("mult"))
246
253
  >>> output_data.schema
247
- StructType(List(StructField(mult,ArrayType(DoubleType,true),true)))
254
+ StructType([StructField('mult', ArrayType(DoubleType(), True), True)])
248
255
  >>> output_data.show()
249
256
  +----------+
250
257
  | mult|
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: replay-rec
3
- Version: 0.17.0rc0
3
+ Version: 0.17.1rc0
4
4
  Summary: RecSys Library
5
5
  Home-page: https://sb-ai-lab.github.io/RePlay/
6
6
  License: Apache-2.0
@@ -32,11 +32,11 @@ Requires-Dist: nmslib (==2.1.1)
32
32
  Requires-Dist: numba (>=0.50)
33
33
  Requires-Dist: numpy (>=1.20.0)
34
34
  Requires-Dist: optuna (>=3.2.0,<3.3.0)
35
- Requires-Dist: pandas (>=1.3.5,<2.0.0)
35
+ Requires-Dist: pandas (>=1.3.5,<=2.2.2)
36
36
  Requires-Dist: polars (>=0.20.7,<0.21.0)
37
37
  Requires-Dist: psutil (>=5.9.5,<5.10.0)
38
38
  Requires-Dist: pyarrow (>=12.0.1)
39
- Requires-Dist: pyspark (>=3.0,<3.3) ; extra == "spark" or extra == "all"
39
+ Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark" or extra == "all"
40
40
  Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "all"
41
41
  Requires-Dist: sb-obp (>=0.5.7,<0.6.0)
42
42
  Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
@@ -1,12 +1,12 @@
1
- replay/__init__.py,sha256=y6Ms_dBdP_0tx6CPUF9QV0jrhb-ogRReafA6edgal_E,54
1
+ replay/__init__.py,sha256=_PQ2zFERSGjgeThzFv3t6MPODgutry1eR82biGhB98o,54
2
2
  replay/data/__init__.py,sha256=g5bKRyF76QL_BqlED-31RnS8pBdcyj9loMsx5vAG_0E,301
3
- replay/data/dataset.py,sha256=ysMTNfx8I2hI9fSugtt3IPhenmutgzQMw-8VcM3oUJk,21299
3
+ replay/data/dataset.py,sha256=cSStvCqIc6WAJNtbmsxncSpcQZ1KfULMsrmf_V0UdPw,29490
4
4
  replay/data/dataset_utils/__init__.py,sha256=9wUvG8ZwGUvuzLU4zQI5FDcH0WVVo5YLN2ey3DterP0,55
5
5
  replay/data/dataset_utils/dataset_label_encoder.py,sha256=TEx2zLw5rJdIz1SRBEznyVv5x_Cs7o6QQbzMk-M1LU0,9598
6
6
  replay/data/nn/__init__.py,sha256=WxLsi4rgOuuvGYHN49xBPxP2Srhqf3NYgfBDVH-ZvBo,1122
7
- replay/data/nn/schema.py,sha256=BYU65vLqPDl69OE-rReh59fiQK0ERfs1xbBLWCiIJnw,14258
8
- replay/data/nn/sequence_tokenizer.py,sha256=dXD8l7IfK1dod8p--I6BhvE9af3iUOfpaoW2QBU9hTs,34133
9
- replay/data/nn/sequential_dataset.py,sha256=fqlyBAzDmpH332S-LoMP9PoRYMtgZpxG6Qdahmk5GtE,7840
7
+ replay/data/nn/schema.py,sha256=pO4N7RgmgrqfD1-2d95OTeihKHTZ-5y2BG7CX_wBFi4,16198
8
+ replay/data/nn/sequence_tokenizer.py,sha256=Ambrp3CMOp3JP68PiwmVh0m-_zNXiWzxxVreHkEwOyY,32592
9
+ replay/data/nn/sequential_dataset.py,sha256=jCWxC0Pm1eQ5p8Y6_Bmg4fSEvPaecLrqz1iaWzaICdI,11014
10
10
  replay/data/nn/torch_sequential_dataset.py,sha256=BqrK_PtkhpsaY1zRIWGk4EgwPL31a7IWCc0hLDuwDQc,10984
11
11
  replay/data/nn/utils.py,sha256=YKE9gkIHZDDiwv4THqOWL4PzsdOujnPuM97v79Mwq0E,2769
12
12
  replay/data/schema.py,sha256=F_cv6sYb6l23yuX5xWnbqoJ9oSeUT2NpIM19u8Lf2jA,15606
@@ -148,14 +148,14 @@ replay/optimization/__init__.py,sha256=az6U10rF7X6rPRUUPwLyiM1WFNJ_6kl0imA5xLVWF
148
148
  replay/optimization/optuna_objective.py,sha256=Z-8X0_FT3BicVWj0UhxoLrvZAck3Dhn7jHDGo0i0hxA,7653
149
149
  replay/preprocessing/__init__.py,sha256=TtBysFqYeDy4kZAEnWEaNSwPvbffYdfMkEs71YG51fM,411
150
150
  replay/preprocessing/converter.py,sha256=DczqsVLrwFi6EFhK2HR8rGiIxGCwXeY7QNgWorjA41g,4390
151
- replay/preprocessing/filters.py,sha256=6MaO4IIyKNFP2AR94YA5iQUhQvuCRhAFfj0opI6o4-Q,33744
151
+ replay/preprocessing/filters.py,sha256=wsXWQoZ-2aAecunLkaTxeLWi5ow4e3FAGcElx0iNx0w,41669
152
152
  replay/preprocessing/history_based_fp.py,sha256=tfgKJPKm53LSNqM6VmMXYsVrRDc-rP1Tbzn8s3mbziQ,18751
153
153
  replay/preprocessing/label_encoder.py,sha256=MLBavPD-dB644as0E9ZJSE9-8QxGCB_IHek1w3xtqDI,27040
154
154
  replay/preprocessing/sessionizer.py,sha256=G6i0K3FwqtweRxvcSYraJ-tBWAT2HnV-bWHHlIZJF-s,12217
155
155
  replay/scenarios/__init__.py,sha256=kw2wRkPPinw0IBA20D83XQ3xeSudk3KuYAAA1Wdr8xY,93
156
156
  replay/scenarios/fallback.py,sha256=EeBmIR-5igzKR2m55bQRFyhxTkpJez6ZkCW449n8hWs,7130
157
157
  replay/splitters/__init__.py,sha256=DnqVMelrzLwR8fGQgcWN_8FipGs8T4XGSPOMW-L_x2g,454
158
- replay/splitters/base_splitter.py,sha256=qWW8Sueu0BrYt0WIxMbzooAC4-jhEmyd6pMND_H_qB0,7751
158
+ replay/splitters/base_splitter.py,sha256=hj9_GYDWllzv3XnxN6WHu1JKRRVjXo77vZEOLbF9v-s,7761
159
159
  replay/splitters/cold_user_random_splitter.py,sha256=gVwBVdn_0IOaLGT_UzJoS9AMaPhelZy-FpC5JQS1PhA,4136
160
160
  replay/splitters/k_folds.py,sha256=WH02_DP18A2ae893ysonmfLPB56_i1ETllTAwaCYekg,6218
161
161
  replay/splitters/last_n_splitter.py,sha256=r9kdq2JPi508C9ywjwc68an-iq27KsigMfHWLz0YohE,15346
@@ -165,16 +165,16 @@ replay/splitters/ratio_splitter.py,sha256=8zvuCn16Icc4ntQPKXJ5ArAWuJzCZ9NHZtgWct
165
165
  replay/splitters/time_splitter.py,sha256=iXhuafjBx7dWyJSy-TEVy1IUQBwMpA1gAiF4-GtRe2g,9031
166
166
  replay/splitters/two_stage_splitter.py,sha256=PWozxjjgjrVzdz6Sm9dcDTeH0bOA24reFzkk_N_TgbQ,17734
167
167
  replay/utils/__init__.py,sha256=vDJgOWq81fbBs-QO4ZDpdqR4KDyO1kMOOxBRi-5Gp7E,253
168
- replay/utils/common.py,sha256=6JxR5bFuTFTFWad36J5Zu8dFgpFXoof6VsVpF2sD7h8,1471
168
+ replay/utils/common.py,sha256=s4Pro3QCkPeVBsj-s0vrbhd_pkJD-_-2M_sIguxGzQQ,5411
169
169
  replay/utils/dataframe_bucketizer.py,sha256=LipmBBQkdkLGroZpbP9i7qvTombLdMxo2dUUys1m5OY,3748
170
170
  replay/utils/distributions.py,sha256=kGGq2KzQZ-yhTuw_vtOsKFXVpXUOQ2l4aIFBcaDufZ8,1202
171
171
  replay/utils/model_handler.py,sha256=V-mHDh8_UexjVSsMBBRA9yrjS_5MPHwYOwv_UrI-Zfs,6466
172
172
  replay/utils/session_handler.py,sha256=ijTvDSNAe1D9R1e-dhtd-r80tFNiIBsFdWZLgw-gLEo,5153
173
- replay/utils/spark_utils.py,sha256=PhNi9fW28ek0ZB90AUg3tsT5BULbQjDhLalxxww9eLE,26700
173
+ replay/utils/spark_utils.py,sha256=k5lUFM2C9QZKQON3dqhgfswyUF4tsgJOn0U2wCKimqM,26901
174
174
  replay/utils/time.py,sha256=J8asoQBytPcNw-BLGADYIsKeWhIoN1H5hKiX9t2AMqo,9376
175
175
  replay/utils/types.py,sha256=5sw0A7NG4ZgQKdWORnBy0wBZ5F98sP_Ju8SKQ6zbDS4,651
176
- replay_rec-0.17.0rc0.dist-info/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
177
- replay_rec-0.17.0rc0.dist-info/METADATA,sha256=8Ki81O8-t1bWieQu4WJFFNWMu4CrvhwBSaU0mcfhh4o,10889
178
- replay_rec-0.17.0rc0.dist-info/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
179
- replay_rec-0.17.0rc0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
180
- replay_rec-0.17.0rc0.dist-info/RECORD,,
176
+ replay_rec-0.17.1rc0.dist-info/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
177
+ replay_rec-0.17.1rc0.dist-info/METADATA,sha256=FgZduBS6AVq1qSNahVyNFCJILLPdVLVosbxjUxN7WkQ,10890
178
+ replay_rec-0.17.1rc0.dist-info/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
179
+ replay_rec-0.17.1rc0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
180
+ replay_rec-0.17.1rc0.dist-info/RECORD,,