replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
|
@@ -3,13 +3,12 @@ from typing import Optional
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from scipy.sparse import csr_matrix
|
|
5
5
|
|
|
6
|
-
from replay.utils import DataFrameLike, SparkDataFrame
|
|
6
|
+
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, SparkDataFrame
|
|
7
7
|
|
|
8
8
|
if PYSPARK_AVAILABLE:
|
|
9
9
|
from replay.utils.spark_utils import spark_to_pandas
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
# pylint: disable=too-few-public-methods
|
|
13
12
|
class CSRConverter:
|
|
14
13
|
"""
|
|
15
14
|
Convert input data to csr sparse matrix.
|
|
@@ -47,7 +46,6 @@ class CSRConverter:
|
|
|
47
46
|
<BLANKLINE>
|
|
48
47
|
"""
|
|
49
48
|
|
|
50
|
-
# pylint: disable=too-many-arguments
|
|
51
49
|
def __init__(
|
|
52
50
|
self,
|
|
53
51
|
first_dim_column: str,
|
|
@@ -96,10 +94,7 @@ class CSRConverter:
|
|
|
96
94
|
|
|
97
95
|
rows_data = data[self.first_dim_column].values
|
|
98
96
|
cols_data = data[self.second_dim_column].values
|
|
99
|
-
if self.data_column is not None
|
|
100
|
-
data = data[self.data_column].values
|
|
101
|
-
else:
|
|
102
|
-
data = np.ones(data.shape[0])
|
|
97
|
+
data = data[self.data_column].values if self.data_column is not None else np.ones(data.shape[0])
|
|
103
98
|
|
|
104
99
|
def _get_max(data: np.ndarray) -> int:
|
|
105
100
|
return np.max(data) if data.shape[0] > 0 else 0
|
replay/preprocessing/filters.py
CHANGED
|
@@ -1,22 +1,23 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Select or remove data by some criteria
|
|
3
3
|
"""
|
|
4
|
-
import polars as pl
|
|
5
4
|
from abc import ABC, abstractmethod
|
|
6
5
|
from datetime import datetime, timedelta
|
|
7
|
-
from typing import Callable, Optional,
|
|
6
|
+
from typing import Callable, Optional, Tuple, Union
|
|
8
7
|
|
|
9
|
-
|
|
8
|
+
import polars as pl
|
|
10
9
|
|
|
10
|
+
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
11
11
|
|
|
12
12
|
if PYSPARK_AVAILABLE:
|
|
13
|
-
from pyspark.sql import
|
|
14
|
-
|
|
13
|
+
from pyspark.sql import (
|
|
14
|
+
Window,
|
|
15
|
+
functions as sf,
|
|
16
|
+
)
|
|
15
17
|
from pyspark.sql.functions import col
|
|
16
18
|
from pyspark.sql.types import TimestampType
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
# pylint: disable=too-few-public-methods
|
|
20
21
|
class _BaseFilter(ABC):
|
|
21
22
|
def transform(self, interactions: DataFrameLike) -> DataFrameLike:
|
|
22
23
|
r"""Filter interactions.
|
|
@@ -32,7 +33,8 @@ class _BaseFilter(ABC):
|
|
|
32
33
|
elif isinstance(interactions, PolarsDataFrame):
|
|
33
34
|
return self._filter_polars(interactions)
|
|
34
35
|
else:
|
|
35
|
-
|
|
36
|
+
msg = f"{self.__class__.__name__} is not implemented for {type(interactions)}"
|
|
37
|
+
raise NotImplementedError(msg)
|
|
36
38
|
|
|
37
39
|
@abstractmethod
|
|
38
40
|
def _filter_spark(self, interactions: SparkDataFrame): # pragma: no cover
|
|
@@ -47,7 +49,6 @@ class _BaseFilter(ABC):
|
|
|
47
49
|
pass
|
|
48
50
|
|
|
49
51
|
|
|
50
|
-
# pylint: disable=too-few-public-methods, too-many-instance-attributes
|
|
51
52
|
class InteractionEntriesFilter(_BaseFilter):
|
|
52
53
|
"""
|
|
53
54
|
Remove interactions less than minimum constraint value and greater
|
|
@@ -81,7 +82,6 @@ class InteractionEntriesFilter(_BaseFilter):
|
|
|
81
82
|
<BLANKLINE>
|
|
82
83
|
"""
|
|
83
84
|
|
|
84
|
-
# pylint: disable=too-many-arguments
|
|
85
85
|
def __init__(
|
|
86
86
|
self,
|
|
87
87
|
query_column: str = "user_id",
|
|
@@ -166,11 +166,10 @@ class InteractionEntriesFilter(_BaseFilter):
|
|
|
166
166
|
interactions, interaction_count, agg_column, non_agg_column, min_inter, max_inter
|
|
167
167
|
)
|
|
168
168
|
is_dropped_user_item[current_index] = bool(dropped_interact)
|
|
169
|
-
current_index = (current_index + 1) % 2
|
|
169
|
+
current_index = (current_index + 1) % 2 # current_index only in (0, 1)
|
|
170
170
|
|
|
171
171
|
return interactions
|
|
172
172
|
|
|
173
|
-
# pylint: disable=no-self-use
|
|
174
173
|
def _filter_column_pandas(
|
|
175
174
|
self,
|
|
176
175
|
interactions: PandasDataFrame,
|
|
@@ -196,7 +195,6 @@ class InteractionEntriesFilter(_BaseFilter):
|
|
|
196
195
|
|
|
197
196
|
return filtered_interactions, different_len, end_len_dataframe
|
|
198
197
|
|
|
199
|
-
# pylint: disable=no-self-use
|
|
200
198
|
def _filter_column_spark(
|
|
201
199
|
self,
|
|
202
200
|
interactions: SparkDataFrame,
|
|
@@ -223,7 +221,6 @@ class InteractionEntriesFilter(_BaseFilter):
|
|
|
223
221
|
|
|
224
222
|
return filtered_interactions, different_len, end_len_dataframe
|
|
225
223
|
|
|
226
|
-
# pylint: disable=no-self-use
|
|
227
224
|
def _filter_column_polars(
|
|
228
225
|
self,
|
|
229
226
|
interactions: PolarsDataFrame,
|
|
@@ -234,8 +231,7 @@ class InteractionEntriesFilter(_BaseFilter):
|
|
|
234
231
|
max_inter: Optional[int] = None,
|
|
235
232
|
) -> Tuple[PolarsDataFrame, int, int]:
|
|
236
233
|
filtered_interactions = interactions.with_columns(
|
|
237
|
-
pl.col(non_agg_column).count().over(pl.col(agg_column))
|
|
238
|
-
.alias("count")
|
|
234
|
+
pl.col(non_agg_column).count().over(pl.col(agg_column)).alias("count")
|
|
239
235
|
)
|
|
240
236
|
if min_inter:
|
|
241
237
|
filtered_interactions = filtered_interactions.filter(pl.col("count") >= min_inter)
|
|
@@ -294,32 +290,20 @@ class MinCountFilter(_BaseFilter):
|
|
|
294
290
|
|
|
295
291
|
def _filter_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
|
|
296
292
|
filtered_interactions = interactions.copy(deep=True)
|
|
297
|
-
filtered_interactions["count"] = (
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
)
|
|
302
|
-
return (
|
|
303
|
-
filtered_interactions[filtered_interactions["count"] >= self.num_entries]
|
|
304
|
-
.drop(columns=["count"])
|
|
305
|
-
)
|
|
293
|
+
filtered_interactions["count"] = filtered_interactions.groupby(self.groupby_column)[
|
|
294
|
+
self.groupby_column
|
|
295
|
+
].transform(len)
|
|
296
|
+
return filtered_interactions[filtered_interactions["count"] >= self.num_entries].drop(columns=["count"])
|
|
306
297
|
|
|
307
298
|
def _filter_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
|
|
308
299
|
filtered_interactions = interactions.clone()
|
|
309
300
|
count_by_group = (
|
|
310
|
-
filtered_interactions
|
|
311
|
-
.
|
|
312
|
-
.
|
|
313
|
-
pl.col(self.groupby_column).count().alias(f"{self.groupby_column}_temp_count")
|
|
314
|
-
)
|
|
315
|
-
.filter(
|
|
316
|
-
pl.col(f"{self.groupby_column}_temp_count") >= self.num_entries
|
|
317
|
-
)
|
|
301
|
+
filtered_interactions.group_by(self.groupby_column)
|
|
302
|
+
.agg(pl.col(self.groupby_column).count().alias(f"{self.groupby_column}_temp_count"))
|
|
303
|
+
.filter(pl.col(f"{self.groupby_column}_temp_count") >= self.num_entries)
|
|
318
304
|
)
|
|
319
|
-
return (
|
|
320
|
-
|
|
321
|
-
.join(count_by_group, on=self.groupby_column)
|
|
322
|
-
.drop(f"{self.groupby_column}_temp_count")
|
|
305
|
+
return filtered_interactions.join(count_by_group, on=self.groupby_column).drop(
|
|
306
|
+
f"{self.groupby_column}_temp_count"
|
|
323
307
|
)
|
|
324
308
|
|
|
325
309
|
|
|
@@ -423,7 +407,6 @@ class NumInteractionsFilter(_BaseFilter):
|
|
|
423
407
|
<BLANKLINE>
|
|
424
408
|
"""
|
|
425
409
|
|
|
426
|
-
# pylint: disable=too-many-arguments
|
|
427
410
|
def __init__(
|
|
428
411
|
self,
|
|
429
412
|
num_interactions: int = 10,
|
|
@@ -480,14 +463,12 @@ class NumInteractionsFilter(_BaseFilter):
|
|
|
480
463
|
ascending = [self.first] * len(sorting_columns)
|
|
481
464
|
|
|
482
465
|
filtered_interactions["temp_rank"] = (
|
|
483
|
-
filtered_interactions
|
|
484
|
-
.sort_values(sorting_columns, ascending=ascending)
|
|
466
|
+
filtered_interactions.sort_values(sorting_columns, ascending=ascending)
|
|
485
467
|
.groupby(self.query_column)
|
|
486
468
|
.cumcount()
|
|
487
469
|
)
|
|
488
|
-
return (
|
|
489
|
-
|
|
490
|
-
.drop(columns=["temp_rank"])
|
|
470
|
+
return filtered_interactions[filtered_interactions["temp_rank"] < self.num_interactions].drop(
|
|
471
|
+
columns=["temp_rank"]
|
|
491
472
|
)
|
|
492
473
|
|
|
493
474
|
def _filter_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
|
|
@@ -498,15 +479,10 @@ class NumInteractionsFilter(_BaseFilter):
|
|
|
498
479
|
descending = not self.first
|
|
499
480
|
|
|
500
481
|
return (
|
|
501
|
-
interactions
|
|
502
|
-
.
|
|
503
|
-
.
|
|
504
|
-
|
|
505
|
-
.cumcount()
|
|
506
|
-
.over(self.query_column)
|
|
507
|
-
.alias("temp_rank")
|
|
508
|
-
)
|
|
509
|
-
.filter(pl.col("temp_rank") <= self.num_interactions).drop("temp_rank")
|
|
482
|
+
interactions.sort(sorting_columns, descending=descending)
|
|
483
|
+
.with_columns(pl.col(self.query_column).cumcount().over(self.query_column).alias("temp_rank"))
|
|
484
|
+
.filter(pl.col("temp_rank") <= self.num_interactions)
|
|
485
|
+
.drop("temp_rank")
|
|
510
486
|
)
|
|
511
487
|
|
|
512
488
|
|
|
@@ -598,18 +574,13 @@ class EntityDaysFilter(_BaseFilter):
|
|
|
598
574
|
if self.first:
|
|
599
575
|
filtered_interactions = (
|
|
600
576
|
interactions.withColumn("min_date", sf.min(col(self.timestamp_column)).over(window))
|
|
601
|
-
.filter(
|
|
602
|
-
col(self.timestamp_column)
|
|
603
|
-
< col("min_date") + sf.expr(f"INTERVAL {self.days} days")
|
|
604
|
-
)
|
|
577
|
+
.filter(col(self.timestamp_column) < col("min_date") + sf.expr(f"INTERVAL {self.days} days"))
|
|
605
578
|
.drop("min_date")
|
|
606
579
|
)
|
|
607
580
|
else:
|
|
608
581
|
filtered_interactions = (
|
|
609
582
|
interactions.withColumn("max_date", sf.max(col(self.timestamp_column)).over(window))
|
|
610
|
-
.filter(
|
|
611
|
-
col(self.timestamp_column) > col("max_date") - sf.expr(f"INTERVAL {self.days} days")
|
|
612
|
-
)
|
|
583
|
+
.filter(col(self.timestamp_column) > col("max_date") - sf.expr(f"INTERVAL {self.days} days"))
|
|
613
584
|
.drop("max_date")
|
|
614
585
|
)
|
|
615
586
|
return filtered_interactions
|
|
@@ -618,57 +589,36 @@ class EntityDaysFilter(_BaseFilter):
|
|
|
618
589
|
filtered_interactions = interactions.copy(deep=True)
|
|
619
590
|
|
|
620
591
|
if self.first:
|
|
621
|
-
filtered_interactions["min_date"] = (
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
.drop(columns=["min_date"])
|
|
634
|
-
)
|
|
635
|
-
filtered_interactions["max_date"] = (
|
|
636
|
-
filtered_interactions
|
|
637
|
-
.groupby(self.entity_column)[self.timestamp_column]
|
|
638
|
-
.transform(max)
|
|
639
|
-
)
|
|
640
|
-
return (
|
|
641
|
-
filtered_interactions[
|
|
642
|
-
(
|
|
643
|
-
filtered_interactions["max_date"]
|
|
644
|
-
- filtered_interactions[self.timestamp_column]
|
|
645
|
-
).dt.days < self.days
|
|
646
|
-
]
|
|
647
|
-
.drop(columns=["max_date"])
|
|
648
|
-
)
|
|
592
|
+
filtered_interactions["min_date"] = filtered_interactions.groupby(self.entity_column)[
|
|
593
|
+
self.timestamp_column
|
|
594
|
+
].transform(min)
|
|
595
|
+
return filtered_interactions[
|
|
596
|
+
(filtered_interactions[self.timestamp_column] - filtered_interactions["min_date"]).dt.days < self.days
|
|
597
|
+
].drop(columns=["min_date"])
|
|
598
|
+
filtered_interactions["max_date"] = filtered_interactions.groupby(self.entity_column)[
|
|
599
|
+
self.timestamp_column
|
|
600
|
+
].transform(max)
|
|
601
|
+
return filtered_interactions[
|
|
602
|
+
(filtered_interactions["max_date"] - filtered_interactions[self.timestamp_column]).dt.days < self.days
|
|
603
|
+
].drop(columns=["max_date"])
|
|
649
604
|
|
|
650
605
|
def _filter_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
|
|
651
606
|
if self.first:
|
|
652
607
|
return (
|
|
653
|
-
interactions
|
|
654
|
-
.with_columns(
|
|
608
|
+
interactions.with_columns(
|
|
655
609
|
(
|
|
656
|
-
pl.col(self.timestamp_column)
|
|
657
|
-
|
|
658
|
-
)
|
|
659
|
-
.alias("min_date")
|
|
610
|
+
pl.col(self.timestamp_column).min().over(pl.col(self.entity_column))
|
|
611
|
+
+ pl.duration(days=self.days)
|
|
612
|
+
).alias("min_date")
|
|
660
613
|
)
|
|
661
614
|
.filter(pl.col(self.timestamp_column) < pl.col("min_date"))
|
|
662
615
|
.drop("min_date")
|
|
663
616
|
)
|
|
664
617
|
return (
|
|
665
|
-
interactions
|
|
666
|
-
.with_columns(
|
|
618
|
+
interactions.with_columns(
|
|
667
619
|
(
|
|
668
|
-
pl.col(self.timestamp_column)
|
|
669
|
-
|
|
670
|
-
)
|
|
671
|
-
.alias("max_date")
|
|
620
|
+
pl.col(self.timestamp_column).max().over(pl.col(self.entity_column)) - pl.duration(days=self.days)
|
|
621
|
+
).alias("max_date")
|
|
672
622
|
)
|
|
673
623
|
.filter(pl.col(self.timestamp_column) > pl.col("max_date"))
|
|
674
624
|
.drop("max_date")
|
|
@@ -749,15 +699,11 @@ class GlobalDaysFilter(_BaseFilter):
|
|
|
749
699
|
def _filter_spark(self, interactions: SparkDataFrame) -> SparkDataFrame:
|
|
750
700
|
if self.first:
|
|
751
701
|
start_date = interactions.agg(sf.min(self.timestamp_column)).first()[0]
|
|
752
|
-
end_date = sf.lit(start_date).cast(TimestampType()) + sf.expr(
|
|
753
|
-
f"INTERVAL {self.days} days"
|
|
754
|
-
)
|
|
702
|
+
end_date = sf.lit(start_date).cast(TimestampType()) + sf.expr(f"INTERVAL {self.days} days")
|
|
755
703
|
return interactions.filter(col(self.timestamp_column) < end_date)
|
|
756
704
|
|
|
757
705
|
end_date = interactions.agg(sf.max(self.timestamp_column)).first()[0]
|
|
758
|
-
start_date = sf.lit(end_date).cast(TimestampType()) - sf.expr(
|
|
759
|
-
f"INTERVAL {self.days} days"
|
|
760
|
-
)
|
|
706
|
+
start_date = sf.lit(end_date).cast(TimestampType()) - sf.expr(f"INTERVAL {self.days} days")
|
|
761
707
|
return interactions.filter(col(self.timestamp_column) > start_date)
|
|
762
708
|
|
|
763
709
|
def _filter_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
|
|
@@ -765,33 +711,19 @@ class GlobalDaysFilter(_BaseFilter):
|
|
|
765
711
|
|
|
766
712
|
if self.first:
|
|
767
713
|
start_date = filtered_interactions[self.timestamp_column].min()
|
|
768
|
-
return
|
|
769
|
-
filtered_interactions[
|
|
770
|
-
(filtered_interactions[self.timestamp_column] - start_date).dt.days < self.days
|
|
771
|
-
]
|
|
772
|
-
)
|
|
773
|
-
end_date = filtered_interactions[self.timestamp_column].max()
|
|
774
|
-
return (
|
|
775
|
-
filtered_interactions[
|
|
776
|
-
(end_date - filtered_interactions[self.timestamp_column]).dt.days < self.days
|
|
714
|
+
return filtered_interactions[
|
|
715
|
+
(filtered_interactions[self.timestamp_column] - start_date).dt.days < self.days
|
|
777
716
|
]
|
|
778
|
-
)
|
|
717
|
+
end_date = filtered_interactions[self.timestamp_column].max()
|
|
718
|
+
return filtered_interactions[(end_date - filtered_interactions[self.timestamp_column]).dt.days < self.days]
|
|
779
719
|
|
|
780
720
|
def _filter_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
|
|
781
721
|
if self.first:
|
|
782
|
-
return (
|
|
783
|
-
|
|
784
|
-
.filter(
|
|
785
|
-
pl.col(self.timestamp_column)
|
|
786
|
-
< (pl.col(self.timestamp_column).min() + pl.duration(days=self.days))
|
|
787
|
-
)
|
|
788
|
-
)
|
|
789
|
-
return (
|
|
790
|
-
interactions
|
|
791
|
-
.filter(
|
|
792
|
-
pl.col(self.timestamp_column)
|
|
793
|
-
> (pl.col(self.timestamp_column).max() - pl.duration(days=self.days))
|
|
722
|
+
return interactions.filter(
|
|
723
|
+
pl.col(self.timestamp_column) < (pl.col(self.timestamp_column).min() + pl.duration(days=self.days))
|
|
794
724
|
)
|
|
725
|
+
return interactions.filter(
|
|
726
|
+
pl.col(self.timestamp_column) > (pl.col(self.timestamp_column).max() - pl.duration(days=self.days))
|
|
795
727
|
)
|
|
796
728
|
|
|
797
729
|
|
|
@@ -823,7 +755,10 @@ class TimePeriodFilter(_BaseFilter):
|
|
|
823
755
|
+-------+-------+------+-------------------+
|
|
824
756
|
<BLANKLINE>
|
|
825
757
|
|
|
826
|
-
>>> TimePeriodFilter(
|
|
758
|
+
>>> TimePeriodFilter(
|
|
759
|
+
... start_date="2020-01-01 14:00:00",
|
|
760
|
+
... end_date=datetime(2020, 1, 3, 0, 0, 0)
|
|
761
|
+
... ).transform(log_sp).show()
|
|
827
762
|
+-------+-------+------+-------------------+
|
|
828
763
|
|user_id|item_id|rating| timestamp|
|
|
829
764
|
+-------+-------+------+-------------------+
|
|
@@ -861,9 +796,7 @@ class TimePeriodFilter(_BaseFilter):
|
|
|
861
796
|
if self.start_date is None:
|
|
862
797
|
self.start_date = interactions.agg(sf.min(self.timestamp_column)).first()[0]
|
|
863
798
|
if self.end_date is None:
|
|
864
|
-
self.end_date = interactions.agg(sf.max(self.timestamp_column)).first()[0] + timedelta(
|
|
865
|
-
seconds=1
|
|
866
|
-
)
|
|
799
|
+
self.end_date = interactions.agg(sf.max(self.timestamp_column)).first()[0] + timedelta(seconds=1)
|
|
867
800
|
|
|
868
801
|
return interactions.filter(
|
|
869
802
|
(col(self.timestamp_column) >= sf.lit(self.start_date))
|
|
@@ -874,9 +807,7 @@ class TimePeriodFilter(_BaseFilter):
|
|
|
874
807
|
if self.start_date is None:
|
|
875
808
|
self.start_date = interactions[self.timestamp_column].min()
|
|
876
809
|
if self.end_date is None:
|
|
877
|
-
self.end_date = interactions[self.timestamp_column].max() + timedelta(
|
|
878
|
-
seconds=1
|
|
879
|
-
)
|
|
810
|
+
self.end_date = interactions[self.timestamp_column].max() + timedelta(seconds=1)
|
|
880
811
|
|
|
881
812
|
return interactions[
|
|
882
813
|
(interactions[self.timestamp_column] >= self.start_date)
|
|
@@ -887,14 +818,8 @@ class TimePeriodFilter(_BaseFilter):
|
|
|
887
818
|
if self.start_date is None:
|
|
888
819
|
self.start_date = interactions.select(self.timestamp_column).min()[0, 0]
|
|
889
820
|
if self.end_date is None:
|
|
890
|
-
self.end_date = interactions.select(self.timestamp_column).max()[0, 0] + pl.duration(
|
|
891
|
-
seconds=1
|
|
892
|
-
)
|
|
821
|
+
self.end_date = interactions.select(self.timestamp_column).max()[0, 0] + pl.duration(seconds=1)
|
|
893
822
|
|
|
894
|
-
return (
|
|
895
|
-
|
|
896
|
-
.filter(
|
|
897
|
-
pl.col(self.timestamp_column)
|
|
898
|
-
.is_between(self.start_date, self.end_date, closed="left")
|
|
899
|
-
)
|
|
823
|
+
return interactions.filter(
|
|
824
|
+
pl.col(self.timestamp_column).is_between(self.start_date, self.end_date, closed="left")
|
|
900
825
|
)
|