replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from typing import Optional,
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
2
3
|
import polars as pl
|
|
3
4
|
|
|
4
|
-
from .base_splitter import Splitter, SplitterReturnType
|
|
5
5
|
from replay.utils import (
|
|
6
6
|
PYSPARK_AVAILABLE,
|
|
7
7
|
DataFrameLike,
|
|
@@ -10,11 +10,12 @@ from replay.utils import (
|
|
|
10
10
|
SparkDataFrame,
|
|
11
11
|
)
|
|
12
12
|
|
|
13
|
+
from .base_splitter import Splitter, SplitterReturnType
|
|
14
|
+
|
|
13
15
|
if PYSPARK_AVAILABLE:
|
|
14
16
|
import pyspark.sql.functions as sf
|
|
15
17
|
|
|
16
18
|
|
|
17
|
-
# pylint: disable=too-few-public-methods, duplicate-code
|
|
18
19
|
class ColdUserRandomSplitter(Splitter):
|
|
19
20
|
"""
|
|
20
21
|
Test set consists of all actions of randomly chosen users.
|
|
@@ -28,7 +29,6 @@ class ColdUserRandomSplitter(Splitter):
|
|
|
28
29
|
"item_column",
|
|
29
30
|
]
|
|
30
31
|
|
|
31
|
-
# pylint: disable=too-many-arguments
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
34
|
test_size: float,
|
|
@@ -52,14 +52,13 @@ class ColdUserRandomSplitter(Splitter):
|
|
|
52
52
|
)
|
|
53
53
|
self.seed = seed
|
|
54
54
|
if test_size <= 0 or test_size >= 1:
|
|
55
|
-
|
|
55
|
+
msg = "test_size must between 0 and 1"
|
|
56
|
+
raise ValueError(msg)
|
|
56
57
|
self.test_size = test_size
|
|
57
58
|
|
|
58
59
|
def _core_split_pandas(
|
|
59
|
-
self,
|
|
60
|
-
|
|
61
|
-
threshold: float
|
|
62
|
-
) -> Union[PandasDataFrame, PandasDataFrame]:
|
|
60
|
+
self, interactions: PandasDataFrame, threshold: float
|
|
61
|
+
) -> Tuple[PandasDataFrame, PandasDataFrame]:
|
|
63
62
|
users = PandasDataFrame(interactions[self.query_column].unique(), columns=[self.query_column])
|
|
64
63
|
train_users = users.sample(frac=(1 - threshold), random_state=self.seed)
|
|
65
64
|
train_users["is_test"] = False
|
|
@@ -74,19 +73,15 @@ class ColdUserRandomSplitter(Splitter):
|
|
|
74
73
|
return train, test
|
|
75
74
|
|
|
76
75
|
def _core_split_spark(
|
|
77
|
-
self,
|
|
78
|
-
|
|
79
|
-
threshold: float
|
|
80
|
-
) -> Union[SparkDataFrame, SparkDataFrame]:
|
|
76
|
+
self, interactions: SparkDataFrame, threshold: float
|
|
77
|
+
) -> Tuple[SparkDataFrame, SparkDataFrame]:
|
|
81
78
|
users = interactions.select(self.query_column).distinct()
|
|
82
79
|
train_users, _ = users.randomSplit(
|
|
83
80
|
[1 - threshold, threshold],
|
|
84
81
|
seed=self.seed,
|
|
85
82
|
)
|
|
86
83
|
interactions = interactions.join(
|
|
87
|
-
train_users.withColumn("is_test", sf.lit(False)),
|
|
88
|
-
on=self.query_column,
|
|
89
|
-
how="left"
|
|
84
|
+
train_users.withColumn("is_test", sf.lit(False)), on=self.query_column, how="left"
|
|
90
85
|
).na.fill({"is_test": True})
|
|
91
86
|
|
|
92
87
|
train = interactions.filter(~sf.col("is_test")).drop("is_test")
|
|
@@ -95,27 +90,18 @@ class ColdUserRandomSplitter(Splitter):
|
|
|
95
90
|
return train, test
|
|
96
91
|
|
|
97
92
|
def _core_split_polars(
|
|
98
|
-
self,
|
|
99
|
-
|
|
100
|
-
threshold: float
|
|
101
|
-
) -> Union[PolarsDataFrame, PolarsDataFrame]:
|
|
93
|
+
self, interactions: PolarsDataFrame, threshold: float
|
|
94
|
+
) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
102
95
|
train_users = (
|
|
103
|
-
interactions
|
|
104
|
-
.select(self.query_column)
|
|
96
|
+
interactions.select(self.query_column)
|
|
105
97
|
.unique()
|
|
106
98
|
.sample(fraction=(1 - threshold), seed=self.seed)
|
|
107
99
|
.with_columns(pl.lit(False).alias("is_test"))
|
|
108
100
|
)
|
|
109
101
|
|
|
110
|
-
interactions = (
|
|
111
|
-
interactions
|
|
112
|
-
.join(
|
|
113
|
-
train_users,
|
|
114
|
-
on=self.query_column, how="left")
|
|
115
|
-
.fill_null(True)
|
|
116
|
-
)
|
|
102
|
+
interactions = interactions.join(train_users, on=self.query_column, how="left").fill_null(True)
|
|
117
103
|
|
|
118
|
-
train = interactions.filter(~pl.col("is_test")).drop("is_test")
|
|
104
|
+
train = interactions.filter(~pl.col("is_test")).drop("is_test")
|
|
119
105
|
test = interactions.filter(pl.col("is_test")).drop("is_test")
|
|
120
106
|
return train, test
|
|
121
107
|
|
|
@@ -127,4 +113,5 @@ class ColdUserRandomSplitter(Splitter):
|
|
|
127
113
|
if isinstance(interactions, PolarsDataFrame):
|
|
128
114
|
return self._core_split_polars(interactions, self.test_size)
|
|
129
115
|
|
|
130
|
-
|
|
116
|
+
msg = f"{self} is not implemented for {type(interactions)}"
|
|
117
|
+
raise NotImplementedError(msg)
|
replay/splitters/k_folds.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from typing import Literal, Optional, Tuple
|
|
2
|
+
|
|
2
3
|
import polars as pl
|
|
3
4
|
|
|
4
|
-
from .base_splitter import Splitter, SplitterReturnType
|
|
5
5
|
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
6
6
|
|
|
7
|
+
from .base_splitter import Splitter, SplitterReturnType
|
|
8
|
+
|
|
7
9
|
if PYSPARK_AVAILABLE:
|
|
8
10
|
import pyspark.sql.functions as sf
|
|
9
11
|
from pyspark.sql import Window
|
|
@@ -11,11 +13,11 @@ if PYSPARK_AVAILABLE:
|
|
|
11
13
|
StrategyName = Literal["query"]
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
# pylint: disable=too-few-public-methods
|
|
15
16
|
class KFolds(Splitter):
|
|
16
17
|
"""
|
|
17
18
|
Splits interactions inside each query into folds at random.
|
|
18
19
|
"""
|
|
20
|
+
|
|
19
21
|
_init_arg_names = [
|
|
20
22
|
"n_folds",
|
|
21
23
|
"strategy",
|
|
@@ -29,7 +31,6 @@ class KFolds(Splitter):
|
|
|
29
31
|
"session_id_processing_strategy",
|
|
30
32
|
]
|
|
31
33
|
|
|
32
|
-
# pylint: disable=too-many-arguments
|
|
33
34
|
def __init__(
|
|
34
35
|
self,
|
|
35
36
|
n_folds: Optional[int] = 5,
|
|
@@ -64,11 +65,12 @@ class KFolds(Splitter):
|
|
|
64
65
|
item_column=item_column,
|
|
65
66
|
timestamp_column=timestamp_column,
|
|
66
67
|
session_id_column=session_id_column,
|
|
67
|
-
session_id_processing_strategy=session_id_processing_strategy
|
|
68
|
+
session_id_processing_strategy=session_id_processing_strategy,
|
|
68
69
|
)
|
|
69
70
|
self.n_folds = n_folds
|
|
70
71
|
if strategy not in {"query"}:
|
|
71
|
-
|
|
72
|
+
msg = f"Wrong splitter parameter: {strategy}"
|
|
73
|
+
raise ValueError(msg)
|
|
72
74
|
self.strategy = strategy
|
|
73
75
|
self.seed = seed
|
|
74
76
|
|
|
@@ -85,16 +87,10 @@ class KFolds(Splitter):
|
|
|
85
87
|
dataframe = interactions.withColumn("_rand", sf.rand(self.seed))
|
|
86
88
|
dataframe = dataframe.withColumn(
|
|
87
89
|
"fold",
|
|
88
|
-
sf.row_number().over(
|
|
89
|
-
Window.partitionBy(self.query_column).orderBy("_rand")
|
|
90
|
-
)
|
|
91
|
-
% self.n_folds,
|
|
90
|
+
sf.row_number().over(Window.partitionBy(self.query_column).orderBy("_rand")) % self.n_folds,
|
|
92
91
|
).drop("_rand")
|
|
93
92
|
for i in range(self.n_folds):
|
|
94
|
-
dataframe = dataframe.withColumn(
|
|
95
|
-
"is_test",
|
|
96
|
-
sf.when(sf.col("fold") == i, True).otherwise(False)
|
|
97
|
-
)
|
|
93
|
+
dataframe = dataframe.withColumn("is_test", sf.when(sf.col("fold") == i, True).otherwise(False))
|
|
98
94
|
if self.session_id_column:
|
|
99
95
|
dataframe = self._recalculate_with_session_id_column(dataframe)
|
|
100
96
|
|
|
@@ -122,28 +118,21 @@ class KFolds(Splitter):
|
|
|
122
118
|
def _query_split_polars(self, interactions: PolarsDataFrame) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
123
119
|
dataframe = interactions.sample(fraction=1, shuffle=True, seed=self.seed).sort(self.query_column)
|
|
124
120
|
dataframe = dataframe.with_columns(
|
|
125
|
-
(pl.cum_count(self.query_column).over(self.query_column) % self.n_folds)
|
|
126
|
-
.alias("fold")
|
|
121
|
+
(pl.cum_count(self.query_column).over(self.query_column) % self.n_folds).alias("fold")
|
|
127
122
|
)
|
|
128
123
|
for i in range(self.n_folds):
|
|
129
124
|
dataframe = dataframe.with_columns(
|
|
130
|
-
pl.when(
|
|
131
|
-
pl.col("fold") == i
|
|
132
|
-
)
|
|
133
|
-
.then(True)
|
|
134
|
-
.otherwise(False)
|
|
135
|
-
.alias("is_test")
|
|
125
|
+
pl.when(pl.col("fold") == i).then(True).otherwise(False).alias("is_test")
|
|
136
126
|
)
|
|
137
127
|
if self.session_id_column:
|
|
138
128
|
dataframe = self._recalculate_with_session_id_column(dataframe)
|
|
139
129
|
|
|
140
|
-
train = dataframe.filter(~pl.col("is_test")).drop("is_test", "fold")
|
|
130
|
+
train = dataframe.filter(~pl.col("is_test")).drop("is_test", "fold")
|
|
141
131
|
test = dataframe.filter(pl.col("is_test")).drop("is_test", "fold")
|
|
142
132
|
|
|
143
133
|
test = self._drop_cold_items_and_users(train, test)
|
|
144
134
|
yield train, test
|
|
145
135
|
|
|
146
|
-
# pylint: disable=inconsistent-return-statements
|
|
147
136
|
def _core_split(self, interactions: DataFrameLike) -> SplitterReturnType:
|
|
148
137
|
if self.strategy == "query":
|
|
149
138
|
if isinstance(interactions, SparkDataFrame):
|
|
@@ -153,4 +142,5 @@ class KFolds(Splitter):
|
|
|
153
142
|
if isinstance(interactions, PolarsDataFrame):
|
|
154
143
|
return self._query_split_polars(interactions)
|
|
155
144
|
|
|
156
|
-
|
|
145
|
+
msg = f"{self} is not implemented for {type(interactions)}"
|
|
146
|
+
raise NotImplementedError(msg)
|
|
@@ -4,9 +4,10 @@ import numpy as np
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import polars as pl
|
|
6
6
|
|
|
7
|
-
from .base_splitter import Splitter
|
|
8
7
|
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
9
8
|
|
|
9
|
+
from .base_splitter import Splitter
|
|
10
|
+
|
|
10
11
|
if PYSPARK_AVAILABLE:
|
|
11
12
|
import pyspark.sql.functions as sf
|
|
12
13
|
from pyspark.sql import Window
|
|
@@ -14,7 +15,6 @@ if PYSPARK_AVAILABLE:
|
|
|
14
15
|
StrategyName = Literal["interactions", "timedelta"]
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
# pylint: disable=too-few-public-methods
|
|
18
18
|
class LastNSplitter(Splitter):
|
|
19
19
|
"""
|
|
20
20
|
Split interactions by last N interactions/timedelta per user.
|
|
@@ -88,10 +88,11 @@ class LastNSplitter(Splitter):
|
|
|
88
88
|
14 3 2 2020-01-05
|
|
89
89
|
<BLANKLINE>
|
|
90
90
|
"""
|
|
91
|
+
|
|
91
92
|
_init_arg_names = [
|
|
92
93
|
"N",
|
|
93
94
|
"divide_column",
|
|
94
|
-
"
|
|
95
|
+
"time_column_format",
|
|
95
96
|
"strategy",
|
|
96
97
|
"drop_cold_users",
|
|
97
98
|
"drop_cold_items",
|
|
@@ -102,10 +103,9 @@ class LastNSplitter(Splitter):
|
|
|
102
103
|
"session_id_processing_strategy",
|
|
103
104
|
]
|
|
104
105
|
|
|
105
|
-
# pylint: disable=invalid-name, too-many-arguments
|
|
106
106
|
def __init__(
|
|
107
107
|
self,
|
|
108
|
-
N: int,
|
|
108
|
+
N: int, # noqa: N803
|
|
109
109
|
divide_column: str = "query_id",
|
|
110
110
|
time_column_format: str = "yyyy-MM-dd HH:mm:ss",
|
|
111
111
|
strategy: StrategyName = "interactions",
|
|
@@ -147,7 +147,8 @@ class LastNSplitter(Splitter):
|
|
|
147
147
|
default: ``test``.
|
|
148
148
|
"""
|
|
149
149
|
if strategy not in ["interactions", "timedelta"]:
|
|
150
|
-
|
|
150
|
+
msg = "strategy must be equal 'interactions' or 'timedelta'"
|
|
151
|
+
raise ValueError(msg)
|
|
151
152
|
super().__init__(
|
|
152
153
|
drop_cold_users=drop_cold_users,
|
|
153
154
|
drop_cold_items=drop_cold_items,
|
|
@@ -160,9 +161,9 @@ class LastNSplitter(Splitter):
|
|
|
160
161
|
self.N = N
|
|
161
162
|
self.strategy = strategy
|
|
162
163
|
self.divide_column = divide_column
|
|
163
|
-
self.
|
|
164
|
+
self.time_column_format = None
|
|
164
165
|
if self.strategy == "timedelta":
|
|
165
|
-
self.
|
|
166
|
+
self.time_column_format = time_column_format
|
|
166
167
|
|
|
167
168
|
def _add_time_partition(self, interactions: DataFrameLike) -> DataFrameLike:
|
|
168
169
|
if isinstance(interactions, SparkDataFrame):
|
|
@@ -172,7 +173,8 @@ class LastNSplitter(Splitter):
|
|
|
172
173
|
if isinstance(interactions, PolarsDataFrame):
|
|
173
174
|
return self._add_time_partition_to_polars(interactions)
|
|
174
175
|
|
|
175
|
-
|
|
176
|
+
msg = f"{self} is not implemented for {type(interactions)}"
|
|
177
|
+
raise NotImplementedError(msg)
|
|
176
178
|
|
|
177
179
|
def _add_time_partition_to_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
|
|
178
180
|
res = interactions.copy(deep=True)
|
|
@@ -191,8 +193,7 @@ class LastNSplitter(Splitter):
|
|
|
191
193
|
|
|
192
194
|
def _add_time_partition_to_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
|
|
193
195
|
res = interactions.sort(self.timestamp_column).with_columns(
|
|
194
|
-
pl.col(self.divide_column).cumcount().over(pl.col(self.divide_column))
|
|
195
|
-
.alias("row_num")
|
|
196
|
+
pl.col(self.divide_column).cumcount().over(pl.col(self.divide_column)).alias("row_num")
|
|
196
197
|
)
|
|
197
198
|
|
|
198
199
|
return res
|
|
@@ -205,7 +206,8 @@ class LastNSplitter(Splitter):
|
|
|
205
206
|
if isinstance(interactions, PolarsDataFrame):
|
|
206
207
|
return self._to_unix_timestamp_polars(interactions)
|
|
207
208
|
|
|
208
|
-
|
|
209
|
+
msg = f"{self} is not implemented for {type(interactions)}"
|
|
210
|
+
raise NotImplementedError(msg)
|
|
209
211
|
|
|
210
212
|
def _to_unix_timestamp_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
|
|
211
213
|
time_column_type = dict(interactions.dtypes)[self.timestamp_column]
|
|
@@ -221,7 +223,7 @@ class LastNSplitter(Splitter):
|
|
|
221
223
|
time_column_type = dict(interactions.dtypes)[self.timestamp_column]
|
|
222
224
|
if time_column_type == "date":
|
|
223
225
|
interactions = interactions.withColumn(
|
|
224
|
-
self.timestamp_column, sf.unix_timestamp(self.timestamp_column, self.
|
|
226
|
+
self.timestamp_column, sf.unix_timestamp(self.timestamp_column, self.time_column_format)
|
|
225
227
|
)
|
|
226
228
|
|
|
227
229
|
return interactions
|
|
@@ -233,20 +235,19 @@ class LastNSplitter(Splitter):
|
|
|
233
235
|
|
|
234
236
|
return interactions
|
|
235
237
|
|
|
236
|
-
|
|
237
|
-
def _partial_split_interactions(self, interactions: DataFrameLike, N: int) -> Tuple[DataFrameLike, DataFrameLike]:
|
|
238
|
+
def _partial_split_interactions(self, interactions: DataFrameLike, n: int) -> Tuple[DataFrameLike, DataFrameLike]:
|
|
238
239
|
res = self._add_time_partition(interactions)
|
|
239
240
|
if isinstance(interactions, SparkDataFrame):
|
|
240
|
-
return self._partial_split_interactions_spark(res,
|
|
241
|
+
return self._partial_split_interactions_spark(res, n)
|
|
241
242
|
if isinstance(interactions, PandasDataFrame):
|
|
242
|
-
return self._partial_split_interactions_pandas(res,
|
|
243
|
-
return self._partial_split_interactions_polars(res,
|
|
243
|
+
return self._partial_split_interactions_pandas(res, n)
|
|
244
|
+
return self._partial_split_interactions_polars(res, n)
|
|
244
245
|
|
|
245
246
|
def _partial_split_interactions_pandas(
|
|
246
|
-
self, interactions: PandasDataFrame,
|
|
247
|
+
self, interactions: PandasDataFrame, n: int
|
|
247
248
|
) -> Tuple[PandasDataFrame, PandasDataFrame]:
|
|
248
249
|
interactions["count"] = interactions.groupby(self.divide_column, sort=False)[self.divide_column].transform(len)
|
|
249
|
-
interactions["is_test"] = interactions["row_num"] > (interactions["count"] - float(
|
|
250
|
+
interactions["is_test"] = interactions["row_num"] > (interactions["count"] - float(n))
|
|
250
251
|
if self.session_id_column:
|
|
251
252
|
interactions = self._recalculate_with_session_id_column(interactions)
|
|
252
253
|
|
|
@@ -256,14 +257,14 @@ class LastNSplitter(Splitter):
|
|
|
256
257
|
return train, test
|
|
257
258
|
|
|
258
259
|
def _partial_split_interactions_spark(
|
|
259
|
-
self, interactions: SparkDataFrame,
|
|
260
|
+
self, interactions: SparkDataFrame, n: int
|
|
260
261
|
) -> Tuple[SparkDataFrame, SparkDataFrame]:
|
|
261
262
|
interactions = interactions.withColumn(
|
|
262
263
|
"count", sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column))
|
|
263
264
|
)
|
|
264
265
|
# float(n) - because DataFrame.filter is changing order
|
|
265
266
|
# of sorted DataFrame to descending
|
|
266
|
-
interactions = interactions.withColumn("is_test", sf.col("row_num") > sf.col("count") - sf.lit(float(
|
|
267
|
+
interactions = interactions.withColumn("is_test", sf.col("row_num") > sf.col("count") - sf.lit(float(n)))
|
|
267
268
|
if self.session_id_column:
|
|
268
269
|
interactions = self._recalculate_with_session_id_column(interactions)
|
|
269
270
|
|
|
@@ -273,27 +274,22 @@ class LastNSplitter(Splitter):
|
|
|
273
274
|
return train, test
|
|
274
275
|
|
|
275
276
|
def _partial_split_interactions_polars(
|
|
276
|
-
self, interactions: PolarsDataFrame,
|
|
277
|
+
self, interactions: PolarsDataFrame, n: int
|
|
277
278
|
) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
278
279
|
interactions = interactions.with_columns(
|
|
279
|
-
pl.col(self.timestamp_column).count().over(self.divide_column)
|
|
280
|
-
.alias("count")
|
|
281
|
-
)
|
|
282
|
-
interactions = interactions.with_columns(
|
|
283
|
-
(pl.col("row_num") > (pl.col("count") - N))
|
|
284
|
-
.alias("is_test")
|
|
280
|
+
pl.col(self.timestamp_column).count().over(self.divide_column).alias("count")
|
|
285
281
|
)
|
|
282
|
+
interactions = interactions.with_columns((pl.col("row_num") > (pl.col("count") - n)).alias("is_test"))
|
|
286
283
|
if self.session_id_column:
|
|
287
284
|
interactions = self._recalculate_with_session_id_column(interactions)
|
|
288
285
|
|
|
289
|
-
train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "is_test")
|
|
286
|
+
train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "is_test")
|
|
290
287
|
test = interactions.filter(pl.col("is_test")).drop("row_num", "count", "is_test")
|
|
291
288
|
|
|
292
289
|
return train, test
|
|
293
290
|
|
|
294
291
|
def _partial_split_timedelta(
|
|
295
|
-
self,
|
|
296
|
-
interactions: DataFrameLike, timedelta: int
|
|
292
|
+
self, interactions: DataFrameLike, timedelta: int
|
|
297
293
|
) -> Tuple[DataFrameLike, DataFrameLike]:
|
|
298
294
|
if isinstance(interactions, SparkDataFrame):
|
|
299
295
|
return self._partial_split_timedelta_spark(interactions, timedelta)
|
|
@@ -341,22 +337,16 @@ class LastNSplitter(Splitter):
|
|
|
341
337
|
def _partial_split_timedelta_polars(
|
|
342
338
|
self, interactions: PolarsDataFrame, timedelta: int
|
|
343
339
|
) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
344
|
-
res = (
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
(pl.col(self.timestamp_column).max().over(self.divide_column) - pl.col(self.timestamp_column))
|
|
348
|
-
.alias("diff_timestamp")
|
|
340
|
+
res = interactions.with_columns(
|
|
341
|
+
(pl.col(self.timestamp_column).max().over(self.divide_column) - pl.col(self.timestamp_column)).alias(
|
|
342
|
+
"diff_timestamp"
|
|
349
343
|
)
|
|
350
|
-
|
|
351
|
-
(pl.col("diff_timestamp") < timedelta)
|
|
352
|
-
.alias("is_test")
|
|
353
|
-
)
|
|
354
|
-
)
|
|
344
|
+
).with_columns((pl.col("diff_timestamp") < timedelta).alias("is_test"))
|
|
355
345
|
|
|
356
346
|
if self.session_id_column:
|
|
357
347
|
res = self._recalculate_with_session_id_column(res)
|
|
358
348
|
|
|
359
|
-
train = res.filter(~pl.col("is_test")).drop("diff_timestamp", "is_test")
|
|
349
|
+
train = res.filter(~pl.col("is_test")).drop("diff_timestamp", "is_test")
|
|
360
350
|
test = res.filter(pl.col("is_test")).drop("diff_timestamp", "is_test")
|
|
361
351
|
|
|
362
352
|
return train, test
|
|
@@ -1,15 +1,16 @@
|
|
|
1
|
-
from typing import Optional,
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
2
3
|
import polars as pl
|
|
3
4
|
|
|
4
|
-
from .base_splitter import Splitter, SplitterReturnType
|
|
5
5
|
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
6
6
|
|
|
7
|
+
from .base_splitter import Splitter, SplitterReturnType
|
|
8
|
+
|
|
7
9
|
if PYSPARK_AVAILABLE:
|
|
8
10
|
import pyspark.sql.functions as sf
|
|
9
11
|
from pyspark.sql import Window
|
|
10
12
|
|
|
11
13
|
|
|
12
|
-
# pylint: disable=too-few-public-methods, duplicate-code
|
|
13
14
|
class NewUsersSplitter(Splitter):
|
|
14
15
|
"""
|
|
15
16
|
Only new users will be assigned to test set.
|
|
@@ -63,7 +64,6 @@ class NewUsersSplitter(Splitter):
|
|
|
63
64
|
"session_id_processing_strategy",
|
|
64
65
|
]
|
|
65
66
|
|
|
66
|
-
# pylint: disable=too-many-arguments
|
|
67
67
|
def __init__(
|
|
68
68
|
self,
|
|
69
69
|
test_size: float,
|
|
@@ -91,24 +91,23 @@ class NewUsersSplitter(Splitter):
|
|
|
91
91
|
item_column=item_column,
|
|
92
92
|
timestamp_column=timestamp_column,
|
|
93
93
|
session_id_column=session_id_column,
|
|
94
|
-
session_id_processing_strategy=session_id_processing_strategy
|
|
94
|
+
session_id_processing_strategy=session_id_processing_strategy,
|
|
95
95
|
)
|
|
96
96
|
if test_size < 0 or test_size > 1:
|
|
97
|
-
|
|
97
|
+
msg = "test_size must between 0 and 1"
|
|
98
|
+
raise ValueError(msg)
|
|
98
99
|
self.test_size = test_size
|
|
99
100
|
|
|
100
101
|
def _core_split_pandas(
|
|
101
|
-
self,
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
_start_dt_by_user=(self.timestamp_column, "min")
|
|
107
|
-
).reset_index()
|
|
102
|
+
self, interactions: PandasDataFrame, threshold: float
|
|
103
|
+
) -> Tuple[PandasDataFrame, PandasDataFrame]:
|
|
104
|
+
start_date_by_user = (
|
|
105
|
+
interactions.groupby(self.query_column).agg(_start_dt_by_user=(self.timestamp_column, "min")).reset_index()
|
|
106
|
+
)
|
|
108
107
|
test_start_date = (
|
|
109
|
-
start_date_by_user
|
|
110
|
-
.
|
|
111
|
-
.
|
|
108
|
+
start_date_by_user.groupby("_start_dt_by_user")
|
|
109
|
+
.agg(_num_users_by_start_date=(self.query_column, "count"))
|
|
110
|
+
.reset_index()
|
|
112
111
|
.sort_values(by="_start_dt_by_user", ascending=False)
|
|
113
112
|
)
|
|
114
113
|
test_start_date["_cum_num_users_to_dt"] = test_start_date["_num_users_by_start_date"].cumsum()
|
|
@@ -120,9 +119,7 @@ class NewUsersSplitter(Splitter):
|
|
|
120
119
|
|
|
121
120
|
train = interactions[interactions[self.timestamp_column] < test_start]
|
|
122
121
|
test = interactions.merge(
|
|
123
|
-
start_date_by_user[start_date_by_user["_start_dt_by_user"] >= test_start],
|
|
124
|
-
how="inner",
|
|
125
|
-
on=self.query_column
|
|
122
|
+
start_date_by_user[start_date_by_user["_start_dt_by_user"] >= test_start], how="inner", on=self.query_column
|
|
126
123
|
).drop(columns=["_start_dt_by_user"])
|
|
127
124
|
|
|
128
125
|
if self.session_id_column:
|
|
@@ -136,10 +133,8 @@ class NewUsersSplitter(Splitter):
|
|
|
136
133
|
return train, test
|
|
137
134
|
|
|
138
135
|
def _core_split_spark(
|
|
139
|
-
self,
|
|
140
|
-
|
|
141
|
-
threshold: float
|
|
142
|
-
) -> Union[SparkDataFrame, SparkDataFrame]:
|
|
136
|
+
self, interactions: SparkDataFrame, threshold: float
|
|
137
|
+
) -> Tuple[SparkDataFrame, SparkDataFrame]:
|
|
143
138
|
start_date_by_user = interactions.groupby(self.query_column).agg(
|
|
144
139
|
sf.min(self.timestamp_column).alias("_start_dt_by_user")
|
|
145
140
|
)
|
|
@@ -175,53 +170,33 @@ class NewUsersSplitter(Splitter):
|
|
|
175
170
|
return train, test
|
|
176
171
|
|
|
177
172
|
def _core_split_polars(
|
|
178
|
-
self,
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
start_date_by_user = (
|
|
183
|
-
interactions
|
|
184
|
-
.group_by(self.query_column).agg(
|
|
185
|
-
pl.col(self.timestamp_column).min()
|
|
186
|
-
.alias("_start_dt_by_user")
|
|
187
|
-
)
|
|
173
|
+
self, interactions: PolarsDataFrame, threshold: float
|
|
174
|
+
) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
175
|
+
start_date_by_user = interactions.group_by(self.query_column).agg(
|
|
176
|
+
pl.col(self.timestamp_column).min().alias("_start_dt_by_user")
|
|
188
177
|
)
|
|
189
178
|
test_start_date = (
|
|
190
|
-
start_date_by_user
|
|
191
|
-
.
|
|
192
|
-
pl.col(self.query_column).count()
|
|
193
|
-
.alias("_num_users_by_start_date")
|
|
194
|
-
)
|
|
179
|
+
start_date_by_user.group_by("_start_dt_by_user")
|
|
180
|
+
.agg(pl.col(self.query_column).count().alias("_num_users_by_start_date"))
|
|
195
181
|
.sort("_start_dt_by_user", descending=True)
|
|
196
182
|
.with_columns(
|
|
197
|
-
pl.col("_num_users_by_start_date").cum_sum()
|
|
198
|
-
.alias("cum_sum_users"),
|
|
183
|
+
pl.col("_num_users_by_start_date").cum_sum().alias("cum_sum_users"),
|
|
199
184
|
)
|
|
200
|
-
.filter(
|
|
201
|
-
pl.col("cum_sum_users") >= pl.col("cum_sum_users").max() * threshold
|
|
202
|
-
)
|
|
203
|
-
["_start_dt_by_user"]
|
|
185
|
+
.filter(pl.col("cum_sum_users") >= pl.col("cum_sum_users").max() * threshold)["_start_dt_by_user"]
|
|
204
186
|
.max()
|
|
205
187
|
)
|
|
206
188
|
|
|
207
189
|
train = interactions.filter(pl.col(self.timestamp_column) < test_start_date)
|
|
208
190
|
test = interactions.join(
|
|
209
|
-
start_date_by_user.filter(pl.col("_start_dt_by_user") >= test_start_date),
|
|
210
|
-
on=self.query_column,
|
|
211
|
-
how="inner"
|
|
191
|
+
start_date_by_user.filter(pl.col("_start_dt_by_user") >= test_start_date), on=self.query_column, how="inner"
|
|
212
192
|
).drop("_start_dt_by_user")
|
|
213
193
|
|
|
214
194
|
if self.session_id_column:
|
|
215
195
|
interactions = interactions.with_columns(
|
|
216
|
-
pl.when(
|
|
217
|
-
pl.col(self.timestamp_column) < test_start_date
|
|
218
|
-
)
|
|
219
|
-
.then(False)
|
|
220
|
-
.otherwise(True)
|
|
221
|
-
.alias("is_test")
|
|
196
|
+
pl.when(pl.col(self.timestamp_column) < test_start_date).then(False).otherwise(True).alias("is_test")
|
|
222
197
|
)
|
|
223
198
|
interactions = self._recalculate_with_session_id_column(interactions)
|
|
224
|
-
train = interactions.filter(~pl.col("is_test")).drop("is_test")
|
|
199
|
+
train = interactions.filter(~pl.col("is_test")).drop("is_test")
|
|
225
200
|
test = interactions.filter(pl.col("is_test")).drop("is_test")
|
|
226
201
|
|
|
227
202
|
return train, test
|
|
@@ -234,4 +209,5 @@ class NewUsersSplitter(Splitter):
|
|
|
234
209
|
if isinstance(interactions, PolarsDataFrame):
|
|
235
210
|
return self._core_split_polars(interactions, self.test_size)
|
|
236
211
|
|
|
237
|
-
|
|
212
|
+
msg = f"{self} is not implemented for {type(interactions)}"
|
|
213
|
+
raise NotImplementedError(msg)
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from typing import Optional,
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
2
|
|
|
3
|
-
from .base_splitter import Splitter, SplitterReturnType
|
|
4
3
|
from replay.utils import DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
5
4
|
|
|
5
|
+
from .base_splitter import Splitter, SplitterReturnType
|
|
6
|
+
|
|
6
7
|
|
|
7
|
-
# pylint: disable=too-few-public-methods, duplicate-code
|
|
8
8
|
class RandomSplitter(Splitter):
|
|
9
9
|
"""Assign records into train and test at random."""
|
|
10
10
|
|
|
@@ -17,7 +17,6 @@ class RandomSplitter(Splitter):
|
|
|
17
17
|
"item_column",
|
|
18
18
|
]
|
|
19
19
|
|
|
20
|
-
# pylint: disable=too-many-arguments
|
|
21
20
|
def __init__(
|
|
22
21
|
self,
|
|
23
22
|
test_size: float,
|
|
@@ -25,7 +24,7 @@ class RandomSplitter(Splitter):
|
|
|
25
24
|
drop_cold_users: bool = False,
|
|
26
25
|
seed: Optional[int] = None,
|
|
27
26
|
query_column: str = "query_id",
|
|
28
|
-
item_column: str = "item_id"
|
|
27
|
+
item_column: str = "item_id",
|
|
29
28
|
):
|
|
30
29
|
"""
|
|
31
30
|
:param test_size: test size 0 to 1
|
|
@@ -39,37 +38,30 @@ class RandomSplitter(Splitter):
|
|
|
39
38
|
drop_cold_items=drop_cold_items,
|
|
40
39
|
drop_cold_users=drop_cold_users,
|
|
41
40
|
query_column=query_column,
|
|
42
|
-
item_column=item_column
|
|
41
|
+
item_column=item_column,
|
|
43
42
|
)
|
|
44
43
|
self.seed = seed
|
|
45
44
|
if test_size < 0 or test_size > 1:
|
|
46
|
-
|
|
45
|
+
msg = "test_size must between 0 and 1"
|
|
46
|
+
raise ValueError(msg)
|
|
47
47
|
self.test_size = test_size
|
|
48
48
|
|
|
49
49
|
def _random_split_spark(
|
|
50
|
-
self,
|
|
51
|
-
|
|
52
|
-
threshold
|
|
53
|
-
) -> Union[SparkDataFrame, SparkDataFrame]:
|
|
54
|
-
train, test = interactions.randomSplit(
|
|
55
|
-
[1 - threshold, threshold], self.seed
|
|
56
|
-
)
|
|
50
|
+
self, interactions: SparkDataFrame, threshold: float
|
|
51
|
+
) -> Tuple[SparkDataFrame, SparkDataFrame]:
|
|
52
|
+
train, test = interactions.randomSplit([1 - threshold, threshold], self.seed)
|
|
57
53
|
return train, test
|
|
58
54
|
|
|
59
55
|
def _random_split_pandas(
|
|
60
|
-
self,
|
|
61
|
-
|
|
62
|
-
threshold: float
|
|
63
|
-
) -> Union[PandasDataFrame, PandasDataFrame]:
|
|
56
|
+
self, interactions: PandasDataFrame, threshold: float
|
|
57
|
+
) -> Tuple[PandasDataFrame, PandasDataFrame]:
|
|
64
58
|
train = interactions.sample(frac=(1 - threshold), random_state=self.seed)
|
|
65
59
|
test = interactions.drop(train.index)
|
|
66
60
|
return train, test
|
|
67
61
|
|
|
68
62
|
def _random_split_polars(
|
|
69
|
-
self,
|
|
70
|
-
|
|
71
|
-
threshold: float
|
|
72
|
-
) -> Union[PolarsDataFrame, PolarsDataFrame]:
|
|
63
|
+
self, interactions: PolarsDataFrame, threshold: float
|
|
64
|
+
) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
73
65
|
train_size = int(len(interactions) * (1 - threshold)) + 1
|
|
74
66
|
shuffled_interactions = interactions.sample(fraction=1, shuffle=True, seed=self.seed)
|
|
75
67
|
train = shuffled_interactions[:train_size]
|
|
@@ -84,4 +76,5 @@ class RandomSplitter(Splitter):
|
|
|
84
76
|
if isinstance(interactions, PolarsDataFrame):
|
|
85
77
|
return self._random_split_polars(interactions, self.test_size)
|
|
86
78
|
|
|
87
|
-
|
|
79
|
+
msg = f"{self} is not implemented for {type(interactions)}"
|
|
80
|
+
raise NotImplementedError(msg)
|