replay-rec 0.20.0__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +10 -9
- replay/data/dataset_utils/dataset_label_encoder.py +5 -4
- replay/data/nn/schema.py +9 -18
- replay/data/nn/sequence_tokenizer.py +26 -18
- replay/data/nn/sequential_dataset.py +22 -18
- replay/data/nn/torch_sequential_dataset.py +17 -16
- replay/data/nn/utils.py +2 -1
- replay/data/schema.py +3 -12
- replay/metrics/base_metric.py +11 -10
- replay/metrics/categorical_diversity.py +8 -8
- replay/metrics/coverage.py +4 -4
- replay/metrics/experiment.py +3 -3
- replay/metrics/hitrate.py +1 -3
- replay/metrics/map.py +1 -3
- replay/metrics/mrr.py +1 -3
- replay/metrics/ndcg.py +1 -2
- replay/metrics/novelty.py +3 -3
- replay/metrics/offline_metrics.py +16 -16
- replay/metrics/precision.py +1 -3
- replay/metrics/recall.py +1 -3
- replay/metrics/rocauc.py +1 -3
- replay/metrics/surprisal.py +4 -4
- replay/metrics/torch_metrics_builder.py +13 -12
- replay/metrics/unexpectedness.py +2 -2
- replay/models/als.py +2 -2
- replay/models/association_rules.py +4 -3
- replay/models/base_neighbour_rec.py +3 -2
- replay/models/base_rec.py +11 -10
- replay/models/cat_pop_rec.py +2 -1
- replay/models/extensions/ann/ann_mixin.py +2 -1
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
- replay/models/lin_ucb.py +57 -11
- replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
- replay/models/nn/sequential/bert4rec/dataset.py +5 -18
- replay/models/nn/sequential/bert4rec/lightning.py +3 -3
- replay/models/nn/sequential/bert4rec/model.py +2 -2
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +12 -12
- replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
- replay/models/nn/sequential/compiled/base_compiled_model.py +5 -5
- replay/models/nn/sequential/postprocessors/_base.py +2 -3
- replay/models/nn/sequential/postprocessors/postprocessors.py +11 -11
- replay/models/nn/sequential/sasrec/dataset.py +3 -16
- replay/models/nn/sequential/sasrec/lightning.py +3 -3
- replay/models/nn/sequential/sasrec/model.py +8 -8
- replay/models/slim.py +2 -2
- replay/models/ucb.py +2 -2
- replay/models/word2vec.py +3 -3
- replay/preprocessing/discretizer.py +8 -7
- replay/preprocessing/filters.py +4 -4
- replay/preprocessing/history_based_fp.py +6 -6
- replay/preprocessing/label_encoder.py +8 -7
- replay/scenarios/fallback.py +4 -3
- replay/splitters/base_splitter.py +3 -3
- replay/splitters/cold_user_random_splitter.py +4 -4
- replay/splitters/k_folds.py +4 -4
- replay/splitters/last_n_splitter.py +10 -10
- replay/splitters/new_users_splitter.py +4 -4
- replay/splitters/random_splitter.py +4 -4
- replay/splitters/ratio_splitter.py +10 -10
- replay/splitters/time_splitter.py +6 -6
- replay/splitters/two_stage_splitter.py +4 -4
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +1 -1
- replay/utils/session_handler.py +2 -2
- replay/utils/spark_utils.py +6 -5
- replay/utils/types.py +3 -1
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/METADATA +7 -1
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/RECORD +73 -74
- replay/utils/warnings.py +0 -26
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -10,7 +10,6 @@ from replay.data.nn import (
|
|
|
10
10
|
TorchSequentialDataset,
|
|
11
11
|
TorchSequentialValidationDataset,
|
|
12
12
|
)
|
|
13
|
-
from replay.utils import deprecation_warning
|
|
14
13
|
|
|
15
14
|
|
|
16
15
|
class SasRecTrainingBatch(NamedTuple):
|
|
@@ -31,17 +30,13 @@ class SasRecTrainingDataset(TorchDataset):
|
|
|
31
30
|
Dataset that generates samples to train SasRec-like model
|
|
32
31
|
"""
|
|
33
32
|
|
|
34
|
-
@deprecation_warning(
|
|
35
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
36
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
37
|
-
)
|
|
38
33
|
def __init__(
|
|
39
34
|
self,
|
|
40
35
|
sequential: SequentialDataset,
|
|
41
36
|
max_sequence_length: int,
|
|
42
37
|
sequence_shift: int = 1,
|
|
43
38
|
sliding_window_step: Optional[None] = None,
|
|
44
|
-
padding_value: int =
|
|
39
|
+
padding_value: Optional[int] = None,
|
|
45
40
|
label_feature_name: Optional[str] = None,
|
|
46
41
|
) -> None:
|
|
47
42
|
"""
|
|
@@ -127,15 +122,11 @@ class SasRecPredictionDataset(TorchDataset):
|
|
|
127
122
|
Dataset that generates samples to infer SasRec-like model
|
|
128
123
|
"""
|
|
129
124
|
|
|
130
|
-
@deprecation_warning(
|
|
131
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
132
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
133
|
-
)
|
|
134
125
|
def __init__(
|
|
135
126
|
self,
|
|
136
127
|
sequential: SequentialDataset,
|
|
137
128
|
max_sequence_length: int,
|
|
138
|
-
padding_value: int =
|
|
129
|
+
padding_value: Optional[int] = None,
|
|
139
130
|
) -> None:
|
|
140
131
|
"""
|
|
141
132
|
:param sequential: Sequential dataset with data to make predictions at.
|
|
@@ -179,17 +170,13 @@ class SasRecValidationDataset(TorchDataset):
|
|
|
179
170
|
Dataset that generates samples to infer and validate SasRec-like model
|
|
180
171
|
"""
|
|
181
172
|
|
|
182
|
-
@deprecation_warning(
|
|
183
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
184
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
185
|
-
)
|
|
186
173
|
def __init__(
|
|
187
174
|
self,
|
|
188
175
|
sequential: SequentialDataset,
|
|
189
176
|
ground_truth: SequentialDataset,
|
|
190
177
|
train: SequentialDataset,
|
|
191
178
|
max_sequence_length: int,
|
|
192
|
-
padding_value: int =
|
|
179
|
+
padding_value: Optional[int] = None,
|
|
193
180
|
label_feature_name: Optional[str] = None,
|
|
194
181
|
):
|
|
195
182
|
"""
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Literal, Optional, Union, cast
|
|
3
3
|
|
|
4
4
|
import lightning
|
|
5
5
|
import torch
|
|
@@ -341,7 +341,7 @@ class SasRec(lightning.LightningModule):
|
|
|
341
341
|
positive_labels: torch.LongTensor,
|
|
342
342
|
padding_mask: torch.BoolTensor,
|
|
343
343
|
target_padding_mask: torch.BoolTensor,
|
|
344
|
-
) ->
|
|
344
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.LongTensor, int]:
|
|
345
345
|
assert self._loss_sample_count is not None
|
|
346
346
|
n_negative_samples = self._loss_sample_count
|
|
347
347
|
positive_labels = cast(
|
|
@@ -428,7 +428,7 @@ class SasRec(lightning.LightningModule):
|
|
|
428
428
|
msg = "Not supported loss_type"
|
|
429
429
|
raise NotImplementedError(msg)
|
|
430
430
|
|
|
431
|
-
def get_all_embeddings(self) ->
|
|
431
|
+
def get_all_embeddings(self) -> dict[str, torch.nn.Embedding]:
|
|
432
432
|
"""
|
|
433
433
|
:returns: copy of all embeddings as a dictionary.
|
|
434
434
|
"""
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import contextlib
|
|
3
|
-
from typing import Any,
|
|
3
|
+
from typing import Any, Optional, Union, cast
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
@@ -212,7 +212,7 @@ class SasRecMasks:
|
|
|
212
212
|
self,
|
|
213
213
|
feature_tensor: TensorMap,
|
|
214
214
|
padding_mask: torch.BoolTensor,
|
|
215
|
-
) ->
|
|
215
|
+
) -> tuple[torch.BoolTensor, torch.BoolTensor, TensorMap]:
|
|
216
216
|
"""
|
|
217
217
|
:param feature_tensor: Batch of features.
|
|
218
218
|
:param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
|
|
@@ -260,7 +260,7 @@ class BaseSasRecEmbeddings(abc.ABC):
|
|
|
260
260
|
"""
|
|
261
261
|
|
|
262
262
|
@abc.abstractmethod
|
|
263
|
-
def get_all_embeddings(self) ->
|
|
263
|
+
def get_all_embeddings(self) -> dict[str, torch.Tensor]:
|
|
264
264
|
"""
|
|
265
265
|
:returns: copy of all embeddings presented in a layer as a dict.
|
|
266
266
|
"""
|
|
@@ -366,7 +366,7 @@ class SasRecEmbeddings(torch.nn.Module, BaseSasRecEmbeddings):
|
|
|
366
366
|
# Last one is reserved for padding, so we remove it
|
|
367
367
|
return self.item_emb.weight[:-1, :]
|
|
368
368
|
|
|
369
|
-
def get_all_embeddings(self) ->
|
|
369
|
+
def get_all_embeddings(self) -> dict[str, torch.Tensor]:
|
|
370
370
|
"""
|
|
371
371
|
:returns: copy of all embeddings presented in this layer as a dict.
|
|
372
372
|
"""
|
|
@@ -579,7 +579,7 @@ class TiSasRecEmbeddings(torch.nn.Module, BaseSasRecEmbeddings):
|
|
|
579
579
|
self,
|
|
580
580
|
feature_tensor: TensorMap,
|
|
581
581
|
padding_mask: torch.BoolTensor,
|
|
582
|
-
) ->
|
|
582
|
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
583
583
|
"""
|
|
584
584
|
:param feature_tensor: Batch of features.
|
|
585
585
|
:param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
|
|
@@ -628,7 +628,7 @@ class TiSasRecEmbeddings(torch.nn.Module, BaseSasRecEmbeddings):
|
|
|
628
628
|
# Last one is reserved for padding, so we remove it
|
|
629
629
|
return self.item_emb.weight[:-1, :]
|
|
630
630
|
|
|
631
|
-
def get_all_embeddings(self) ->
|
|
631
|
+
def get_all_embeddings(self) -> dict[str, torch.Tensor]:
|
|
632
632
|
"""
|
|
633
633
|
:returns: copy of all embeddings presented in this layer as a dict.
|
|
634
634
|
"""
|
|
@@ -674,7 +674,7 @@ class TiSasRecLayers(torch.nn.Module):
|
|
|
674
674
|
seqs: torch.Tensor,
|
|
675
675
|
attention_mask: torch.BoolTensor,
|
|
676
676
|
padding_mask: torch.BoolTensor,
|
|
677
|
-
ti_embeddings:
|
|
677
|
+
ti_embeddings: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
|
678
678
|
device: torch.device,
|
|
679
679
|
) -> torch.Tensor:
|
|
680
680
|
"""
|
|
@@ -734,7 +734,7 @@ class TiSasRecAttention(torch.nn.Module):
|
|
|
734
734
|
keys: torch.LongTensor,
|
|
735
735
|
time_mask: torch.LongTensor,
|
|
736
736
|
attn_mask: torch.LongTensor,
|
|
737
|
-
ti_embeddings:
|
|
737
|
+
ti_embeddings: tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor],
|
|
738
738
|
device: torch.device,
|
|
739
739
|
) -> torch.Tensor:
|
|
740
740
|
"""
|
replay/models/slim.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Optional
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
@@ -21,7 +21,7 @@ class SLIM(NeighbourRec):
|
|
|
21
21
|
"""`SLIM: Sparse Linear Methods for Top-N Recommender Systems
|
|
22
22
|
<http://glaros.dtc.umn.edu/gkhome/fetch/papers/SLIM2011icdm.pdf>`_"""
|
|
23
23
|
|
|
24
|
-
def _get_ann_infer_params(self) ->
|
|
24
|
+
def _get_ann_infer_params(self) -> dict[str, Any]:
|
|
25
25
|
return {
|
|
26
26
|
"features_col": None,
|
|
27
27
|
}
|
replay/models/ucb.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Optional
|
|
3
3
|
|
|
4
4
|
from replay.data.dataset import Dataset
|
|
5
5
|
from replay.metrics import NDCG, Metric
|
|
@@ -103,7 +103,7 @@ class UCB(NonPersonalizedRecommender):
|
|
|
103
103
|
self,
|
|
104
104
|
train_dataset: Dataset, # noqa: ARG002
|
|
105
105
|
test_dataset: Dataset, # noqa: ARG002
|
|
106
|
-
param_borders: Optional[
|
|
106
|
+
param_borders: Optional[dict[str, list[Any]]] = None, # noqa: ARG002
|
|
107
107
|
criterion: Metric = NDCG, # noqa: ARG002
|
|
108
108
|
k: int = 10, # noqa: ARG002
|
|
109
109
|
budget: int = 10, # noqa: ARG002
|
replay/models/word2vec.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Optional
|
|
2
2
|
|
|
3
3
|
from replay.data import Dataset
|
|
4
4
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
@@ -24,7 +24,7 @@ class Word2VecRec(ANNMixin, Recommender, ItemVectorModel):
|
|
|
24
24
|
Trains word2vec model where items are treated as words and queries as sentences.
|
|
25
25
|
"""
|
|
26
26
|
|
|
27
|
-
def _get_ann_infer_params(self) ->
|
|
27
|
+
def _get_ann_infer_params(self) -> dict[str, Any]:
|
|
28
28
|
self.index_builder.index_params.dim = self.rank
|
|
29
29
|
return {
|
|
30
30
|
"features_col": "query_vector",
|
|
@@ -36,7 +36,7 @@ class Word2VecRec(ANNMixin, Recommender, ItemVectorModel):
|
|
|
36
36
|
query_vectors = query_vectors.select(self.query_column, vector_to_array("query_vector").alias("query_vector"))
|
|
37
37
|
return query_vectors
|
|
38
38
|
|
|
39
|
-
def _configure_index_builder(self, interactions: SparkDataFrame) ->
|
|
39
|
+
def _configure_index_builder(self, interactions: SparkDataFrame) -> dict[str, Any]:
|
|
40
40
|
item_vectors = self._get_item_vectors()
|
|
41
41
|
item_vectors = item_vectors.select(self.item_column, vector_to_array("item_vector").alias("item_vector"))
|
|
42
42
|
|
|
@@ -2,8 +2,9 @@ import abc
|
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
4
|
import warnings
|
|
5
|
+
from collections.abc import Sequence
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
7
|
+
from typing import Literal
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import polars as pl
|
|
@@ -114,7 +115,7 @@ class GreedyDiscretizingRule(BaseDiscretizingRule):
|
|
|
114
115
|
max_bin: int,
|
|
115
116
|
total_cnt: int,
|
|
116
117
|
min_data_in_bin: int,
|
|
117
|
-
) ->
|
|
118
|
+
) -> list[float]:
|
|
118
119
|
"""
|
|
119
120
|
Computes bound for bins.
|
|
120
121
|
|
|
@@ -363,7 +364,7 @@ Set 'keep' or 'skip' for processing NaN."
|
|
|
363
364
|
@classmethod
|
|
364
365
|
def load(cls, path: str) -> "GreedyDiscretizingRule":
|
|
365
366
|
base_path = Path(path).with_suffix(".replay").resolve()
|
|
366
|
-
with open(base_path / "init_args.json"
|
|
367
|
+
with open(base_path / "init_args.json") as file:
|
|
367
368
|
discretizer_rule_dict = json.loads(file.read())
|
|
368
369
|
|
|
369
370
|
discretizer_rule = cls(**discretizer_rule_dict["init_args"])
|
|
@@ -590,7 +591,7 @@ Set 'keep' or 'skip' for processing NaN."
|
|
|
590
591
|
@classmethod
|
|
591
592
|
def load(cls, path: str) -> "QuantileDiscretizingRule":
|
|
592
593
|
base_path = Path(path).with_suffix(".replay").resolve()
|
|
593
|
-
with open(base_path / "init_args.json"
|
|
594
|
+
with open(base_path / "init_args.json") as file:
|
|
594
595
|
discretizer_rule_dict = json.loads(file.read())
|
|
595
596
|
|
|
596
597
|
discretizer_rule = cls(**discretizer_rule_dict["init_args"])
|
|
@@ -655,7 +656,7 @@ class Discretizer:
|
|
|
655
656
|
"""
|
|
656
657
|
return self.fit(df).transform(df)
|
|
657
658
|
|
|
658
|
-
def set_handle_invalid(self, handle_invalid_rules:
|
|
659
|
+
def set_handle_invalid(self, handle_invalid_rules: dict[str, HandleInvalidStrategies]) -> None:
|
|
659
660
|
"""
|
|
660
661
|
Modify handle_invalid strategy on already fitted Discretizer.
|
|
661
662
|
|
|
@@ -704,13 +705,13 @@ class Discretizer:
|
|
|
704
705
|
@classmethod
|
|
705
706
|
def load(cls, path: str) -> "Discretizer":
|
|
706
707
|
base_path = Path(path).with_suffix(".replay").resolve()
|
|
707
|
-
with open(base_path / "init_args.json"
|
|
708
|
+
with open(base_path / "init_args.json") as file:
|
|
708
709
|
discretizer_dict = json.loads(file.read())
|
|
709
710
|
rules = []
|
|
710
711
|
for root, dirs, files in os.walk(str(base_path) + "/rules/"):
|
|
711
712
|
for d in dirs:
|
|
712
713
|
if d.split(".")[0] in discretizer_dict["rule_names"]:
|
|
713
|
-
with open(root + d + "/init_args.json"
|
|
714
|
+
with open(root + d + "/init_args.json") as file:
|
|
714
715
|
discretizer_rule_dict = json.loads(file.read())
|
|
715
716
|
rules.append(globals()[discretizer_rule_dict["_class_name"]].load(root + d))
|
|
716
717
|
|
replay/preprocessing/filters.py
CHANGED
|
@@ -4,7 +4,7 @@ Select or remove data by some criteria
|
|
|
4
4
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from datetime import datetime, timedelta
|
|
7
|
-
from typing import Callable, Literal, Optional,
|
|
7
|
+
from typing import Callable, Literal, Optional, Union
|
|
8
8
|
from uuid import uuid4
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
@@ -182,7 +182,7 @@ class InteractionEntriesFilter(_BaseFilter):
|
|
|
182
182
|
non_agg_column: str,
|
|
183
183
|
min_inter: Optional[int] = None,
|
|
184
184
|
max_inter: Optional[int] = None,
|
|
185
|
-
) ->
|
|
185
|
+
) -> tuple[PandasDataFrame, int, int]:
|
|
186
186
|
filtered_interactions = interactions.copy(deep=True)
|
|
187
187
|
|
|
188
188
|
filtered_interactions["count"] = filtered_interactions.groupby(agg_column, sort=False)[
|
|
@@ -207,7 +207,7 @@ class InteractionEntriesFilter(_BaseFilter):
|
|
|
207
207
|
non_agg_column: str,
|
|
208
208
|
min_inter: Optional[int] = None,
|
|
209
209
|
max_inter: Optional[int] = None,
|
|
210
|
-
) ->
|
|
210
|
+
) -> tuple[SparkDataFrame, int, int]:
|
|
211
211
|
filtered_interactions = interactions.withColumn(
|
|
212
212
|
"count", sf.count(non_agg_column).over(Window.partitionBy(agg_column))
|
|
213
213
|
)
|
|
@@ -233,7 +233,7 @@ class InteractionEntriesFilter(_BaseFilter):
|
|
|
233
233
|
non_agg_column: str,
|
|
234
234
|
min_inter: Optional[int] = None,
|
|
235
235
|
max_inter: Optional[int] = None,
|
|
236
|
-
) ->
|
|
236
|
+
) -> tuple[PolarsDataFrame, int, int]:
|
|
237
237
|
filtered_interactions = interactions.with_columns(
|
|
238
238
|
pl.col(non_agg_column).count().over(pl.col(agg_column)).alias("count")
|
|
239
239
|
)
|
|
@@ -9,7 +9,7 @@ Contains classes for users' and items' features generation based on interactions
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
from datetime import datetime
|
|
12
|
-
from typing import
|
|
12
|
+
from typing import Optional
|
|
13
13
|
|
|
14
14
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
15
15
|
|
|
@@ -64,7 +64,7 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
|
|
|
64
64
|
user_log_features: Optional[SparkDataFrame] = None
|
|
65
65
|
item_log_features: Optional[SparkDataFrame] = None
|
|
66
66
|
|
|
67
|
-
def _create_log_aggregates(self, agg_col: str = "user_idx") ->
|
|
67
|
+
def _create_log_aggregates(self, agg_col: str = "user_idx") -> list:
|
|
68
68
|
"""
|
|
69
69
|
Create features based on relevance type
|
|
70
70
|
(binary or not) and whether timestamp is present.
|
|
@@ -289,12 +289,12 @@ class ConditionalPopularityProcessor(EmptyFeatureProcessor):
|
|
|
289
289
|
If user features are provided, item features will be generated and vice versa.
|
|
290
290
|
"""
|
|
291
291
|
|
|
292
|
-
conditional_pop_dict: Optional[
|
|
292
|
+
conditional_pop_dict: Optional[dict[str, SparkDataFrame]]
|
|
293
293
|
entity_name: str
|
|
294
294
|
|
|
295
295
|
def __init__(
|
|
296
296
|
self,
|
|
297
|
-
cat_features_list:
|
|
297
|
+
cat_features_list: list,
|
|
298
298
|
):
|
|
299
299
|
"""
|
|
300
300
|
:param cat_features_list: List of columns with categorical features to use
|
|
@@ -397,8 +397,8 @@ class HistoryBasedFeaturesProcessor:
|
|
|
397
397
|
self,
|
|
398
398
|
use_log_features: bool = True,
|
|
399
399
|
use_conditional_popularity: bool = True,
|
|
400
|
-
user_cat_features_list: Optional[
|
|
401
|
-
item_cat_features_list: Optional[
|
|
400
|
+
user_cat_features_list: Optional[list] = None,
|
|
401
|
+
item_cat_features_list: Optional[list] = None,
|
|
402
402
|
):
|
|
403
403
|
"""
|
|
404
404
|
:param use_log_features: if add statistical log-based features
|
|
@@ -10,8 +10,9 @@ import abc
|
|
|
10
10
|
import json
|
|
11
11
|
import os
|
|
12
12
|
import warnings
|
|
13
|
+
from collections.abc import Mapping, Sequence
|
|
13
14
|
from pathlib import Path
|
|
14
|
-
from typing import
|
|
15
|
+
from typing import Literal, Optional, Union
|
|
15
16
|
|
|
16
17
|
import polars as pl
|
|
17
18
|
|
|
@@ -162,7 +163,7 @@ class LabelEncodingRule(BaseLabelEncodingRule):
|
|
|
162
163
|
def _make_inverse_mapping(self) -> Mapping:
|
|
163
164
|
return {val: key for key, val in self.get_mapping().items()}
|
|
164
165
|
|
|
165
|
-
def _make_inverse_mapping_list(self) ->
|
|
166
|
+
def _make_inverse_mapping_list(self) -> list:
|
|
166
167
|
inverse_mapping_list = [0 for _ in range(len(self.get_mapping()))]
|
|
167
168
|
for k, value in self.get_mapping().items():
|
|
168
169
|
inverse_mapping_list[value] = k
|
|
@@ -543,7 +544,7 @@ Convert type to string, integer, or float."
|
|
|
543
544
|
@classmethod
|
|
544
545
|
def load(cls, path: str) -> "LabelEncodingRule":
|
|
545
546
|
base_path = Path(path).with_suffix(".replay").resolve()
|
|
546
|
-
with open(base_path / "init_args.json"
|
|
547
|
+
with open(base_path / "init_args.json") as file:
|
|
547
548
|
encoder_rule_dict = json.loads(file.read())
|
|
548
549
|
|
|
549
550
|
string_column_type = encoder_rule_dict["fitted_args"]["column_type"]
|
|
@@ -901,7 +902,7 @@ class LabelEncoder:
|
|
|
901
902
|
"""
|
|
902
903
|
return self.fit(df).transform(df)
|
|
903
904
|
|
|
904
|
-
def set_handle_unknowns(self, handle_unknown_rules:
|
|
905
|
+
def set_handle_unknowns(self, handle_unknown_rules: dict[str, HandleUnknownStrategies]) -> None:
|
|
905
906
|
"""
|
|
906
907
|
Modify handle unknown strategy on already fitted encoder.
|
|
907
908
|
|
|
@@ -923,7 +924,7 @@ class LabelEncoder:
|
|
|
923
924
|
rule = list(filter(lambda x: x.column == column, self.rules))
|
|
924
925
|
rule[0].set_handle_unknown(handle_unknown)
|
|
925
926
|
|
|
926
|
-
def set_default_values(self, default_value_rules:
|
|
927
|
+
def set_default_values(self, default_value_rules: dict[str, Optional[Union[int, str]]]) -> None:
|
|
927
928
|
"""
|
|
928
929
|
Modify handle unknown strategy on already fitted encoder.
|
|
929
930
|
Default value that will fill the unknown labels
|
|
@@ -974,13 +975,13 @@ class LabelEncoder:
|
|
|
974
975
|
@classmethod
|
|
975
976
|
def load(cls, path: str) -> "LabelEncoder":
|
|
976
977
|
base_path = Path(path).with_suffix(".replay").resolve()
|
|
977
|
-
with open(base_path / "init_args.json"
|
|
978
|
+
with open(base_path / "init_args.json") as file:
|
|
978
979
|
encoder_dict = json.loads(file.read())
|
|
979
980
|
rules = []
|
|
980
981
|
for root, dirs, files in os.walk(str(base_path) + "/rules/"):
|
|
981
982
|
for d in dirs:
|
|
982
983
|
if d.split(".")[0] in encoder_dict["rule_names"]:
|
|
983
|
-
with open(root + d + "/init_args.json"
|
|
984
|
+
with open(root + d + "/init_args.json") as file:
|
|
984
985
|
encoder_rule_dict = json.loads(file.read())
|
|
985
986
|
rules.append(globals()[encoder_rule_dict["_class_name"]].load(root + d))
|
|
986
987
|
|
replay/scenarios/fallback.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from typing import Any, Optional, Union
|
|
2
3
|
|
|
3
4
|
from replay.data import Dataset
|
|
4
5
|
from replay.metrics import NDCG, Metric
|
|
@@ -125,12 +126,12 @@ class Fallback(BaseRecommender):
|
|
|
125
126
|
self,
|
|
126
127
|
train_dataset: Dataset,
|
|
127
128
|
test_dataset: Dataset,
|
|
128
|
-
param_borders: Optional[
|
|
129
|
+
param_borders: Optional[dict[str, dict[str, list[Any]]]] = None,
|
|
129
130
|
criterion: Metric = NDCG,
|
|
130
131
|
k: int = 10,
|
|
131
132
|
budget: int = 10,
|
|
132
133
|
new_study: bool = True,
|
|
133
|
-
) ->
|
|
134
|
+
) -> tuple[dict[str, Any]]:
|
|
134
135
|
"""
|
|
135
136
|
Searches best parameters with optuna.
|
|
136
137
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Optional
|
|
4
|
+
from typing import Optional
|
|
5
5
|
|
|
6
6
|
import polars as pl
|
|
7
7
|
|
|
@@ -20,7 +20,7 @@ if PYSPARK_AVAILABLE:
|
|
|
20
20
|
)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
SplitterReturnType =
|
|
23
|
+
SplitterReturnType = tuple[DataFrameLike, DataFrameLike]
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class Splitter(ABC):
|
|
@@ -90,7 +90,7 @@ class Splitter(ABC):
|
|
|
90
90
|
Method for loading splitter from `.replay` directory.
|
|
91
91
|
"""
|
|
92
92
|
base_path = Path(path).with_suffix(".replay").resolve()
|
|
93
|
-
with open(base_path / "init_args.json"
|
|
93
|
+
with open(base_path / "init_args.json") as file:
|
|
94
94
|
splitter_dict = json.loads(file.read())
|
|
95
95
|
splitter = cls(**splitter_dict["init_args"])
|
|
96
96
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import polars as pl
|
|
4
4
|
|
|
@@ -62,7 +62,7 @@ class ColdUserRandomSplitter(Splitter):
|
|
|
62
62
|
|
|
63
63
|
def _core_split_pandas(
|
|
64
64
|
self, interactions: PandasDataFrame, threshold: float
|
|
65
|
-
) ->
|
|
65
|
+
) -> tuple[PandasDataFrame, PandasDataFrame]:
|
|
66
66
|
users = PandasDataFrame(interactions[self.query_column].unique(), columns=[self.query_column])
|
|
67
67
|
train_users = users.sample(frac=(1 - threshold), random_state=self.seed)
|
|
68
68
|
train_users["is_test"] = False
|
|
@@ -78,7 +78,7 @@ class ColdUserRandomSplitter(Splitter):
|
|
|
78
78
|
|
|
79
79
|
def _core_split_spark(
|
|
80
80
|
self, interactions: SparkDataFrame, threshold: float
|
|
81
|
-
) ->
|
|
81
|
+
) -> tuple[SparkDataFrame, SparkDataFrame]:
|
|
82
82
|
users = interactions.select(self.query_column).distinct()
|
|
83
83
|
train_users, _ = users.randomSplit(
|
|
84
84
|
[1 - threshold, threshold],
|
|
@@ -97,7 +97,7 @@ class ColdUserRandomSplitter(Splitter):
|
|
|
97
97
|
|
|
98
98
|
def _core_split_polars(
|
|
99
99
|
self, interactions: PolarsDataFrame, threshold: float
|
|
100
|
-
) ->
|
|
100
|
+
) -> tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
101
101
|
train_users = (
|
|
102
102
|
interactions.select(self.query_column)
|
|
103
103
|
.unique(maintain_order=True)
|
replay/splitters/k_folds.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Literal, Optional
|
|
1
|
+
from typing import Literal, Optional
|
|
2
2
|
|
|
3
3
|
import polars as pl
|
|
4
4
|
|
|
@@ -83,7 +83,7 @@ class KFolds(Splitter):
|
|
|
83
83
|
"""
|
|
84
84
|
return self._core_split(interactions)
|
|
85
85
|
|
|
86
|
-
def _query_split_spark(self, interactions: SparkDataFrame) ->
|
|
86
|
+
def _query_split_spark(self, interactions: SparkDataFrame) -> tuple[SparkDataFrame, SparkDataFrame]:
|
|
87
87
|
dataframe = interactions.withColumn("_rand", sf.rand(self.seed))
|
|
88
88
|
dataframe = dataframe.withColumn(
|
|
89
89
|
"fold",
|
|
@@ -100,7 +100,7 @@ class KFolds(Splitter):
|
|
|
100
100
|
test = self._drop_cold_items_and_users(train, test)
|
|
101
101
|
yield train, test
|
|
102
102
|
|
|
103
|
-
def _query_split_pandas(self, interactions: PandasDataFrame) ->
|
|
103
|
+
def _query_split_pandas(self, interactions: PandasDataFrame) -> tuple[PandasDataFrame, PandasDataFrame]:
|
|
104
104
|
dataframe = interactions.sample(frac=1, random_state=self.seed).sort_values(self.query_column)
|
|
105
105
|
dataframe["fold"] = (dataframe.groupby(self.query_column, sort=False).cumcount() + 1) % self.n_folds
|
|
106
106
|
for i in range(self.n_folds):
|
|
@@ -115,7 +115,7 @@ class KFolds(Splitter):
|
|
|
115
115
|
test = self._drop_cold_items_and_users(train, test)
|
|
116
116
|
yield train, test
|
|
117
117
|
|
|
118
|
-
def _query_split_polars(self, interactions: PolarsDataFrame) ->
|
|
118
|
+
def _query_split_polars(self, interactions: PolarsDataFrame) -> tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
119
119
|
dataframe = interactions.sample(fraction=1, shuffle=True, seed=self.seed).sort(self.query_column)
|
|
120
120
|
dataframe = dataframe.with_columns(
|
|
121
121
|
(pl.cum_count(self.query_column).over(self.query_column) % self.n_folds).alias("fold")
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Literal, Optional
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
@@ -240,7 +240,7 @@ class LastNSplitter(Splitter):
|
|
|
240
240
|
|
|
241
241
|
return interactions
|
|
242
242
|
|
|
243
|
-
def _partial_split_interactions(self, interactions: DataFrameLike, n: int) ->
|
|
243
|
+
def _partial_split_interactions(self, interactions: DataFrameLike, n: int) -> tuple[DataFrameLike, DataFrameLike]:
|
|
244
244
|
res = self._add_time_partition(interactions)
|
|
245
245
|
if isinstance(interactions, SparkDataFrame):
|
|
246
246
|
return self._partial_split_interactions_spark(res, n)
|
|
@@ -250,7 +250,7 @@ class LastNSplitter(Splitter):
|
|
|
250
250
|
|
|
251
251
|
def _partial_split_interactions_pandas(
|
|
252
252
|
self, interactions: PandasDataFrame, n: int
|
|
253
|
-
) ->
|
|
253
|
+
) -> tuple[PandasDataFrame, PandasDataFrame]:
|
|
254
254
|
interactions["count"] = interactions.groupby(self.divide_column, sort=False)[self.divide_column].transform(len)
|
|
255
255
|
interactions["is_test"] = interactions["row_num"] > (interactions["count"] - float(n))
|
|
256
256
|
if self.session_id_column:
|
|
@@ -263,7 +263,7 @@ class LastNSplitter(Splitter):
|
|
|
263
263
|
|
|
264
264
|
def _partial_split_interactions_spark(
|
|
265
265
|
self, interactions: SparkDataFrame, n: int
|
|
266
|
-
) ->
|
|
266
|
+
) -> tuple[SparkDataFrame, SparkDataFrame]:
|
|
267
267
|
interactions = interactions.withColumn(
|
|
268
268
|
"count",
|
|
269
269
|
sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column)),
|
|
@@ -281,7 +281,7 @@ class LastNSplitter(Splitter):
|
|
|
281
281
|
|
|
282
282
|
def _partial_split_interactions_polars(
|
|
283
283
|
self, interactions: PolarsDataFrame, n: int
|
|
284
|
-
) ->
|
|
284
|
+
) -> tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
285
285
|
interactions = interactions.with_columns(
|
|
286
286
|
pl.col(self.timestamp_column).count().over(self.divide_column).alias("count")
|
|
287
287
|
)
|
|
@@ -296,7 +296,7 @@ class LastNSplitter(Splitter):
|
|
|
296
296
|
|
|
297
297
|
def _partial_split_timedelta(
|
|
298
298
|
self, interactions: DataFrameLike, timedelta: int
|
|
299
|
-
) ->
|
|
299
|
+
) -> tuple[DataFrameLike, DataFrameLike]:
|
|
300
300
|
if isinstance(interactions, SparkDataFrame):
|
|
301
301
|
return self._partial_split_timedelta_spark(interactions, timedelta)
|
|
302
302
|
if isinstance(interactions, PandasDataFrame):
|
|
@@ -305,7 +305,7 @@ class LastNSplitter(Splitter):
|
|
|
305
305
|
|
|
306
306
|
def _partial_split_timedelta_pandas(
|
|
307
307
|
self, interactions: PandasDataFrame, timedelta: int
|
|
308
|
-
) ->
|
|
308
|
+
) -> tuple[PandasDataFrame, PandasDataFrame]:
|
|
309
309
|
res = interactions.copy(deep=True)
|
|
310
310
|
res["diff_timestamp"] = (
|
|
311
311
|
res.groupby(self.divide_column)[self.timestamp_column].transform(max) - res[self.timestamp_column]
|
|
@@ -321,7 +321,7 @@ class LastNSplitter(Splitter):
|
|
|
321
321
|
|
|
322
322
|
def _partial_split_timedelta_spark(
|
|
323
323
|
self, interactions: SparkDataFrame, timedelta: int
|
|
324
|
-
) ->
|
|
324
|
+
) -> tuple[SparkDataFrame, SparkDataFrame]:
|
|
325
325
|
inter_with_max_time = interactions.withColumn(
|
|
326
326
|
"max_timestamp",
|
|
327
327
|
sf.max(self.timestamp_column).over(Window.partitionBy(self.divide_column)),
|
|
@@ -343,7 +343,7 @@ class LastNSplitter(Splitter):
|
|
|
343
343
|
|
|
344
344
|
def _partial_split_timedelta_polars(
|
|
345
345
|
self, interactions: PolarsDataFrame, timedelta: int
|
|
346
|
-
) ->
|
|
346
|
+
) -> tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
347
347
|
res = interactions.with_columns(
|
|
348
348
|
(pl.col(self.timestamp_column).max().over(self.divide_column) - pl.col(self.timestamp_column)).alias(
|
|
349
349
|
"diff_timestamp"
|
|
@@ -358,7 +358,7 @@ class LastNSplitter(Splitter):
|
|
|
358
358
|
|
|
359
359
|
return train, test
|
|
360
360
|
|
|
361
|
-
def _core_split(self, interactions: DataFrameLike) ->
|
|
361
|
+
def _core_split(self, interactions: DataFrameLike) -> list[DataFrameLike]:
|
|
362
362
|
if self.strategy == "timedelta":
|
|
363
363
|
interactions = self._to_unix_timestamp(interactions)
|
|
364
364
|
train, test = getattr(self, "_partial_split_" + self.strategy)(interactions, self.N)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import polars as pl
|
|
4
4
|
|
|
@@ -100,7 +100,7 @@ class NewUsersSplitter(Splitter):
|
|
|
100
100
|
|
|
101
101
|
def _core_split_pandas(
|
|
102
102
|
self, interactions: PandasDataFrame, threshold: float
|
|
103
|
-
) ->
|
|
103
|
+
) -> tuple[PandasDataFrame, PandasDataFrame]:
|
|
104
104
|
start_date_by_user = (
|
|
105
105
|
interactions.groupby(self.query_column).agg(_start_dt_by_user=(self.timestamp_column, "min")).reset_index()
|
|
106
106
|
)
|
|
@@ -134,7 +134,7 @@ class NewUsersSplitter(Splitter):
|
|
|
134
134
|
|
|
135
135
|
def _core_split_spark(
|
|
136
136
|
self, interactions: SparkDataFrame, threshold: float
|
|
137
|
-
) ->
|
|
137
|
+
) -> tuple[SparkDataFrame, SparkDataFrame]:
|
|
138
138
|
start_date_by_user = interactions.groupby(self.query_column).agg(
|
|
139
139
|
sf.min(self.timestamp_column).alias("_start_dt_by_user")
|
|
140
140
|
)
|
|
@@ -171,7 +171,7 @@ class NewUsersSplitter(Splitter):
|
|
|
171
171
|
|
|
172
172
|
def _core_split_polars(
|
|
173
173
|
self, interactions: PolarsDataFrame, threshold: float
|
|
174
|
-
) ->
|
|
174
|
+
) -> tuple[PolarsDataFrame, PolarsDataFrame]:
|
|
175
175
|
start_date_by_user = interactions.group_by(self.query_column).agg(
|
|
176
176
|
pl.col(self.timestamp_column).min().alias("_start_dt_by_user")
|
|
177
177
|
)
|