replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
import polars as pl
|
|
2
1
|
from typing import List, Optional, Tuple
|
|
3
2
|
|
|
4
|
-
|
|
3
|
+
import polars as pl
|
|
4
|
+
|
|
5
5
|
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
6
6
|
|
|
7
|
+
from .base_splitter import Splitter
|
|
8
|
+
|
|
7
9
|
if PYSPARK_AVAILABLE:
|
|
8
10
|
import pyspark.sql.functions as sf
|
|
9
11
|
from pyspark.sql import Window
|
|
10
12
|
from pyspark.sql.types import IntegerType
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
# pylint: disable=too-few-public-methods, too-many-instance-attributes
|
|
14
15
|
class RatioSplitter(Splitter):
|
|
15
16
|
"""
|
|
16
17
|
Split interactions into train and test by ratio. Split is made for each user separately.
|
|
@@ -82,6 +83,7 @@ class RatioSplitter(Splitter):
|
|
|
82
83
|
14 3 2 2020-01-05
|
|
83
84
|
<BLANKLINE>
|
|
84
85
|
"""
|
|
86
|
+
|
|
85
87
|
_init_arg_names = [
|
|
86
88
|
"test_size",
|
|
87
89
|
"divide_column",
|
|
@@ -96,7 +98,6 @@ class RatioSplitter(Splitter):
|
|
|
96
98
|
"session_id_processing_strategy",
|
|
97
99
|
]
|
|
98
100
|
|
|
99
|
-
# pylint: disable=too-many-arguments
|
|
100
101
|
def __init__(
|
|
101
102
|
self,
|
|
102
103
|
test_size: float,
|
|
@@ -160,7 +161,8 @@ class RatioSplitter(Splitter):
|
|
|
160
161
|
self.min_interactions_per_group = min_interactions_per_group
|
|
161
162
|
self.split_by_fractions = split_by_fractions
|
|
162
163
|
if test_size < 0 or test_size > 1:
|
|
163
|
-
|
|
164
|
+
msg = "test_size must between 0 and 1"
|
|
165
|
+
raise ValueError(msg)
|
|
164
166
|
self.test_size = test_size
|
|
165
167
|
|
|
166
168
|
def _add_time_partition(self, interactions: DataFrameLike) -> DataFrameLike:
|
|
@@ -171,7 +173,8 @@ class RatioSplitter(Splitter):
|
|
|
171
173
|
if isinstance(interactions, PolarsDataFrame):
|
|
172
174
|
return self._add_time_partition_to_polars(interactions)
|
|
173
175
|
|
|
174
|
-
|
|
176
|
+
msg = f"{self} is not implemented for {type(interactions)}"
|
|
177
|
+
raise NotImplementedError(msg)
|
|
175
178
|
|
|
176
179
|
def _add_time_partition_to_pandas(self, interactions: PandasDataFrame) -> PandasDataFrame:
|
|
177
180
|
res = interactions.copy(deep=True)
|
|
@@ -189,14 +192,8 @@ class RatioSplitter(Splitter):
|
|
|
189
192
|
return res
|
|
190
193
|
|
|
191
194
|
def _add_time_partition_to_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
|
|
192
|
-
res = (
|
|
193
|
-
|
|
194
|
-
.sort(self.timestamp_column)
|
|
195
|
-
.with_columns(
|
|
196
|
-
pl.cum_count(self.divide_column)
|
|
197
|
-
.over(self.divide_column)
|
|
198
|
-
.alias("row_num")
|
|
199
|
-
)
|
|
195
|
+
res = interactions.sort(self.timestamp_column).with_columns(
|
|
196
|
+
pl.cum_count(self.divide_column).over(self.divide_column).alias("row_num")
|
|
200
197
|
)
|
|
201
198
|
|
|
202
199
|
return res
|
|
@@ -262,8 +259,7 @@ class RatioSplitter(Splitter):
|
|
|
262
259
|
self, interactions: PolarsDataFrame, train_size: float
|
|
263
260
|
) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
264
261
|
interactions = interactions.with_columns(
|
|
265
|
-
pl.count(self.timestamp_column).over(pl.col(self.divide_column))
|
|
266
|
-
.alias("count")
|
|
262
|
+
pl.count(self.timestamp_column).over(pl.col(self.divide_column)).alias("count")
|
|
267
263
|
)
|
|
268
264
|
if self.min_interactions_per_group is not None:
|
|
269
265
|
interactions = interactions.with_columns(
|
|
@@ -274,18 +270,14 @@ class RatioSplitter(Splitter):
|
|
|
274
270
|
)
|
|
275
271
|
else:
|
|
276
272
|
interactions = interactions.with_columns(
|
|
277
|
-
(pl.col("row_num") / pl.col("count")).round(self._precision)
|
|
278
|
-
.alias("frac")
|
|
273
|
+
(pl.col("row_num") / pl.col("count")).round(self._precision).alias("frac")
|
|
279
274
|
)
|
|
280
275
|
|
|
281
|
-
interactions = interactions.with_columns(
|
|
282
|
-
(pl.col("frac") > train_size)
|
|
283
|
-
.alias("is_test")
|
|
284
|
-
)
|
|
276
|
+
interactions = interactions.with_columns((pl.col("frac") > train_size).alias("is_test"))
|
|
285
277
|
if self.session_id_column:
|
|
286
278
|
interactions = self._recalculate_with_session_id_column(interactions)
|
|
287
279
|
|
|
288
|
-
train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "frac", "is_test")
|
|
280
|
+
train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "frac", "is_test")
|
|
289
281
|
test = interactions.filter(pl.col("is_test")).drop("row_num", "count", "frac", "is_test")
|
|
290
282
|
|
|
291
283
|
return train, test
|
|
@@ -316,7 +308,7 @@ class RatioSplitter(Splitter):
|
|
|
316
308
|
"train_size",
|
|
317
309
|
] = (
|
|
318
310
|
interactions["train_size"] - 1
|
|
319
|
-
)
|
|
311
|
+
)
|
|
320
312
|
|
|
321
313
|
interactions["is_test"] = interactions["row_num"] > interactions["train_size"]
|
|
322
314
|
if self.session_id_column:
|
|
@@ -327,9 +319,7 @@ class RatioSplitter(Splitter):
|
|
|
327
319
|
|
|
328
320
|
return train, test
|
|
329
321
|
|
|
330
|
-
def _partial_split_spark(
|
|
331
|
-
self, interactions: SparkDataFrame, ratio: float
|
|
332
|
-
) -> Tuple[SparkDataFrame, SparkDataFrame]:
|
|
322
|
+
def _partial_split_spark(self, interactions: SparkDataFrame, ratio: float) -> Tuple[SparkDataFrame, SparkDataFrame]:
|
|
333
323
|
interactions = interactions.withColumn(
|
|
334
324
|
"count", sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column))
|
|
335
325
|
)
|
|
@@ -364,51 +354,37 @@ class RatioSplitter(Splitter):
|
|
|
364
354
|
self, interactions: PolarsDataFrame, ratio: float
|
|
365
355
|
) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
366
356
|
interactions = interactions.with_columns(
|
|
367
|
-
pl.count(self.timestamp_column).over(self.divide_column)
|
|
368
|
-
.alias("count")
|
|
357
|
+
pl.count(self.timestamp_column).over(self.divide_column).alias("count")
|
|
369
358
|
)
|
|
370
359
|
if self.min_interactions_per_group is not None:
|
|
371
360
|
interactions = interactions.with_columns(
|
|
372
|
-
pl.when(
|
|
373
|
-
|
|
374
|
-
)
|
|
375
|
-
.then(
|
|
376
|
-
pl.col("count") - (pl.col("count") * ratio).cast(interactions.get_column("count").dtype)
|
|
377
|
-
)
|
|
361
|
+
pl.when(pl.col("count") >= self.min_interactions_per_group)
|
|
362
|
+
.then(pl.col("count") - (pl.col("count") * ratio).cast(interactions.get_column("count").dtype))
|
|
378
363
|
.otherwise(pl.col("count"))
|
|
379
364
|
.alias("train_size")
|
|
380
365
|
)
|
|
381
366
|
else:
|
|
382
|
-
interactions = (
|
|
383
|
-
interactions
|
|
384
|
-
|
|
385
|
-
(pl.col("count") - (pl.col("count") * ratio).cast(interactions.get_column("count").dtype))
|
|
386
|
-
.alias("train_size")
|
|
387
|
-
)
|
|
388
|
-
.with_columns(
|
|
389
|
-
pl.when(
|
|
390
|
-
(pl.col("count") * ratio > 0) & (pl.col("count") * ratio < 1) & (pl.col("train_size") > 1)
|
|
391
|
-
)
|
|
392
|
-
.then(pl.col("train_size") - 1)
|
|
393
|
-
.otherwise(pl.col("train_size"))
|
|
394
|
-
.alias("train_size")
|
|
367
|
+
interactions = interactions.with_columns(
|
|
368
|
+
(pl.col("count") - (pl.col("count") * ratio).cast(interactions.get_column("count").dtype)).alias(
|
|
369
|
+
"train_size"
|
|
395
370
|
)
|
|
371
|
+
).with_columns(
|
|
372
|
+
pl.when((pl.col("count") * ratio > 0) & (pl.col("count") * ratio < 1) & (pl.col("train_size") > 1))
|
|
373
|
+
.then(pl.col("train_size") - 1)
|
|
374
|
+
.otherwise(pl.col("train_size"))
|
|
375
|
+
.alias("train_size")
|
|
396
376
|
)
|
|
397
377
|
|
|
398
|
-
interactions = interactions.with_columns(
|
|
399
|
-
(pl.col("row_num") > pl.col("train_size"))
|
|
400
|
-
.alias("is_test")
|
|
401
|
-
)
|
|
378
|
+
interactions = interactions.with_columns((pl.col("row_num") > pl.col("train_size")).alias("is_test"))
|
|
402
379
|
|
|
403
380
|
if self.session_id_column:
|
|
404
381
|
interactions = self._recalculate_with_session_id_column(interactions)
|
|
405
382
|
|
|
406
|
-
train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "train_size", "is_test")
|
|
383
|
+
train = interactions.filter(~pl.col("is_test")).drop("row_num", "count", "train_size", "is_test")
|
|
407
384
|
test = interactions.filter(pl.col("is_test")).drop("row_num", "count", "train_size", "is_test")
|
|
408
385
|
|
|
409
386
|
return train, test
|
|
410
387
|
|
|
411
|
-
# pylint: disable=invalid-name
|
|
412
388
|
def _core_split(self, interactions: DataFrameLike) -> List[DataFrameLike]:
|
|
413
389
|
if self.split_by_fractions:
|
|
414
390
|
return self._partial_split_fractions(interactions, self.test_size)
|
|
@@ -3,21 +3,21 @@ from typing import List, Optional, Tuple, Union
|
|
|
3
3
|
|
|
4
4
|
import polars as pl
|
|
5
5
|
|
|
6
|
-
from .base_splitter import Splitter
|
|
7
6
|
from replay.utils import (
|
|
8
7
|
PYSPARK_AVAILABLE,
|
|
9
8
|
DataFrameLike,
|
|
10
9
|
PandasDataFrame,
|
|
11
|
-
SparkDataFrame,
|
|
12
10
|
PolarsDataFrame,
|
|
11
|
+
SparkDataFrame,
|
|
13
12
|
)
|
|
14
13
|
|
|
14
|
+
from .base_splitter import Splitter
|
|
15
|
+
|
|
15
16
|
if PYSPARK_AVAILABLE:
|
|
16
17
|
import pyspark.sql.functions as sf
|
|
17
18
|
from pyspark.sql import Window
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
# pylint: disable=too-few-public-methods
|
|
21
21
|
class TimeSplitter(Splitter):
|
|
22
22
|
"""
|
|
23
23
|
Split interactions by time.
|
|
@@ -85,6 +85,7 @@ class TimeSplitter(Splitter):
|
|
|
85
85
|
14 3 2 2020-01-05
|
|
86
86
|
<BLANKLINE>
|
|
87
87
|
"""
|
|
88
|
+
|
|
88
89
|
_init_arg_names = [
|
|
89
90
|
"time_threshold",
|
|
90
91
|
"drop_cold_users",
|
|
@@ -97,10 +98,9 @@ class TimeSplitter(Splitter):
|
|
|
97
98
|
"time_column_format",
|
|
98
99
|
]
|
|
99
100
|
|
|
100
|
-
# pylint: disable=too-many-arguments
|
|
101
101
|
def __init__(
|
|
102
102
|
self,
|
|
103
|
-
time_threshold: Union[datetime, str,
|
|
103
|
+
time_threshold: Union[datetime, str, float],
|
|
104
104
|
query_column: str = "query_id",
|
|
105
105
|
drop_cold_users: bool = False,
|
|
106
106
|
drop_cold_items: bool = False,
|
|
@@ -144,7 +144,8 @@ class TimeSplitter(Splitter):
|
|
|
144
144
|
self._precision = 3
|
|
145
145
|
self.time_column_format = time_column_format
|
|
146
146
|
if isinstance(time_threshold, float) and (time_threshold < 0 or time_threshold > 1):
|
|
147
|
-
|
|
147
|
+
msg = "time_threshold must be between 0 and 1"
|
|
148
|
+
raise ValueError(msg)
|
|
148
149
|
self.time_threshold = time_threshold
|
|
149
150
|
|
|
150
151
|
def _partial_split(
|
|
@@ -160,7 +161,8 @@ class TimeSplitter(Splitter):
|
|
|
160
161
|
if isinstance(interactions, PolarsDataFrame):
|
|
161
162
|
return self._partial_split_polars(interactions, threshold)
|
|
162
163
|
|
|
163
|
-
|
|
164
|
+
msg = f"{self} is not implemented for {type(interactions)}"
|
|
165
|
+
raise NotImplementedError(msg)
|
|
164
166
|
|
|
165
167
|
def _partial_split_pandas(
|
|
166
168
|
self, interactions: PandasDataFrame, threshold: Union[datetime, str, int]
|
|
@@ -191,9 +193,7 @@ class TimeSplitter(Splitter):
|
|
|
191
193
|
)
|
|
192
194
|
test_start = int(dates.count() * (1 - threshold)) + 1
|
|
193
195
|
test_start = (
|
|
194
|
-
dates.filter(sf.col("_row_number_by_ts") == test_start)
|
|
195
|
-
.select(self.timestamp_column)
|
|
196
|
-
.collect()[0][0]
|
|
196
|
+
dates.filter(sf.col("_row_number_by_ts") == test_start).select(self.timestamp_column).collect()[0][0]
|
|
197
197
|
)
|
|
198
198
|
res = interactions.withColumn("is_test", sf.col(self.timestamp_column) >= test_start)
|
|
199
199
|
else:
|
|
@@ -212,20 +212,15 @@ class TimeSplitter(Splitter):
|
|
|
212
212
|
if isinstance(threshold, float):
|
|
213
213
|
test_start = int(len(interactions) * (1 - threshold)) + 1
|
|
214
214
|
|
|
215
|
-
res = (
|
|
216
|
-
|
|
217
|
-
.sort(self.timestamp_column)
|
|
218
|
-
.with_columns(
|
|
219
|
-
(pl.col(self.timestamp_column).cum_count() >= test_start)
|
|
220
|
-
.alias("is_test")
|
|
221
|
-
)
|
|
215
|
+
res = interactions.sort(self.timestamp_column).with_columns(
|
|
216
|
+
(pl.col(self.timestamp_column).cum_count() >= test_start).alias("is_test")
|
|
222
217
|
)
|
|
223
218
|
else:
|
|
224
219
|
res = interactions.with_columns((pl.col(self.timestamp_column) >= threshold).alias("is_test"))
|
|
225
220
|
|
|
226
221
|
if self.session_id_column:
|
|
227
222
|
res = self._recalculate_with_session_id_column(res)
|
|
228
|
-
train = res.filter(~pl.col("is_test")).drop("is_test")
|
|
223
|
+
train = res.filter(~pl.col("is_test")).drop("is_test")
|
|
229
224
|
test = res.filter("is_test").drop("is_test")
|
|
230
225
|
|
|
231
226
|
return train, test
|
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
"""
|
|
2
2
|
This splitter split data by two columns.
|
|
3
3
|
"""
|
|
4
|
-
from typing import Optional,
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
5
6
|
import polars as pl
|
|
6
7
|
|
|
7
|
-
from .base_splitter import Splitter, SplitterReturnType
|
|
8
8
|
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
9
9
|
|
|
10
|
+
from .base_splitter import Splitter, SplitterReturnType
|
|
11
|
+
|
|
10
12
|
if PYSPARK_AVAILABLE:
|
|
11
13
|
import pyspark.sql.functions as sf
|
|
12
14
|
from pyspark.sql import Window
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
# pylint: disable=too-few-public-methods
|
|
16
17
|
class TwoStageSplitter(Splitter):
|
|
17
18
|
"""
|
|
18
19
|
Split data by two columns.
|
|
@@ -73,11 +74,10 @@ class TwoStageSplitter(Splitter):
|
|
|
73
74
|
"timestamp_column",
|
|
74
75
|
]
|
|
75
76
|
|
|
76
|
-
# pylint: disable=too-many-arguments
|
|
77
77
|
def __init__(
|
|
78
78
|
self,
|
|
79
|
-
first_divide_size:
|
|
80
|
-
second_divide_size:
|
|
79
|
+
first_divide_size: float,
|
|
80
|
+
second_divide_size: float,
|
|
81
81
|
first_divide_column: str = "query_id",
|
|
82
82
|
second_divide_column: str = "item_id",
|
|
83
83
|
shuffle=False,
|
|
@@ -147,17 +147,12 @@ class TwoStageSplitter(Splitter):
|
|
|
147
147
|
else:
|
|
148
148
|
value_error = True
|
|
149
149
|
if value_error:
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
Invalid value for user_test_size: {self.first_divide_size}
|
|
153
|
-
"""
|
|
154
|
-
)
|
|
150
|
+
msg = f"Invalid value for user_test_size: {self.first_divide_size}"
|
|
151
|
+
raise ValueError(msg)
|
|
155
152
|
if isinstance(interactions, SparkDataFrame):
|
|
156
153
|
test_users = (
|
|
157
154
|
all_values.withColumn("_rand", sf.rand(self.seed))
|
|
158
|
-
.withColumn(
|
|
159
|
-
"_row_num", sf.row_number().over(Window.orderBy("_rand"))
|
|
160
|
-
)
|
|
155
|
+
.withColumn("_row_num", sf.row_number().over(Window.orderBy("_rand")))
|
|
161
156
|
.filter(f"_row_num <= {test_user_count}")
|
|
162
157
|
.drop("_rand", "_row_num")
|
|
163
158
|
)
|
|
@@ -168,11 +163,9 @@ class TwoStageSplitter(Splitter):
|
|
|
168
163
|
|
|
169
164
|
return test_users
|
|
170
165
|
|
|
171
|
-
def _split_proportion_spark(self, interactions: SparkDataFrame) ->
|
|
166
|
+
def _split_proportion_spark(self, interactions: SparkDataFrame) -> Tuple[SparkDataFrame, SparkDataFrame]:
|
|
172
167
|
counts = interactions.groupBy(self.first_divide_column).count()
|
|
173
|
-
test_users = self._get_test_values(interactions).withColumn(
|
|
174
|
-
"is_test", sf.lit(True)
|
|
175
|
-
)
|
|
168
|
+
test_users = self._get_test_values(interactions).withColumn("is_test", sf.lit(True))
|
|
176
169
|
if self.shuffle:
|
|
177
170
|
res = self._add_random_partition_spark(
|
|
178
171
|
interactions.join(test_users, how="left", on=self.first_divide_column)
|
|
@@ -202,10 +195,10 @@ class TwoStageSplitter(Splitter):
|
|
|
202
195
|
|
|
203
196
|
return train, test
|
|
204
197
|
|
|
205
|
-
def _split_proportion_pandas(self, interactions: PandasDataFrame) ->
|
|
206
|
-
counts =
|
|
207
|
-
count=(self.first_divide_column, "count")
|
|
208
|
-
)
|
|
198
|
+
def _split_proportion_pandas(self, interactions: PandasDataFrame) -> Tuple[PandasDataFrame, PandasDataFrame]:
|
|
199
|
+
counts = (
|
|
200
|
+
interactions.groupby(self.first_divide_column).agg(count=(self.first_divide_column, "count")).reset_index()
|
|
201
|
+
)
|
|
209
202
|
test_users = self._get_test_values(interactions)
|
|
210
203
|
test_users["is_test"] = True
|
|
211
204
|
if self.shuffle:
|
|
@@ -229,11 +222,9 @@ class TwoStageSplitter(Splitter):
|
|
|
229
222
|
|
|
230
223
|
return train, test
|
|
231
224
|
|
|
232
|
-
def _split_proportion_polars(self, interactions: PolarsDataFrame) ->
|
|
225
|
+
def _split_proportion_polars(self, interactions: PolarsDataFrame) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
233
226
|
counts = interactions.group_by(self.first_divide_column).count()
|
|
234
|
-
test_users = self._get_test_values(interactions).with_columns(
|
|
235
|
-
pl.lit(True).alias("is_test")
|
|
236
|
-
)
|
|
227
|
+
test_users = self._get_test_values(interactions).with_columns(pl.lit(True).alias("is_test"))
|
|
237
228
|
if self.shuffle:
|
|
238
229
|
res = self._add_random_partition_polars(
|
|
239
230
|
interactions.join(test_users, how="left", on=self.first_divide_column)
|
|
@@ -245,18 +236,15 @@ class TwoStageSplitter(Splitter):
|
|
|
245
236
|
)
|
|
246
237
|
|
|
247
238
|
res = res.join(counts, on=self.first_divide_column, how="left")
|
|
248
|
-
res = res.with_columns(
|
|
249
|
-
(pl.col("_row_num") / pl.col("count"))
|
|
250
|
-
.alias("_frac")
|
|
251
|
-
)
|
|
239
|
+
res = res.with_columns((pl.col("_row_num") / pl.col("count")).alias("_frac"))
|
|
252
240
|
res = res.fill_null(False)
|
|
253
241
|
|
|
254
|
-
train = res.filter(
|
|
255
|
-
|
|
256
|
-
)
|
|
257
|
-
test = res.filter(
|
|
258
|
-
|
|
259
|
-
)
|
|
242
|
+
train = res.filter((pl.col("_frac") > self.second_divide_size) | (~pl.col("is_test"))).drop(
|
|
243
|
+
"_rand", "_row_num", "count", "_frac", "is_test"
|
|
244
|
+
)
|
|
245
|
+
test = res.filter((pl.col("_frac") <= self.second_divide_size) & pl.col("is_test")).drop(
|
|
246
|
+
"_rand", "_row_num", "count", "_frac", "is_test"
|
|
247
|
+
)
|
|
260
248
|
|
|
261
249
|
return train, test
|
|
262
250
|
|
|
@@ -274,12 +262,11 @@ class TwoStageSplitter(Splitter):
|
|
|
274
262
|
if isinstance(interactions, PolarsDataFrame):
|
|
275
263
|
return self._split_proportion_polars(interactions)
|
|
276
264
|
|
|
277
|
-
|
|
265
|
+
msg = f"{self} is not implemented for {type(interactions)}"
|
|
266
|
+
raise NotImplementedError(msg)
|
|
278
267
|
|
|
279
268
|
def _split_quantity_spark(self, interactions: SparkDataFrame) -> SparkDataFrame:
|
|
280
|
-
test_users = self._get_test_values(interactions).withColumn(
|
|
281
|
-
"is_test", sf.lit(True)
|
|
282
|
-
)
|
|
269
|
+
test_users = self._get_test_values(interactions).withColumn("is_test", sf.lit(True))
|
|
283
270
|
if self.shuffle:
|
|
284
271
|
res = self._add_random_partition_spark(
|
|
285
272
|
interactions.join(test_users, how="left", on=self.first_divide_column)
|
|
@@ -328,9 +315,7 @@ class TwoStageSplitter(Splitter):
|
|
|
328
315
|
return train, test
|
|
329
316
|
|
|
330
317
|
def _split_quantity_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
|
|
331
|
-
test_users = self._get_test_values(interactions).with_columns(
|
|
332
|
-
pl.lit(True).alias("is_test")
|
|
333
|
-
)
|
|
318
|
+
test_users = self._get_test_values(interactions).with_columns(pl.lit(True).alias("is_test"))
|
|
334
319
|
if self.shuffle:
|
|
335
320
|
res = self._add_random_partition_polars(
|
|
336
321
|
interactions.join(test_users, how="left", on=self.first_divide_column)
|
|
@@ -342,12 +327,12 @@ class TwoStageSplitter(Splitter):
|
|
|
342
327
|
)
|
|
343
328
|
|
|
344
329
|
res = res.fill_null(False)
|
|
345
|
-
train = res.filter(
|
|
346
|
-
|
|
347
|
-
)
|
|
348
|
-
test = res.filter(
|
|
349
|
-
|
|
350
|
-
)
|
|
330
|
+
train = res.filter((pl.col("_row_num") > self.second_divide_size) | (~pl.col("is_test"))).drop(
|
|
331
|
+
"_row_num", "is_test"
|
|
332
|
+
)
|
|
333
|
+
test = res.filter((pl.col("_row_num") <= self.second_divide_size) & pl.col("is_test")).drop(
|
|
334
|
+
"_row_num", "is_test"
|
|
335
|
+
)
|
|
351
336
|
|
|
352
337
|
return train, test
|
|
353
338
|
|
|
@@ -365,7 +350,8 @@ class TwoStageSplitter(Splitter):
|
|
|
365
350
|
if isinstance(interactions, PolarsDataFrame):
|
|
366
351
|
return self._split_quantity_polars(interactions)
|
|
367
352
|
|
|
368
|
-
|
|
353
|
+
msg = f"{self} is not implemented for {type(interactions)}"
|
|
354
|
+
raise NotImplementedError(msg)
|
|
369
355
|
|
|
370
356
|
def _core_split(self, interactions: DataFrameLike) -> SplitterReturnType:
|
|
371
357
|
if 0 <= self.second_divide_size < 1.0:
|
|
@@ -373,11 +359,8 @@ class TwoStageSplitter(Splitter):
|
|
|
373
359
|
elif self.second_divide_size >= 1 and isinstance(self.second_divide_size, int):
|
|
374
360
|
train, test = self._split_quantity(interactions)
|
|
375
361
|
else:
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
"a positive integer; "
|
|
379
|
-
f"test_size={self.second_divide_size}"
|
|
380
|
-
)
|
|
362
|
+
msg = f"`test_size` value must be [0, 1) or a positive integer; test_size={self.second_divide_size}"
|
|
363
|
+
raise ValueError(msg)
|
|
381
364
|
|
|
382
365
|
return train, test
|
|
383
366
|
|
|
@@ -391,9 +374,7 @@ class TwoStageSplitter(Splitter):
|
|
|
391
374
|
dataframe = dataframe.withColumn("_rand", sf.rand(self.seed))
|
|
392
375
|
dataframe = dataframe.withColumn(
|
|
393
376
|
"_row_num",
|
|
394
|
-
sf.row_number().over(
|
|
395
|
-
Window.partitionBy(self.first_divide_column).orderBy("_rand")
|
|
396
|
-
),
|
|
377
|
+
sf.row_number().over(Window.partitionBy(self.first_divide_column).orderBy("_rand")),
|
|
397
378
|
)
|
|
398
379
|
return dataframe
|
|
399
380
|
|
|
@@ -404,14 +385,8 @@ class TwoStageSplitter(Splitter):
|
|
|
404
385
|
return res
|
|
405
386
|
|
|
406
387
|
def _add_random_partition_polars(self, dataframe: PolarsDataFrame) -> PolarsDataFrame:
|
|
407
|
-
res = (
|
|
408
|
-
|
|
409
|
-
.sample(fraction=1, shuffle=True, seed=self.seed)
|
|
410
|
-
.with_columns(
|
|
411
|
-
pl.cum_count(self.first_divide_column)
|
|
412
|
-
.over(self.first_divide_column)
|
|
413
|
-
.alias("_row_num")
|
|
414
|
-
)
|
|
388
|
+
res = dataframe.sample(fraction=1, shuffle=True, seed=self.seed).with_columns(
|
|
389
|
+
pl.cum_count(self.first_divide_column).over(self.first_divide_column).alias("_row_num")
|
|
415
390
|
)
|
|
416
391
|
return res
|
|
417
392
|
|
|
@@ -431,11 +406,7 @@ class TwoStageSplitter(Splitter):
|
|
|
431
406
|
"""
|
|
432
407
|
res = dataframe.withColumn(
|
|
433
408
|
"_row_num",
|
|
434
|
-
sf.row_number().over(
|
|
435
|
-
Window.partitionBy(query_column).orderBy(
|
|
436
|
-
sf.col(date_column).desc()
|
|
437
|
-
)
|
|
438
|
-
),
|
|
409
|
+
sf.row_number().over(Window.partitionBy(query_column).orderBy(sf.col(date_column).desc())),
|
|
439
410
|
)
|
|
440
411
|
return res
|
|
441
412
|
|
|
@@ -456,13 +427,7 @@ class TwoStageSplitter(Splitter):
|
|
|
456
427
|
query_column: str = "query_id",
|
|
457
428
|
date_column: str = "timestamp",
|
|
458
429
|
) -> PolarsDataFrame:
|
|
459
|
-
res = (
|
|
460
|
-
|
|
461
|
-
.sort(date_column, descending=True)
|
|
462
|
-
.with_columns(
|
|
463
|
-
pl.cum_count(query_column)
|
|
464
|
-
.over(query_column)
|
|
465
|
-
.alias("_row_num")
|
|
466
|
-
)
|
|
430
|
+
res = dataframe.sort(date_column, descending=True).with_columns(
|
|
431
|
+
pl.cum_count(query_column).over(query_column).alias("_row_num")
|
|
467
432
|
)
|
|
468
433
|
return res
|
replay/utils/__init__.py
CHANGED
replay/utils/common.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from replay.splitters import (
|
|
6
|
+
ColdUserRandomSplitter,
|
|
7
|
+
KFolds,
|
|
8
|
+
LastNSplitter,
|
|
9
|
+
NewUsersSplitter,
|
|
10
|
+
RandomSplitter,
|
|
11
|
+
RatioSplitter,
|
|
12
|
+
TimeSplitter,
|
|
13
|
+
TwoStageSplitter,
|
|
14
|
+
)
|
|
15
|
+
from replay.utils import TORCH_AVAILABLE
|
|
16
|
+
|
|
17
|
+
SavableObject = Union[
|
|
18
|
+
ColdUserRandomSplitter,
|
|
19
|
+
KFolds,
|
|
20
|
+
LastNSplitter,
|
|
21
|
+
NewUsersSplitter,
|
|
22
|
+
RandomSplitter,
|
|
23
|
+
RatioSplitter,
|
|
24
|
+
TimeSplitter,
|
|
25
|
+
TwoStageSplitter,
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
if TORCH_AVAILABLE:
|
|
29
|
+
from replay.data.nn import SequenceTokenizer
|
|
30
|
+
|
|
31
|
+
SavableObject = Union[
|
|
32
|
+
ColdUserRandomSplitter,
|
|
33
|
+
KFolds,
|
|
34
|
+
LastNSplitter,
|
|
35
|
+
NewUsersSplitter,
|
|
36
|
+
RandomSplitter,
|
|
37
|
+
RatioSplitter,
|
|
38
|
+
TimeSplitter,
|
|
39
|
+
TwoStageSplitter,
|
|
40
|
+
SequenceTokenizer,
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def save_to_replay(obj: SavableObject, path: Union[str, Path]) -> None:
|
|
45
|
+
"""
|
|
46
|
+
General function to save RePlay models, splitters and tokenizer.
|
|
47
|
+
|
|
48
|
+
:param path: Path to save the object.
|
|
49
|
+
"""
|
|
50
|
+
obj.save(path)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_from_replay(path: Union[str, Path]) -> SavableObject:
|
|
54
|
+
"""
|
|
55
|
+
General function to load RePlay models, splitters and tokenizer.
|
|
56
|
+
|
|
57
|
+
:param path: Path to save the object.
|
|
58
|
+
"""
|
|
59
|
+
path = Path(path).with_suffix(".replay").resolve()
|
|
60
|
+
with open(path / "init_args.json", "r") as file:
|
|
61
|
+
class_name = json.loads(file.read())["_class_name"]
|
|
62
|
+
obj_type = globals()[class_name]
|
|
63
|
+
obj = obj_type.load(path)
|
|
64
|
+
|
|
65
|
+
return obj
|