replay-rec 0.17.0rc0__py3-none-any.whl → 0.17.1rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +246 -20
- replay/data/nn/schema.py +42 -0
- replay/data/nn/sequence_tokenizer.py +17 -47
- replay/data/nn/sequential_dataset.py +76 -2
- replay/preprocessing/filters.py +169 -4
- replay/splitters/base_splitter.py +1 -1
- replay/utils/common.py +107 -5
- replay/utils/spark_utils.py +13 -6
- {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1rc0.dist-info}/METADATA +3 -3
- {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1rc0.dist-info}/RECORD +14 -14
- {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1rc0.dist-info}/LICENSE +0 -0
- {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1rc0.dist-info}/NOTICE +0 -0
- {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1rc0.dist-info}/WHEEL +0 -0
replay/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
""" RecSys library """
|
|
2
|
-
__version__ = "0.17.
|
|
2
|
+
__version__ = "0.17.1.preview"
|
replay/data/dataset.py
CHANGED
|
@@ -3,11 +3,22 @@
|
|
|
3
3
|
"""
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import json
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
|
|
7
9
|
|
|
8
10
|
import numpy as np
|
|
9
|
-
|
|
10
|
-
from
|
|
11
|
+
from pandas import read_parquet as pd_read_parquet
|
|
12
|
+
from polars import read_parquet as pl_read_parquet
|
|
13
|
+
|
|
14
|
+
from replay.utils import (
|
|
15
|
+
PYSPARK_AVAILABLE,
|
|
16
|
+
DataFrameLike,
|
|
17
|
+
PandasDataFrame,
|
|
18
|
+
PolarsDataFrame,
|
|
19
|
+
SparkDataFrame,
|
|
20
|
+
)
|
|
21
|
+
from replay.utils.session_handler import get_spark_session
|
|
11
22
|
|
|
12
23
|
from .schema import FeatureHint, FeatureInfo, FeatureSchema, FeatureSource, FeatureType
|
|
13
24
|
|
|
@@ -47,9 +58,7 @@ class Dataset:
|
|
|
47
58
|
self._query_features = query_features
|
|
48
59
|
self._item_features = item_features
|
|
49
60
|
|
|
50
|
-
self.
|
|
51
|
-
self.is_spark = isinstance(interactions, SparkDataFrame)
|
|
52
|
-
self.is_polars = isinstance(interactions, PolarsDataFrame)
|
|
61
|
+
self._assign_df_type()
|
|
53
62
|
|
|
54
63
|
self._categorical_encoded = categorical_encoded
|
|
55
64
|
|
|
@@ -74,16 +83,8 @@ class Dataset:
|
|
|
74
83
|
msg = "Interactions and query features should have the same type."
|
|
75
84
|
raise TypeError(msg)
|
|
76
85
|
|
|
77
|
-
self.
|
|
78
|
-
|
|
79
|
-
FeatureSource.QUERY_FEATURES: self.query_features,
|
|
80
|
-
FeatureSource.ITEM_FEATURES: self.item_features,
|
|
81
|
-
}
|
|
82
|
-
|
|
83
|
-
self._ids_feature_map: Dict[FeatureHint, DataFrameLike] = {
|
|
84
|
-
FeatureHint.QUERY_ID: self.query_features if self.query_features is not None else self.interactions,
|
|
85
|
-
FeatureHint.ITEM_ID: self.item_features if self.item_features is not None else self.interactions,
|
|
86
|
-
}
|
|
86
|
+
self._get_feature_source_map()
|
|
87
|
+
self._get_ids_source_map()
|
|
87
88
|
|
|
88
89
|
self._feature_schema = self._fill_feature_schema(feature_schema)
|
|
89
90
|
|
|
@@ -92,7 +93,6 @@ class Dataset:
|
|
|
92
93
|
self._check_ids_consistency(hint=FeatureHint.QUERY_ID)
|
|
93
94
|
if self.item_features is not None:
|
|
94
95
|
self._check_ids_consistency(hint=FeatureHint.ITEM_ID)
|
|
95
|
-
|
|
96
96
|
if self._categorical_encoded:
|
|
97
97
|
self._check_encoded()
|
|
98
98
|
|
|
@@ -189,6 +189,157 @@ class Dataset:
|
|
|
189
189
|
"""
|
|
190
190
|
return self._feature_schema
|
|
191
191
|
|
|
192
|
+
def _get_df_type(self) -> str:
|
|
193
|
+
"""
|
|
194
|
+
:returns: Stored dataframe type.
|
|
195
|
+
"""
|
|
196
|
+
if self.is_spark:
|
|
197
|
+
return "spark"
|
|
198
|
+
if self.is_pandas:
|
|
199
|
+
return "pandas"
|
|
200
|
+
if self.is_polars:
|
|
201
|
+
return "polars"
|
|
202
|
+
msg = "No known dataframe types are provided"
|
|
203
|
+
raise ValueError(msg)
|
|
204
|
+
|
|
205
|
+
def _to_parquet(self, df: DataFrameLike, path: Path) -> None:
|
|
206
|
+
"""
|
|
207
|
+
Save the content of the dataframe in parquet format to the provided path.
|
|
208
|
+
|
|
209
|
+
:param df: Dataframe to save.
|
|
210
|
+
:param path: Path to save the dataframe to.
|
|
211
|
+
"""
|
|
212
|
+
if self.is_spark:
|
|
213
|
+
path = str(path)
|
|
214
|
+
df = df.withColumn("idx", sf.monotonically_increasing_id())
|
|
215
|
+
df.write.mode("overwrite").parquet(path)
|
|
216
|
+
elif self.is_pandas:
|
|
217
|
+
df.to_parquet(path)
|
|
218
|
+
elif self.is_polars:
|
|
219
|
+
df.write_parquet(path)
|
|
220
|
+
else:
|
|
221
|
+
msg = """
|
|
222
|
+
_to_parquet() can only be used to save polars|pandas|spark dataframes;
|
|
223
|
+
No known dataframe types are provided
|
|
224
|
+
"""
|
|
225
|
+
raise TypeError(msg)
|
|
226
|
+
|
|
227
|
+
@staticmethod
|
|
228
|
+
def _read_parquet(path: Path, mode: str) -> Union[SparkDataFrame, PandasDataFrame, PolarsDataFrame]:
|
|
229
|
+
"""
|
|
230
|
+
Read the parquet file as dataframe.
|
|
231
|
+
|
|
232
|
+
:param path: The parquet file path.
|
|
233
|
+
:param mode: Dataframe type. Can be spark|pandas|polars.
|
|
234
|
+
:returns: The dataframe read from the file.
|
|
235
|
+
"""
|
|
236
|
+
if mode == "spark":
|
|
237
|
+
path = str(path)
|
|
238
|
+
spark_session = get_spark_session()
|
|
239
|
+
df = spark_session.read.parquet(path)
|
|
240
|
+
if "idx" in df.columns:
|
|
241
|
+
df = df.orderBy("idx").drop("idx")
|
|
242
|
+
return df
|
|
243
|
+
if mode == "pandas":
|
|
244
|
+
df = pd_read_parquet(path)
|
|
245
|
+
if "idx" in df.columns:
|
|
246
|
+
df = df.set_index("idx").reset_index(drop=True)
|
|
247
|
+
return df
|
|
248
|
+
if mode == "polars":
|
|
249
|
+
df = pl_read_parquet(path, use_pyarrow=True)
|
|
250
|
+
if "idx" in df.columns:
|
|
251
|
+
df = df.sort("idx").drop("idx")
|
|
252
|
+
return df
|
|
253
|
+
msg = f"_read_parquet() can only be used to read polars|pandas|spark dataframes, not {mode}"
|
|
254
|
+
raise TypeError(msg)
|
|
255
|
+
|
|
256
|
+
def save(self, path: str) -> None:
|
|
257
|
+
"""
|
|
258
|
+
Save the Dataset to the provided path.
|
|
259
|
+
|
|
260
|
+
:param path: Path to save the Dataset to.
|
|
261
|
+
"""
|
|
262
|
+
dataset_dict = {}
|
|
263
|
+
dataset_dict["_class_name"] = self.__class__.__name__
|
|
264
|
+
|
|
265
|
+
interactions_type = self._get_df_type()
|
|
266
|
+
dataset_dict["init_args"] = {
|
|
267
|
+
"feature_schema": [],
|
|
268
|
+
"interactions": interactions_type,
|
|
269
|
+
"item_features": (interactions_type if self.item_features is not None else None),
|
|
270
|
+
"query_features": (interactions_type if self.query_features is not None else None),
|
|
271
|
+
"check_consistency": False,
|
|
272
|
+
"categorical_encoded": self._categorical_encoded,
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
for feature in self.feature_schema.all_features:
|
|
276
|
+
dataset_dict["init_args"]["feature_schema"].append(
|
|
277
|
+
{
|
|
278
|
+
"column": feature.column,
|
|
279
|
+
"feature_type": feature.feature_type.name,
|
|
280
|
+
"feature_hint": (feature.feature_hint.name if feature.feature_hint else None),
|
|
281
|
+
}
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
285
|
+
base_path.mkdir(parents=True, exist_ok=True)
|
|
286
|
+
|
|
287
|
+
with open(base_path / "init_args.json", "w+") as file:
|
|
288
|
+
json.dump(dataset_dict, file)
|
|
289
|
+
|
|
290
|
+
df_data = {
|
|
291
|
+
"interactions": self.interactions,
|
|
292
|
+
"item_features": self.item_features,
|
|
293
|
+
"query_features": self.query_features,
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
for df_name, df in df_data.items():
|
|
297
|
+
if df is not None:
|
|
298
|
+
df_path = base_path / f"{df_name}.parquet"
|
|
299
|
+
self._to_parquet(df, df_path)
|
|
300
|
+
|
|
301
|
+
@classmethod
|
|
302
|
+
def load(
|
|
303
|
+
cls,
|
|
304
|
+
path: str,
|
|
305
|
+
dataframe_type: Optional[str] = None,
|
|
306
|
+
) -> Dataset:
|
|
307
|
+
"""
|
|
308
|
+
Load the Dataset from the provided path.
|
|
309
|
+
|
|
310
|
+
:param path: The file path
|
|
311
|
+
:dataframe_type: Dataframe type to use to store internal data.
|
|
312
|
+
Can be spark|pandas|polars|None.
|
|
313
|
+
If not provided automatically sets to the one used when the Dataset was saved.
|
|
314
|
+
:returns: Loaded Dataset.
|
|
315
|
+
"""
|
|
316
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
317
|
+
with open(base_path / "init_args.json", "r") as file:
|
|
318
|
+
dataset_dict = json.loads(file.read())
|
|
319
|
+
|
|
320
|
+
if dataframe_type not in ["pandas", "spark", "polars", None]:
|
|
321
|
+
msg = f"Argument dataframe_type can be spark|pandas|polars|None, not {dataframe_type}"
|
|
322
|
+
raise ValueError(msg)
|
|
323
|
+
|
|
324
|
+
feature_schema_data = dataset_dict["init_args"]["feature_schema"]
|
|
325
|
+
features_list = []
|
|
326
|
+
for feature_data in feature_schema_data:
|
|
327
|
+
f_type = feature_data["feature_type"]
|
|
328
|
+
f_hint = feature_data["feature_hint"]
|
|
329
|
+
feature_data["feature_type"] = FeatureType[f_type] if f_type else None
|
|
330
|
+
feature_data["feature_hint"] = FeatureHint[f_hint] if f_hint else None
|
|
331
|
+
features_list.append(FeatureInfo(**feature_data))
|
|
332
|
+
dataset_dict["init_args"]["feature_schema"] = FeatureSchema(features_list)
|
|
333
|
+
|
|
334
|
+
for df_name in ["interactions", "query_features", "item_features"]:
|
|
335
|
+
df_type = dataset_dict["init_args"][df_name]
|
|
336
|
+
if df_type:
|
|
337
|
+
df_type = dataframe_type or df_type
|
|
338
|
+
load_path = base_path / f"{df_name}.parquet"
|
|
339
|
+
dataset_dict["init_args"][df_name] = cls._read_parquet(load_path, df_type)
|
|
340
|
+
dataset = cls(**dataset_dict["init_args"])
|
|
341
|
+
return dataset
|
|
342
|
+
|
|
192
343
|
if PYSPARK_AVAILABLE:
|
|
193
344
|
|
|
194
345
|
def persist(self, storage_level: StorageLevel = StorageLevel(True, True, False, True, 1)) -> None:
|
|
@@ -283,6 +434,24 @@ class Dataset:
|
|
|
283
434
|
categorical_encoded=self._categorical_encoded,
|
|
284
435
|
)
|
|
285
436
|
|
|
437
|
+
def _get_feature_source_map(self):
|
|
438
|
+
self._feature_source_map: Dict[FeatureSource, DataFrameLike] = {
|
|
439
|
+
FeatureSource.INTERACTIONS: self.interactions,
|
|
440
|
+
FeatureSource.QUERY_FEATURES: self.query_features,
|
|
441
|
+
FeatureSource.ITEM_FEATURES: self.item_features,
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
def _get_ids_source_map(self):
|
|
445
|
+
self._ids_feature_map: Dict[FeatureHint, DataFrameLike] = {
|
|
446
|
+
FeatureHint.QUERY_ID: self.query_features if self.query_features is not None else self.interactions,
|
|
447
|
+
FeatureHint.ITEM_ID: self.item_features if self.item_features is not None else self.interactions,
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
def _assign_df_type(self):
|
|
451
|
+
self.is_pandas = isinstance(self.interactions, PandasDataFrame)
|
|
452
|
+
self.is_spark = isinstance(self.interactions, SparkDataFrame)
|
|
453
|
+
self.is_polars = isinstance(self.interactions, PolarsDataFrame)
|
|
454
|
+
|
|
286
455
|
def _get_cardinality(self, feature: FeatureInfo) -> Callable:
|
|
287
456
|
def callback(column: str) -> int:
|
|
288
457
|
if feature.feature_hint in [FeatureHint.ITEM_ID, FeatureHint.QUERY_ID]:
|
|
@@ -381,7 +550,11 @@ class Dataset:
|
|
|
381
550
|
is_consistent = (
|
|
382
551
|
self.interactions.select(ids_column)
|
|
383
552
|
.distinct()
|
|
384
|
-
.join(
|
|
553
|
+
.join(
|
|
554
|
+
features_df.select(ids_column).distinct(),
|
|
555
|
+
on=[ids_column],
|
|
556
|
+
how="leftanti",
|
|
557
|
+
)
|
|
385
558
|
.count()
|
|
386
559
|
) == 0
|
|
387
560
|
else:
|
|
@@ -389,7 +562,11 @@ class Dataset:
|
|
|
389
562
|
len(
|
|
390
563
|
self.interactions.select(ids_column)
|
|
391
564
|
.unique()
|
|
392
|
-
.join(
|
|
565
|
+
.join(
|
|
566
|
+
features_df.select(ids_column).unique(),
|
|
567
|
+
on=ids_column,
|
|
568
|
+
how="anti",
|
|
569
|
+
)
|
|
393
570
|
)
|
|
394
571
|
== 0
|
|
395
572
|
)
|
|
@@ -399,7 +576,11 @@ class Dataset:
|
|
|
399
576
|
raise ValueError(msg)
|
|
400
577
|
|
|
401
578
|
def _check_column_encoded(
|
|
402
|
-
self,
|
|
579
|
+
self,
|
|
580
|
+
data: DataFrameLike,
|
|
581
|
+
column: str,
|
|
582
|
+
source: FeatureSource,
|
|
583
|
+
cardinality: Optional[int],
|
|
403
584
|
) -> None:
|
|
404
585
|
"""
|
|
405
586
|
Checks that IDs are encoded:
|
|
@@ -482,6 +663,51 @@ class Dataset:
|
|
|
482
663
|
feature.cardinality,
|
|
483
664
|
)
|
|
484
665
|
|
|
666
|
+
def to_pandas(self) -> None:
|
|
667
|
+
"""
|
|
668
|
+
Convert internally stored dataframes to pandas.DataFrame.
|
|
669
|
+
"""
|
|
670
|
+
from replay.utils.common import convert2pandas
|
|
671
|
+
|
|
672
|
+
self._interactions = convert2pandas(self._interactions)
|
|
673
|
+
if self._query_features is not None:
|
|
674
|
+
self._query_features = convert2pandas(self._query_features)
|
|
675
|
+
if self._item_features is not None:
|
|
676
|
+
self._item_features = convert2pandas(self.item_features)
|
|
677
|
+
self._get_feature_source_map()
|
|
678
|
+
self._get_ids_source_map()
|
|
679
|
+
self._assign_df_type()
|
|
680
|
+
|
|
681
|
+
def to_spark(self):
|
|
682
|
+
"""
|
|
683
|
+
Convert internally stored dataframes to pyspark.sql.DataFrame.
|
|
684
|
+
"""
|
|
685
|
+
from replay.utils.common import convert2spark
|
|
686
|
+
|
|
687
|
+
self._interactions = convert2spark(self._interactions)
|
|
688
|
+
if self._query_features is not None:
|
|
689
|
+
self._query_features = convert2spark(self._query_features)
|
|
690
|
+
if self._item_features is not None:
|
|
691
|
+
self._item_features = convert2spark(self._item_features)
|
|
692
|
+
self._get_feature_source_map()
|
|
693
|
+
self._get_ids_source_map()
|
|
694
|
+
self._assign_df_type()
|
|
695
|
+
|
|
696
|
+
def to_polars(self):
|
|
697
|
+
"""
|
|
698
|
+
Convert internally stored dataframes to polars.DataFrame.
|
|
699
|
+
"""
|
|
700
|
+
from replay.utils.common import convert2polars
|
|
701
|
+
|
|
702
|
+
self._interactions = convert2polars(self._interactions)
|
|
703
|
+
if self._query_features is not None:
|
|
704
|
+
self._query_features = convert2polars(self._query_features)
|
|
705
|
+
if self._item_features is not None:
|
|
706
|
+
self._item_features = convert2polars(self._item_features)
|
|
707
|
+
self._get_feature_source_map()
|
|
708
|
+
self._get_ids_source_map()
|
|
709
|
+
self._assign_df_type()
|
|
710
|
+
|
|
485
711
|
|
|
486
712
|
def nunique(data: DataFrameLike, column: str) -> int:
|
|
487
713
|
"""
|
replay/data/nn/schema.py
CHANGED
|
@@ -408,6 +408,48 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
408
408
|
return None
|
|
409
409
|
return rating_features.item().name
|
|
410
410
|
|
|
411
|
+
def _get_object_args(self) -> Dict:
|
|
412
|
+
"""
|
|
413
|
+
Returns list of features represented as dictionaries.
|
|
414
|
+
"""
|
|
415
|
+
features = [
|
|
416
|
+
{
|
|
417
|
+
"name": feature.name,
|
|
418
|
+
"feature_type": feature.feature_type.name,
|
|
419
|
+
"is_seq": feature.is_seq,
|
|
420
|
+
"feature_hint": feature.feature_hint.name if feature.feature_hint else None,
|
|
421
|
+
"feature_sources": [
|
|
422
|
+
{"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources
|
|
423
|
+
]
|
|
424
|
+
if feature.feature_sources
|
|
425
|
+
else None,
|
|
426
|
+
"cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
|
|
427
|
+
"embedding_dim": feature.embedding_dim if feature.feature_type == FeatureType.CATEGORICAL else None,
|
|
428
|
+
"tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
|
|
429
|
+
}
|
|
430
|
+
for feature in self.all_features
|
|
431
|
+
]
|
|
432
|
+
return features
|
|
433
|
+
|
|
434
|
+
@classmethod
|
|
435
|
+
def _create_object_by_args(cls, args: Dict) -> "TensorSchema":
|
|
436
|
+
features_list = []
|
|
437
|
+
for feature_data in args:
|
|
438
|
+
feature_data["feature_sources"] = (
|
|
439
|
+
[
|
|
440
|
+
TensorFeatureSource(source=FeatureSource[x["source"]], column=x["column"], index=x["index"])
|
|
441
|
+
for x in feature_data["feature_sources"]
|
|
442
|
+
]
|
|
443
|
+
if feature_data["feature_sources"]
|
|
444
|
+
else None
|
|
445
|
+
)
|
|
446
|
+
f_type = feature_data["feature_type"]
|
|
447
|
+
f_hint = feature_data["feature_hint"]
|
|
448
|
+
feature_data["feature_type"] = FeatureType[f_type] if f_type else None
|
|
449
|
+
feature_data["feature_hint"] = FeatureHint[f_hint] if f_hint else None
|
|
450
|
+
features_list.append(TensorFeatureInfo(**feature_data))
|
|
451
|
+
return TensorSchema(features_list)
|
|
452
|
+
|
|
411
453
|
def filter(
|
|
412
454
|
self,
|
|
413
455
|
name: Optional[str] = None,
|
|
@@ -24,7 +24,10 @@ SequenceDataFrameLike = Union[PandasDataFrame, PolarsDataFrame]
|
|
|
24
24
|
|
|
25
25
|
class SequenceTokenizer:
|
|
26
26
|
"""
|
|
27
|
-
Data tokenizer for transformers
|
|
27
|
+
Data tokenizer for transformers;
|
|
28
|
+
Encodes all categorical features (the ones marked as FeatureType.CATEGORICAL in
|
|
29
|
+
the FeatureSchema) and stores all data as items sequences (sorted by time if a
|
|
30
|
+
feature of type FeatureHint.TIMESTAMP is provided, unsorted otherwise).
|
|
28
31
|
"""
|
|
29
32
|
|
|
30
33
|
def __init__(
|
|
@@ -278,17 +281,17 @@ class SequenceTokenizer:
|
|
|
278
281
|
]
|
|
279
282
|
|
|
280
283
|
for tensor_feature in tensor_schema.values():
|
|
281
|
-
source
|
|
282
|
-
|
|
284
|
+
for source in tensor_feature.feature_sources:
|
|
285
|
+
assert source is not None
|
|
283
286
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
+
# Some columns already added to encoder, skip them
|
|
288
|
+
if source.column in features_subset:
|
|
289
|
+
continue
|
|
287
290
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
291
|
+
if isinstance(source.source, FeatureSource):
|
|
292
|
+
features_subset.append(source.column)
|
|
293
|
+
else:
|
|
294
|
+
assert False, "Unknown tensor feature source"
|
|
292
295
|
|
|
293
296
|
return set(features_subset)
|
|
294
297
|
|
|
@@ -404,7 +407,7 @@ class SequenceTokenizer:
|
|
|
404
407
|
|
|
405
408
|
@classmethod
|
|
406
409
|
@deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
|
|
407
|
-
def load(cls, path: str, use_pickle: bool = False) -> "SequenceTokenizer":
|
|
410
|
+
def load(cls, path: str, use_pickle: bool = False, **kwargs) -> "SequenceTokenizer":
|
|
408
411
|
"""
|
|
409
412
|
Load tokenizer object from the given path.
|
|
410
413
|
|
|
@@ -422,18 +425,7 @@ class SequenceTokenizer:
|
|
|
422
425
|
|
|
423
426
|
# load tensor_schema, tensor_features
|
|
424
427
|
tensor_schema_data = tokenizer_dict["init_args"]["tensor_schema"]
|
|
425
|
-
|
|
426
|
-
for feature_data in tensor_schema_data:
|
|
427
|
-
feature_data["feature_sources"] = [
|
|
428
|
-
TensorFeatureSource(source=FeatureSource[x["source"]], column=x["column"], index=x["index"])
|
|
429
|
-
for x in feature_data["feature_sources"]
|
|
430
|
-
]
|
|
431
|
-
f_type = feature_data["feature_type"]
|
|
432
|
-
f_hint = feature_data["feature_hint"]
|
|
433
|
-
feature_data["feature_type"] = FeatureType[f_type] if f_type else None
|
|
434
|
-
feature_data["feature_hint"] = FeatureHint[f_hint] if f_hint else None
|
|
435
|
-
features_list.append(TensorFeatureInfo(**feature_data))
|
|
436
|
-
tokenizer_dict["init_args"]["tensor_schema"] = TensorSchema(features_list)
|
|
428
|
+
tokenizer_dict["init_args"]["tensor_schema"] = TensorSchema._create_object_by_args(tensor_schema_data)
|
|
437
429
|
|
|
438
430
|
# Load encoder columns and rules
|
|
439
431
|
types = list(FeatureHint) + list(FeatureSource)
|
|
@@ -447,7 +439,7 @@ class SequenceTokenizer:
|
|
|
447
439
|
rule_data = rules_dict[rule]
|
|
448
440
|
if rule_data["mapping"] and rule_data["is_int"]:
|
|
449
441
|
rule_data["mapping"] = {int(key): value for key, value in rule_data["mapping"].items()}
|
|
450
|
-
|
|
442
|
+
del rule_data["is_int"]
|
|
451
443
|
|
|
452
444
|
tokenizer_dict["encoder"]["encoding_rules"][rule] = LabelEncodingRule(**rule_data)
|
|
453
445
|
|
|
@@ -478,31 +470,9 @@ class SequenceTokenizer:
|
|
|
478
470
|
"allow_collect_to_master": self._allow_collect_to_master,
|
|
479
471
|
"handle_unknown_rule": self._encoder._handle_unknown_rule,
|
|
480
472
|
"default_value_rule": self._encoder._default_value_rule,
|
|
481
|
-
"tensor_schema":
|
|
473
|
+
"tensor_schema": self._tensor_schema._get_object_args(),
|
|
482
474
|
}
|
|
483
475
|
|
|
484
|
-
# save tensor schema
|
|
485
|
-
for feature in list(self._tensor_schema.values()):
|
|
486
|
-
tokenizer_dict["init_args"]["tensor_schema"].append(
|
|
487
|
-
{
|
|
488
|
-
"name": feature.name,
|
|
489
|
-
"feature_type": feature.feature_type.name,
|
|
490
|
-
"is_seq": feature.is_seq,
|
|
491
|
-
"feature_hint": feature.feature_hint.name if feature.feature_hint else None,
|
|
492
|
-
"feature_sources": [
|
|
493
|
-
{"source": x.source.name, "column": x.column, "index": x.index}
|
|
494
|
-
for x in feature.feature_sources
|
|
495
|
-
]
|
|
496
|
-
if feature.feature_sources
|
|
497
|
-
else None,
|
|
498
|
-
"cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
|
|
499
|
-
"embedding_dim": feature.embedding_dim
|
|
500
|
-
if feature.feature_type == FeatureType.CATEGORICAL
|
|
501
|
-
else None,
|
|
502
|
-
"tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
|
|
503
|
-
}
|
|
504
|
-
)
|
|
505
|
-
|
|
506
476
|
# save DatasetLabelEncoder
|
|
507
477
|
tokenizer_dict["encoder"] = {
|
|
508
478
|
"features_columns": {key.name: value for key, value in self._encoder._features_columns.items()},
|
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import abc
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
2
4
|
from typing import Tuple, Union
|
|
3
5
|
|
|
4
6
|
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
5
8
|
import polars as pl
|
|
6
9
|
from pandas import DataFrame as PandasDataFrame
|
|
7
10
|
from polars import DataFrame as PolarsDataFrame
|
|
@@ -100,6 +103,23 @@ class SequentialDataset(abc.ABC):
|
|
|
100
103
|
rhs_filtered = rhs.filter_by_query_id(common_queries)
|
|
101
104
|
return lhs_filtered, rhs_filtered
|
|
102
105
|
|
|
106
|
+
def save(self, path: str) -> None:
|
|
107
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
108
|
+
base_path.mkdir(parents=True, exist_ok=True)
|
|
109
|
+
|
|
110
|
+
sequential_dict = {}
|
|
111
|
+
sequential_dict["_class_name"] = self.__class__.__name__
|
|
112
|
+
self._sequences.reset_index().to_json(base_path / "sequences.json")
|
|
113
|
+
sequential_dict["init_args"] = {
|
|
114
|
+
"tensor_schema": self._tensor_schema._get_object_args(),
|
|
115
|
+
"query_id_column": self._query_id_column,
|
|
116
|
+
"item_id_column": self._item_id_column,
|
|
117
|
+
"sequences_path": "sequences.json",
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
with open(base_path / "init_args.json", "w+") as file:
|
|
121
|
+
json.dump(sequential_dict, file)
|
|
122
|
+
|
|
103
123
|
|
|
104
124
|
class PandasSequentialDataset(SequentialDataset):
|
|
105
125
|
"""
|
|
@@ -174,6 +194,25 @@ class PandasSequentialDataset(SequentialDataset):
|
|
|
174
194
|
msg = "Tensor schema does not match with provided data frame"
|
|
175
195
|
raise ValueError(msg)
|
|
176
196
|
|
|
197
|
+
@classmethod
|
|
198
|
+
def load(cls, path: str, **kwargs) -> "PandasSequentialDataset":
|
|
199
|
+
"""
|
|
200
|
+
Method for loading PandasSequentialDataset object from `.replay` directory.
|
|
201
|
+
"""
|
|
202
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
203
|
+
with open(base_path / "init_args.json", "r") as file:
|
|
204
|
+
sequential_dict = json.loads(file.read())
|
|
205
|
+
|
|
206
|
+
sequences = pd.read_json(base_path / sequential_dict["init_args"]["sequences_path"])
|
|
207
|
+
dataset = cls(
|
|
208
|
+
tensor_schema=TensorSchema._create_object_by_args(sequential_dict["init_args"]["tensor_schema"]),
|
|
209
|
+
query_id_column=sequential_dict["init_args"]["query_id_column"],
|
|
210
|
+
item_id_column=sequential_dict["init_args"]["item_id_column"],
|
|
211
|
+
sequences=sequences,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return dataset
|
|
215
|
+
|
|
177
216
|
|
|
178
217
|
class PolarsSequentialDataset(PandasSequentialDataset):
|
|
179
218
|
"""
|
|
@@ -199,7 +238,7 @@ class PolarsSequentialDataset(PandasSequentialDataset):
|
|
|
199
238
|
self._query_id_column = query_id_column
|
|
200
239
|
self._item_id_column = item_id_column
|
|
201
240
|
|
|
202
|
-
self._sequences =
|
|
241
|
+
self._sequences = self._convert_polars_to_pandas(sequences)
|
|
203
242
|
if self._sequences.index.name != query_id_column:
|
|
204
243
|
self._sequences = self._sequences.set_index(query_id_column)
|
|
205
244
|
|
|
@@ -211,12 +250,47 @@ class PolarsSequentialDataset(PandasSequentialDataset):
|
|
|
211
250
|
tensor_schema=self._tensor_schema,
|
|
212
251
|
query_id_column=self._query_id_column,
|
|
213
252
|
item_id_column=self._item_id_column,
|
|
214
|
-
sequences=
|
|
253
|
+
sequences=self._convert_pandas_to_polars(filtered_sequences),
|
|
215
254
|
)
|
|
216
255
|
|
|
256
|
+
def _convert_polars_to_pandas(self, df: PolarsDataFrame) -> PandasDataFrame:
|
|
257
|
+
pandas_df = PandasDataFrame(df.to_dict(as_series=False))
|
|
258
|
+
|
|
259
|
+
for column in pandas_df.select_dtypes(include="object").columns:
|
|
260
|
+
if isinstance(pandas_df[column].iloc[0], list):
|
|
261
|
+
pandas_df[column] = pandas_df[column].apply(lambda x: np.array(x))
|
|
262
|
+
|
|
263
|
+
return pandas_df
|
|
264
|
+
|
|
265
|
+
def _convert_pandas_to_polars(self, df: PandasDataFrame) -> PolarsDataFrame:
|
|
266
|
+
for column in df.select_dtypes(include="object").columns:
|
|
267
|
+
if isinstance(df[column].iloc[0], np.ndarray):
|
|
268
|
+
df[column] = df[column].apply(lambda x: x.tolist())
|
|
269
|
+
|
|
270
|
+
return pl.from_dict(df.to_dict("list"))
|
|
271
|
+
|
|
217
272
|
@classmethod
|
|
218
273
|
def _check_if_schema_matches_data(cls, tensor_schema: TensorSchema, data: PolarsDataFrame) -> None:
|
|
219
274
|
for tensor_feature_name in tensor_schema:
|
|
220
275
|
if tensor_feature_name not in data:
|
|
221
276
|
msg = "Tensor schema does not match with provided data frame"
|
|
222
277
|
raise ValueError(msg)
|
|
278
|
+
|
|
279
|
+
@classmethod
|
|
280
|
+
def load(cls, path: str, **kwargs) -> "PandasSequentialDataset":
|
|
281
|
+
"""
|
|
282
|
+
Method for loading PandasSequentialDataset object from `.replay` directory.
|
|
283
|
+
"""
|
|
284
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
285
|
+
with open(base_path / "init_args.json", "r") as file:
|
|
286
|
+
sequential_dict = json.loads(file.read())
|
|
287
|
+
|
|
288
|
+
sequences = pl.DataFrame(pd.read_json(base_path / sequential_dict["init_args"]["sequences_path"]))
|
|
289
|
+
dataset = cls(
|
|
290
|
+
tensor_schema=TensorSchema._create_object_by_args(sequential_dict["init_args"]["tensor_schema"]),
|
|
291
|
+
query_id_column=sequential_dict["init_args"]["query_id_column"],
|
|
292
|
+
item_id_column=sequential_dict["init_args"]["item_id_column"],
|
|
293
|
+
sequences=sequences,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
return dataset
|
replay/preprocessing/filters.py
CHANGED
|
@@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
|
|
|
5
5
|
from datetime import datetime, timedelta
|
|
6
6
|
from typing import Callable, Optional, Tuple, Union
|
|
7
7
|
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
8
10
|
import polars as pl
|
|
9
11
|
|
|
10
12
|
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
@@ -357,7 +359,7 @@ class NumInteractionsFilter(_BaseFilter):
|
|
|
357
359
|
... "2020-02-01", "2020-01-01 00:04:15",
|
|
358
360
|
... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
|
|
359
361
|
... )
|
|
360
|
-
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
|
|
362
|
+
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
|
|
361
363
|
>>> log_sp = convert2spark(log_pd)
|
|
362
364
|
>>> log_sp.show()
|
|
363
365
|
+-------+-------+------+-------------------+
|
|
@@ -499,7 +501,7 @@ class EntityDaysFilter(_BaseFilter):
|
|
|
499
501
|
... "2020-02-01", "2020-01-01 00:04:15",
|
|
500
502
|
... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
|
|
501
503
|
... )
|
|
502
|
-
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
|
|
504
|
+
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
|
|
503
505
|
>>> log_sp = convert2spark(log_pd)
|
|
504
506
|
>>> log_sp.orderBy('user_id', 'item_id').show()
|
|
505
507
|
+-------+-------+------+-------------------+
|
|
@@ -638,7 +640,7 @@ class GlobalDaysFilter(_BaseFilter):
|
|
|
638
640
|
... "2020-02-01", "2020-01-01 00:04:15",
|
|
639
641
|
... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
|
|
640
642
|
... )
|
|
641
|
-
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
|
|
643
|
+
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
|
|
642
644
|
>>> log_sp = convert2spark(log_pd)
|
|
643
645
|
>>> log_sp.show()
|
|
644
646
|
+-------+-------+------+-------------------+
|
|
@@ -740,7 +742,7 @@ class TimePeriodFilter(_BaseFilter):
|
|
|
740
742
|
... "2020-02-01", "2020-01-01 00:04:15",
|
|
741
743
|
... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
|
|
742
744
|
... )
|
|
743
|
-
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
|
|
745
|
+
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
|
|
744
746
|
>>> log_sp = convert2spark(log_pd)
|
|
745
747
|
>>> log_sp.show()
|
|
746
748
|
+-------+-------+------+-------------------+
|
|
@@ -823,3 +825,166 @@ class TimePeriodFilter(_BaseFilter):
|
|
|
823
825
|
return interactions.filter(
|
|
824
826
|
pl.col(self.timestamp_column).is_between(self.start_date, self.end_date, closed="left")
|
|
825
827
|
)
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
class QuantileItemsFilter(_BaseFilter):
|
|
831
|
+
"""
|
|
832
|
+
Filter is aimed on undersampling the interactions dataset.
|
|
833
|
+
|
|
834
|
+
Filter algorithm performs undersampling by removing `items_proportion` of interactions
|
|
835
|
+
for each items counts that exceeds the `alpha_quantile` value in distribution. Filter firstly
|
|
836
|
+
removes popular items (items that have most interactions). Filter also keeps the original
|
|
837
|
+
relation of items popularity among each other by removing interactions only in range of
|
|
838
|
+
current item count and quantile count (specified by `alpha_quantile`).
|
|
839
|
+
|
|
840
|
+
>>> import pandas as pd
|
|
841
|
+
>>> from replay.utils.spark_utils import convert2spark
|
|
842
|
+
>>> log_pd = pd.DataFrame({
|
|
843
|
+
... "user_id": [0, 0, 1, 2, 2, 2, 2],
|
|
844
|
+
... "item_id": [0, 2, 1, 1, 2, 2, 2]
|
|
845
|
+
... })
|
|
846
|
+
>>> log_spark = convert2spark(log_pd)
|
|
847
|
+
>>> log_spark.show()
|
|
848
|
+
+-------+-------+
|
|
849
|
+
|user_id|item_id|
|
|
850
|
+
+-------+-------+
|
|
851
|
+
| 0| 0|
|
|
852
|
+
| 0| 2|
|
|
853
|
+
| 1| 1|
|
|
854
|
+
| 2| 1|
|
|
855
|
+
| 2| 2|
|
|
856
|
+
| 2| 2|
|
|
857
|
+
| 2| 2|
|
|
858
|
+
+-------+-------+
|
|
859
|
+
<BLANKLINE>
|
|
860
|
+
|
|
861
|
+
>>> QuantileItemsFilter(query_column="user_id").transform(log_spark).show()
|
|
862
|
+
+-------+-------+
|
|
863
|
+
|user_id|item_id|
|
|
864
|
+
+-------+-------+
|
|
865
|
+
| 0| 0|
|
|
866
|
+
| 1| 1|
|
|
867
|
+
| 2| 1|
|
|
868
|
+
| 2| 2|
|
|
869
|
+
| 2| 2|
|
|
870
|
+
| 0| 2|
|
|
871
|
+
+-------+-------+
|
|
872
|
+
<BLANKLINE>
|
|
873
|
+
"""
|
|
874
|
+
|
|
875
|
+
def __init__(
|
|
876
|
+
self,
|
|
877
|
+
alpha_quantile: float = 0.99,
|
|
878
|
+
items_proportion: float = 0.5,
|
|
879
|
+
query_column: str = "query_id",
|
|
880
|
+
item_column: str = "item_id",
|
|
881
|
+
) -> None:
|
|
882
|
+
"""
|
|
883
|
+
:param alpha_quantile: Quantile value of items counts distribution to keep unchanged.
|
|
884
|
+
Every items count that exceeds this value will be undersampled.
|
|
885
|
+
Default: ``0.99``.
|
|
886
|
+
:param items_proportion: proportion of items counts to remove for items that
|
|
887
|
+
exceeds `alpha_quantile` value in range of current item count and quantile count
|
|
888
|
+
to make sure we keep original relation between items unchanged.
|
|
889
|
+
Default: ``0.5``.
|
|
890
|
+
:param query_column: query column name.
|
|
891
|
+
Default: ``query_id``.
|
|
892
|
+
:param item_column: item column name.
|
|
893
|
+
Default: ``item_id``.
|
|
894
|
+
"""
|
|
895
|
+
if not 0 < alpha_quantile < 1:
|
|
896
|
+
msg = "`alpha_quantile` value must be in (0, 1)"
|
|
897
|
+
raise ValueError(msg)
|
|
898
|
+
if not 0 < items_proportion < 1:
|
|
899
|
+
msg = "`items_proportion` value must be in (0, 1)"
|
|
900
|
+
raise ValueError(msg)
|
|
901
|
+
|
|
902
|
+
self.alpha_quantile = alpha_quantile
|
|
903
|
+
self.items_proportion = items_proportion
|
|
904
|
+
self.query_column = query_column
|
|
905
|
+
self.item_column = item_column
|
|
906
|
+
|
|
907
|
+
def _filter_pandas(self, df: pd.DataFrame):
|
|
908
|
+
items_distribution = df.groupby(self.item_column).size().reset_index().rename(columns={0: "counts"})
|
|
909
|
+
users_distribution = df.groupby(self.query_column).size().reset_index().rename(columns={0: "counts"})
|
|
910
|
+
count_threshold = items_distribution.loc[:, "counts"].quantile(self.alpha_quantile, interpolation="midpoint")
|
|
911
|
+
df_with_counts = df.merge(items_distribution, how="left", on=self.item_column).merge(
|
|
912
|
+
users_distribution, how="left", on=self.query_column, suffixes=["_items", "_users"]
|
|
913
|
+
)
|
|
914
|
+
long_tail = df_with_counts.loc[df_with_counts["counts_items"] <= count_threshold]
|
|
915
|
+
short_tail = df_with_counts.loc[df_with_counts["counts_items"] > count_threshold]
|
|
916
|
+
short_tail["num_items_to_delete"] = self.items_proportion * (
|
|
917
|
+
short_tail["counts_items"] - long_tail["counts_items"].max()
|
|
918
|
+
)
|
|
919
|
+
short_tail["num_items_to_delete"] = short_tail["num_items_to_delete"].astype("int")
|
|
920
|
+
short_tail = short_tail.sort_values("counts_users", ascending=False)
|
|
921
|
+
|
|
922
|
+
def get_mask(x):
|
|
923
|
+
mask = np.ones_like(x)
|
|
924
|
+
threshold = x.iloc[0]
|
|
925
|
+
mask[:threshold] = 0
|
|
926
|
+
return mask
|
|
927
|
+
|
|
928
|
+
mask = short_tail.groupby(self.item_column)["num_items_to_delete"].transform(get_mask).astype(bool)
|
|
929
|
+
return pd.concat([long_tail[df.columns], short_tail.loc[mask][df.columns]])
|
|
930
|
+
|
|
931
|
+
def _filter_polars(self, df: pl.DataFrame):
|
|
932
|
+
items_distribution = df.group_by(self.item_column).len()
|
|
933
|
+
users_distribution = df.group_by(self.query_column).len()
|
|
934
|
+
count_threshold = items_distribution.select("len").quantile(self.alpha_quantile, "midpoint")["len"][0]
|
|
935
|
+
df_with_counts = (
|
|
936
|
+
df.join(items_distribution, how="left", on=self.item_column).join(
|
|
937
|
+
users_distribution, how="left", on=self.query_column
|
|
938
|
+
)
|
|
939
|
+
).rename({"len": "counts_items", "len_right": "counts_users"})
|
|
940
|
+
long_tail = df_with_counts.filter(pl.col("counts_items") <= count_threshold)
|
|
941
|
+
short_tail = df_with_counts.filter(pl.col("counts_items") > count_threshold)
|
|
942
|
+
max_long_tail_count = long_tail["counts_items"].max()
|
|
943
|
+
items_to_delete = (
|
|
944
|
+
short_tail.select(
|
|
945
|
+
self.query_column,
|
|
946
|
+
self.item_column,
|
|
947
|
+
self.items_proportion * (pl.col("counts_items") - max_long_tail_count),
|
|
948
|
+
)
|
|
949
|
+
.with_columns(pl.col("literal").cast(pl.Int64).alias("num_items_to_delete"))
|
|
950
|
+
.select(self.item_column, "num_items_to_delete")
|
|
951
|
+
.unique(maintain_order=True)
|
|
952
|
+
)
|
|
953
|
+
short_tail = short_tail.join(items_to_delete, how="left", on=self.item_column).sort(
|
|
954
|
+
"counts_users", descending=True
|
|
955
|
+
)
|
|
956
|
+
short_tail = short_tail.with_columns(index=pl.int_range(short_tail.shape[0]))
|
|
957
|
+
grouped = short_tail.group_by(self.item_column, maintain_order=True).agg(
|
|
958
|
+
pl.col("index"), pl.col("num_items_to_delete")
|
|
959
|
+
)
|
|
960
|
+
grouped = grouped.with_columns(
|
|
961
|
+
pl.col("num_items_to_delete").list.get(0),
|
|
962
|
+
(pl.col("index").list.len() - pl.col("num_items_to_delete").list.get(0)).alias("tail"),
|
|
963
|
+
)
|
|
964
|
+
grouped = grouped.with_columns(pl.col("index").list.tail(pl.col("tail")))
|
|
965
|
+
grouped = grouped.explode("index").select("index")
|
|
966
|
+
short_tail = grouped.join(short_tail, how="left", on="index")
|
|
967
|
+
return pl.concat([long_tail.select(df.columns), short_tail.select(df.columns)])
|
|
968
|
+
|
|
969
|
+
def _filter_spark(self, df: SparkDataFrame):
|
|
970
|
+
items_distribution = df.groupBy(self.item_column).agg(sf.count(self.query_column).alias("counts_items"))
|
|
971
|
+
users_distribution = df.groupBy(self.query_column).agg(sf.count(self.item_column).alias("counts_users"))
|
|
972
|
+
count_threshold = items_distribution.toPandas().loc[:, "counts_items"].quantile(self.alpha_quantile, "midpoint")
|
|
973
|
+
df_with_counts = df.join(items_distribution, on=self.item_column).join(users_distribution, on=self.query_column)
|
|
974
|
+
long_tail = df_with_counts.filter(sf.col("counts_items") <= count_threshold)
|
|
975
|
+
short_tail = df_with_counts.filter(sf.col("counts_items") > count_threshold)
|
|
976
|
+
max_long_tail_count = long_tail.agg({"counts_items": "max"}).collect()[0][0]
|
|
977
|
+
items_to_delete = (
|
|
978
|
+
short_tail.withColumn(
|
|
979
|
+
"num_items_to_delete",
|
|
980
|
+
(self.items_proportion * (sf.col("counts_items") - max_long_tail_count)).cast("int"),
|
|
981
|
+
)
|
|
982
|
+
.select(self.item_column, "num_items_to_delete")
|
|
983
|
+
.distinct()
|
|
984
|
+
)
|
|
985
|
+
short_tail = short_tail.join(items_to_delete, on=self.item_column, how="left")
|
|
986
|
+
short_tail = short_tail.withColumn(
|
|
987
|
+
"index", sf.row_number().over(Window.partitionBy(self.item_column).orderBy(sf.col("counts_users").desc()))
|
|
988
|
+
)
|
|
989
|
+
short_tail = short_tail.filter(sf.col("index") > sf.col("num_items_to_delete"))
|
|
990
|
+
return long_tail.select(df.columns).union(short_tail.select(df.columns))
|
replay/utils/common.py
CHANGED
|
@@ -1,7 +1,12 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import inspect
|
|
1
3
|
import json
|
|
2
4
|
from pathlib import Path
|
|
3
|
-
from typing import Union
|
|
5
|
+
from typing import Any, Callable, Union
|
|
4
6
|
|
|
7
|
+
from polars import from_pandas as pl_from_pandas
|
|
8
|
+
|
|
9
|
+
from replay.data.dataset import Dataset
|
|
5
10
|
from replay.splitters import (
|
|
6
11
|
ColdUserRandomSplitter,
|
|
7
12
|
KFolds,
|
|
@@ -12,7 +17,16 @@ from replay.splitters import (
|
|
|
12
17
|
TimeSplitter,
|
|
13
18
|
TwoStageSplitter,
|
|
14
19
|
)
|
|
15
|
-
from replay.utils import
|
|
20
|
+
from replay.utils import (
|
|
21
|
+
TORCH_AVAILABLE,
|
|
22
|
+
PandasDataFrame,
|
|
23
|
+
PolarsDataFrame,
|
|
24
|
+
SparkDataFrame,
|
|
25
|
+
)
|
|
26
|
+
from replay.utils.spark_utils import (
|
|
27
|
+
convert2spark as pandas_to_spark,
|
|
28
|
+
spark_to_pandas,
|
|
29
|
+
)
|
|
16
30
|
|
|
17
31
|
SavableObject = Union[
|
|
18
32
|
ColdUserRandomSplitter,
|
|
@@ -23,10 +37,11 @@ SavableObject = Union[
|
|
|
23
37
|
RatioSplitter,
|
|
24
38
|
TimeSplitter,
|
|
25
39
|
TwoStageSplitter,
|
|
40
|
+
Dataset,
|
|
26
41
|
]
|
|
27
42
|
|
|
28
43
|
if TORCH_AVAILABLE:
|
|
29
|
-
from replay.data.nn import SequenceTokenizer
|
|
44
|
+
from replay.data.nn import PandasSequentialDataset, PolarsSequentialDataset, SequenceTokenizer
|
|
30
45
|
|
|
31
46
|
SavableObject = Union[
|
|
32
47
|
ColdUserRandomSplitter,
|
|
@@ -38,6 +53,8 @@ if TORCH_AVAILABLE:
|
|
|
38
53
|
TimeSplitter,
|
|
39
54
|
TwoStageSplitter,
|
|
40
55
|
SequenceTokenizer,
|
|
56
|
+
PandasSequentialDataset,
|
|
57
|
+
PolarsSequentialDataset,
|
|
41
58
|
]
|
|
42
59
|
|
|
43
60
|
|
|
@@ -50,7 +67,7 @@ def save_to_replay(obj: SavableObject, path: Union[str, Path]) -> None:
|
|
|
50
67
|
obj.save(path)
|
|
51
68
|
|
|
52
69
|
|
|
53
|
-
def load_from_replay(path: Union[str, Path]) -> SavableObject:
|
|
70
|
+
def load_from_replay(path: Union[str, Path], **kwargs) -> SavableObject:
|
|
54
71
|
"""
|
|
55
72
|
General function to load RePlay models, splitters and tokenizer.
|
|
56
73
|
|
|
@@ -60,6 +77,91 @@ def load_from_replay(path: Union[str, Path]) -> SavableObject:
|
|
|
60
77
|
with open(path / "init_args.json", "r") as file:
|
|
61
78
|
class_name = json.loads(file.read())["_class_name"]
|
|
62
79
|
obj_type = globals()[class_name]
|
|
63
|
-
obj = obj_type.load(path)
|
|
80
|
+
obj = obj_type.load(path, **kwargs)
|
|
64
81
|
|
|
65
82
|
return obj
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _check_if_dataframe(var: Any):
|
|
86
|
+
if not isinstance(var, (SparkDataFrame, PolarsDataFrame, PandasDataFrame)):
|
|
87
|
+
msg = f"Object of type {type(var)} is not a dataframe of known type (can be pandas|spark|polars)"
|
|
88
|
+
raise ValueError(msg)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def check_if_dataframe(*args_to_check: str) -> Callable[..., Any]:
|
|
92
|
+
def decorator_func(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
93
|
+
@functools.wraps(func)
|
|
94
|
+
def wrap_func(*args: Any, **kwargs: Any) -> Any:
|
|
95
|
+
extended_kwargs = {}
|
|
96
|
+
extended_kwargs.update(kwargs)
|
|
97
|
+
extended_kwargs.update(dict(zip(inspect.signature(func).parameters.keys(), args)))
|
|
98
|
+
# add default param values to dict with arguments
|
|
99
|
+
extended_kwargs.update(
|
|
100
|
+
{
|
|
101
|
+
x.name: x.default
|
|
102
|
+
for x in inspect.signature(func).parameters.values()
|
|
103
|
+
if x.name not in extended_kwargs and x.default is not x.empty
|
|
104
|
+
}
|
|
105
|
+
)
|
|
106
|
+
vals_to_check = [extended_kwargs[_arg] for _arg in args_to_check]
|
|
107
|
+
for val in vals_to_check:
|
|
108
|
+
_check_if_dataframe(val)
|
|
109
|
+
return func(*args, **kwargs)
|
|
110
|
+
|
|
111
|
+
return wrap_func
|
|
112
|
+
|
|
113
|
+
return decorator_func
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@check_if_dataframe("data")
|
|
117
|
+
def convert2pandas(
|
|
118
|
+
data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame], allow_collect_to_master: bool = False
|
|
119
|
+
) -> PandasDataFrame:
|
|
120
|
+
"""
|
|
121
|
+
Convert the spark|polars DataFrame to a pandas.DataFrame.
|
|
122
|
+
Returns unchanged dataframe if the input is already of type pandas.DataFrame.
|
|
123
|
+
|
|
124
|
+
:param data: The dataframe to convert. Can be polars|spark|pandas DataFrame.
|
|
125
|
+
:param allow_collect_to_master: If set to False (default) raises a warning
|
|
126
|
+
about collecting parallelized data to the master node.
|
|
127
|
+
"""
|
|
128
|
+
if isinstance(data, PandasDataFrame):
|
|
129
|
+
return data
|
|
130
|
+
if isinstance(data, PolarsDataFrame):
|
|
131
|
+
return data.to_pandas()
|
|
132
|
+
if isinstance(data, SparkDataFrame):
|
|
133
|
+
return spark_to_pandas(data, allow_collect_to_master, from_constructor=False)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@check_if_dataframe("data")
|
|
137
|
+
def convert2polars(
|
|
138
|
+
data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame], allow_collect_to_master: bool = False
|
|
139
|
+
) -> PolarsDataFrame:
|
|
140
|
+
"""
|
|
141
|
+
Convert the spark|pandas DataFrame to a polars.DataFrame.
|
|
142
|
+
Returns unchanged dataframe if the input is already of type polars.DataFrame.
|
|
143
|
+
|
|
144
|
+
:param data: The dataframe to convert. Can be spark|pandas|polars DataFrame.
|
|
145
|
+
:param allow_collect_to_master: If set to False (default) raises a warning
|
|
146
|
+
about collecting parallelized data to the master node.
|
|
147
|
+
"""
|
|
148
|
+
if isinstance(data, PandasDataFrame):
|
|
149
|
+
return pl_from_pandas(data)
|
|
150
|
+
if isinstance(data, PolarsDataFrame):
|
|
151
|
+
return data
|
|
152
|
+
if isinstance(data, SparkDataFrame):
|
|
153
|
+
return pl_from_pandas(spark_to_pandas(data, allow_collect_to_master, from_constructor=False))
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@check_if_dataframe("data")
|
|
157
|
+
def convert2spark(data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame]) -> SparkDataFrame:
|
|
158
|
+
"""
|
|
159
|
+
Convert the pandas|polars DataFrame to a pysaprk.sql.DataFrame.
|
|
160
|
+
Returns unchanged dataframe if the input is already of type pysaprk.sql.DataFrame.
|
|
161
|
+
|
|
162
|
+
:param data: The dataframe to convert. Can be pandas|polars|spark Datarame.
|
|
163
|
+
"""
|
|
164
|
+
if isinstance(data, (PandasDataFrame, SparkDataFrame)):
|
|
165
|
+
return pandas_to_spark(data)
|
|
166
|
+
if isinstance(data, PolarsDataFrame):
|
|
167
|
+
return pandas_to_spark(data.to_pandas())
|
replay/utils/spark_utils.py
CHANGED
|
@@ -33,7 +33,9 @@ class SparkCollectToMasterWarning(Warning): # pragma: no cover
|
|
|
33
33
|
"""
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def spark_to_pandas(
|
|
36
|
+
def spark_to_pandas(
|
|
37
|
+
data: SparkDataFrame, allow_collect_to_master: bool = False, from_constructor: bool = True
|
|
38
|
+
) -> pd.DataFrame: # pragma: no cover
|
|
37
39
|
"""
|
|
38
40
|
Convert Spark DataFrame to Pandas DataFrame.
|
|
39
41
|
|
|
@@ -42,10 +44,15 @@ def spark_to_pandas(data: SparkDataFrame, allow_collect_to_master: bool = False)
|
|
|
42
44
|
|
|
43
45
|
:returns: Converted Pandas DataFrame.
|
|
44
46
|
"""
|
|
47
|
+
warn_msg = "Spark Data Frame is collected to master node, this may lead to OOM exception for larger dataset. "
|
|
48
|
+
if from_constructor:
|
|
49
|
+
_msg = "To remove this warning set allow_collect_to_master=True in the recommender constructor."
|
|
50
|
+
else:
|
|
51
|
+
_msg = "To remove this warning set allow_collect_to_master=True."
|
|
52
|
+
warn_msg += _msg
|
|
45
53
|
if not allow_collect_to_master:
|
|
46
54
|
warnings.warn(
|
|
47
|
-
|
|
48
|
-
"To remove this warning set allow_collect_to_master=True in the recommender constructor.",
|
|
55
|
+
warn_msg,
|
|
49
56
|
SparkCollectToMasterWarning,
|
|
50
57
|
)
|
|
51
58
|
return data.toPandas()
|
|
@@ -169,7 +176,7 @@ if PYSPARK_AVAILABLE:
|
|
|
169
176
|
<BLANKLINE>
|
|
170
177
|
>>> output_data = input_data.select(vector_dot("one", "two").alias("dot"))
|
|
171
178
|
>>> output_data.schema
|
|
172
|
-
StructType(
|
|
179
|
+
StructType([StructField('dot', DoubleType(), True)])
|
|
173
180
|
>>> output_data.show()
|
|
174
181
|
+----+
|
|
175
182
|
| dot|
|
|
@@ -207,7 +214,7 @@ if PYSPARK_AVAILABLE:
|
|
|
207
214
|
<BLANKLINE>
|
|
208
215
|
>>> output_data = input_data.select(vector_mult("one", "two").alias("mult"))
|
|
209
216
|
>>> output_data.schema
|
|
210
|
-
StructType(
|
|
217
|
+
StructType([StructField('mult', VectorUDT(), True)])
|
|
211
218
|
>>> output_data.show()
|
|
212
219
|
+---------+
|
|
213
220
|
| mult|
|
|
@@ -244,7 +251,7 @@ if PYSPARK_AVAILABLE:
|
|
|
244
251
|
<BLANKLINE>
|
|
245
252
|
>>> output_data = input_data.select(array_mult("one", "two").alias("mult"))
|
|
246
253
|
>>> output_data.schema
|
|
247
|
-
StructType(
|
|
254
|
+
StructType([StructField('mult', ArrayType(DoubleType(), True), True)])
|
|
248
255
|
>>> output_data.show()
|
|
249
256
|
+----------+
|
|
250
257
|
| mult|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: replay-rec
|
|
3
|
-
Version: 0.17.
|
|
3
|
+
Version: 0.17.1rc0
|
|
4
4
|
Summary: RecSys Library
|
|
5
5
|
Home-page: https://sb-ai-lab.github.io/RePlay/
|
|
6
6
|
License: Apache-2.0
|
|
@@ -32,11 +32,11 @@ Requires-Dist: nmslib (==2.1.1)
|
|
|
32
32
|
Requires-Dist: numba (>=0.50)
|
|
33
33
|
Requires-Dist: numpy (>=1.20.0)
|
|
34
34
|
Requires-Dist: optuna (>=3.2.0,<3.3.0)
|
|
35
|
-
Requires-Dist: pandas (>=1.3.5
|
|
35
|
+
Requires-Dist: pandas (>=1.3.5,<=2.2.2)
|
|
36
36
|
Requires-Dist: polars (>=0.20.7,<0.21.0)
|
|
37
37
|
Requires-Dist: psutil (>=5.9.5,<5.10.0)
|
|
38
38
|
Requires-Dist: pyarrow (>=12.0.1)
|
|
39
|
-
Requires-Dist: pyspark (>=3.0,<3.
|
|
39
|
+
Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark" or extra == "all"
|
|
40
40
|
Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "all"
|
|
41
41
|
Requires-Dist: sb-obp (>=0.5.7,<0.6.0)
|
|
42
42
|
Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
replay/__init__.py,sha256=
|
|
1
|
+
replay/__init__.py,sha256=_PQ2zFERSGjgeThzFv3t6MPODgutry1eR82biGhB98o,54
|
|
2
2
|
replay/data/__init__.py,sha256=g5bKRyF76QL_BqlED-31RnS8pBdcyj9loMsx5vAG_0E,301
|
|
3
|
-
replay/data/dataset.py,sha256=
|
|
3
|
+
replay/data/dataset.py,sha256=cSStvCqIc6WAJNtbmsxncSpcQZ1KfULMsrmf_V0UdPw,29490
|
|
4
4
|
replay/data/dataset_utils/__init__.py,sha256=9wUvG8ZwGUvuzLU4zQI5FDcH0WVVo5YLN2ey3DterP0,55
|
|
5
5
|
replay/data/dataset_utils/dataset_label_encoder.py,sha256=TEx2zLw5rJdIz1SRBEznyVv5x_Cs7o6QQbzMk-M1LU0,9598
|
|
6
6
|
replay/data/nn/__init__.py,sha256=WxLsi4rgOuuvGYHN49xBPxP2Srhqf3NYgfBDVH-ZvBo,1122
|
|
7
|
-
replay/data/nn/schema.py,sha256=
|
|
8
|
-
replay/data/nn/sequence_tokenizer.py,sha256=
|
|
9
|
-
replay/data/nn/sequential_dataset.py,sha256=
|
|
7
|
+
replay/data/nn/schema.py,sha256=pO4N7RgmgrqfD1-2d95OTeihKHTZ-5y2BG7CX_wBFi4,16198
|
|
8
|
+
replay/data/nn/sequence_tokenizer.py,sha256=Ambrp3CMOp3JP68PiwmVh0m-_zNXiWzxxVreHkEwOyY,32592
|
|
9
|
+
replay/data/nn/sequential_dataset.py,sha256=jCWxC0Pm1eQ5p8Y6_Bmg4fSEvPaecLrqz1iaWzaICdI,11014
|
|
10
10
|
replay/data/nn/torch_sequential_dataset.py,sha256=BqrK_PtkhpsaY1zRIWGk4EgwPL31a7IWCc0hLDuwDQc,10984
|
|
11
11
|
replay/data/nn/utils.py,sha256=YKE9gkIHZDDiwv4THqOWL4PzsdOujnPuM97v79Mwq0E,2769
|
|
12
12
|
replay/data/schema.py,sha256=F_cv6sYb6l23yuX5xWnbqoJ9oSeUT2NpIM19u8Lf2jA,15606
|
|
@@ -148,14 +148,14 @@ replay/optimization/__init__.py,sha256=az6U10rF7X6rPRUUPwLyiM1WFNJ_6kl0imA5xLVWF
|
|
|
148
148
|
replay/optimization/optuna_objective.py,sha256=Z-8X0_FT3BicVWj0UhxoLrvZAck3Dhn7jHDGo0i0hxA,7653
|
|
149
149
|
replay/preprocessing/__init__.py,sha256=TtBysFqYeDy4kZAEnWEaNSwPvbffYdfMkEs71YG51fM,411
|
|
150
150
|
replay/preprocessing/converter.py,sha256=DczqsVLrwFi6EFhK2HR8rGiIxGCwXeY7QNgWorjA41g,4390
|
|
151
|
-
replay/preprocessing/filters.py,sha256=
|
|
151
|
+
replay/preprocessing/filters.py,sha256=wsXWQoZ-2aAecunLkaTxeLWi5ow4e3FAGcElx0iNx0w,41669
|
|
152
152
|
replay/preprocessing/history_based_fp.py,sha256=tfgKJPKm53LSNqM6VmMXYsVrRDc-rP1Tbzn8s3mbziQ,18751
|
|
153
153
|
replay/preprocessing/label_encoder.py,sha256=MLBavPD-dB644as0E9ZJSE9-8QxGCB_IHek1w3xtqDI,27040
|
|
154
154
|
replay/preprocessing/sessionizer.py,sha256=G6i0K3FwqtweRxvcSYraJ-tBWAT2HnV-bWHHlIZJF-s,12217
|
|
155
155
|
replay/scenarios/__init__.py,sha256=kw2wRkPPinw0IBA20D83XQ3xeSudk3KuYAAA1Wdr8xY,93
|
|
156
156
|
replay/scenarios/fallback.py,sha256=EeBmIR-5igzKR2m55bQRFyhxTkpJez6ZkCW449n8hWs,7130
|
|
157
157
|
replay/splitters/__init__.py,sha256=DnqVMelrzLwR8fGQgcWN_8FipGs8T4XGSPOMW-L_x2g,454
|
|
158
|
-
replay/splitters/base_splitter.py,sha256=
|
|
158
|
+
replay/splitters/base_splitter.py,sha256=hj9_GYDWllzv3XnxN6WHu1JKRRVjXo77vZEOLbF9v-s,7761
|
|
159
159
|
replay/splitters/cold_user_random_splitter.py,sha256=gVwBVdn_0IOaLGT_UzJoS9AMaPhelZy-FpC5JQS1PhA,4136
|
|
160
160
|
replay/splitters/k_folds.py,sha256=WH02_DP18A2ae893ysonmfLPB56_i1ETllTAwaCYekg,6218
|
|
161
161
|
replay/splitters/last_n_splitter.py,sha256=r9kdq2JPi508C9ywjwc68an-iq27KsigMfHWLz0YohE,15346
|
|
@@ -165,16 +165,16 @@ replay/splitters/ratio_splitter.py,sha256=8zvuCn16Icc4ntQPKXJ5ArAWuJzCZ9NHZtgWct
|
|
|
165
165
|
replay/splitters/time_splitter.py,sha256=iXhuafjBx7dWyJSy-TEVy1IUQBwMpA1gAiF4-GtRe2g,9031
|
|
166
166
|
replay/splitters/two_stage_splitter.py,sha256=PWozxjjgjrVzdz6Sm9dcDTeH0bOA24reFzkk_N_TgbQ,17734
|
|
167
167
|
replay/utils/__init__.py,sha256=vDJgOWq81fbBs-QO4ZDpdqR4KDyO1kMOOxBRi-5Gp7E,253
|
|
168
|
-
replay/utils/common.py,sha256=
|
|
168
|
+
replay/utils/common.py,sha256=s4Pro3QCkPeVBsj-s0vrbhd_pkJD-_-2M_sIguxGzQQ,5411
|
|
169
169
|
replay/utils/dataframe_bucketizer.py,sha256=LipmBBQkdkLGroZpbP9i7qvTombLdMxo2dUUys1m5OY,3748
|
|
170
170
|
replay/utils/distributions.py,sha256=kGGq2KzQZ-yhTuw_vtOsKFXVpXUOQ2l4aIFBcaDufZ8,1202
|
|
171
171
|
replay/utils/model_handler.py,sha256=V-mHDh8_UexjVSsMBBRA9yrjS_5MPHwYOwv_UrI-Zfs,6466
|
|
172
172
|
replay/utils/session_handler.py,sha256=ijTvDSNAe1D9R1e-dhtd-r80tFNiIBsFdWZLgw-gLEo,5153
|
|
173
|
-
replay/utils/spark_utils.py,sha256=
|
|
173
|
+
replay/utils/spark_utils.py,sha256=k5lUFM2C9QZKQON3dqhgfswyUF4tsgJOn0U2wCKimqM,26901
|
|
174
174
|
replay/utils/time.py,sha256=J8asoQBytPcNw-BLGADYIsKeWhIoN1H5hKiX9t2AMqo,9376
|
|
175
175
|
replay/utils/types.py,sha256=5sw0A7NG4ZgQKdWORnBy0wBZ5F98sP_Ju8SKQ6zbDS4,651
|
|
176
|
-
replay_rec-0.17.
|
|
177
|
-
replay_rec-0.17.
|
|
178
|
-
replay_rec-0.17.
|
|
179
|
-
replay_rec-0.17.
|
|
180
|
-
replay_rec-0.17.
|
|
176
|
+
replay_rec-0.17.1rc0.dist-info/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
|
|
177
|
+
replay_rec-0.17.1rc0.dist-info/METADATA,sha256=FgZduBS6AVq1qSNahVyNFCJILLPdVLVosbxjUxN7WkQ,10890
|
|
178
|
+
replay_rec-0.17.1rc0.dist-info/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
|
|
179
|
+
replay_rec-0.17.1rc0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
180
|
+
replay_rec-0.17.1rc0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|