replay-rec 0.19.0rc0__py3-none-any.whl → 0.20.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +6 -2
- replay/data/dataset.py +19 -18
- replay/data/dataset_utils/dataset_label_encoder.py +5 -4
- replay/data/nn/__init__.py +6 -6
- replay/data/nn/schema.py +9 -18
- replay/data/nn/sequence_tokenizer.py +54 -47
- replay/data/nn/sequential_dataset.py +16 -11
- replay/data/nn/torch_sequential_dataset.py +18 -16
- replay/data/nn/utils.py +3 -2
- replay/data/schema.py +3 -12
- replay/experimental/metrics/base_metric.py +6 -5
- replay/experimental/metrics/coverage.py +5 -5
- replay/experimental/metrics/experiment.py +2 -2
- replay/experimental/models/__init__.py +38 -1
- replay/experimental/models/admm_slim.py +59 -7
- replay/experimental/models/base_neighbour_rec.py +6 -10
- replay/experimental/models/base_rec.py +58 -12
- replay/experimental/models/base_torch_rec.py +2 -2
- replay/experimental/models/cql.py +6 -6
- replay/experimental/models/ddpg.py +47 -38
- replay/experimental/models/dt4rec/dt4rec.py +3 -3
- replay/experimental/models/dt4rec/utils.py +4 -5
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +5 -5
- replay/experimental/models/lightfm_wrap.py +4 -3
- replay/experimental/models/mult_vae.py +4 -4
- replay/experimental/models/neural_ts.py +13 -13
- replay/experimental/models/neuromf.py +4 -4
- replay/experimental/models/scala_als.py +14 -17
- replay/experimental/nn/data/schema_builder.py +4 -4
- replay/experimental/preprocessing/data_preparator.py +13 -13
- replay/experimental/preprocessing/padder.py +7 -7
- replay/experimental/preprocessing/sequence_generator.py +7 -7
- replay/experimental/scenarios/obp_wrapper/__init__.py +4 -4
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +5 -5
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +4 -4
- replay/experimental/scenarios/obp_wrapper/utils.py +3 -5
- replay/experimental/scenarios/two_stages/reranker.py +4 -4
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +18 -18
- replay/experimental/utils/session_handler.py +2 -2
- replay/metrics/base_metric.py +12 -11
- replay/metrics/categorical_diversity.py +8 -8
- replay/metrics/coverage.py +11 -15
- replay/metrics/experiment.py +6 -6
- replay/metrics/hitrate.py +1 -3
- replay/metrics/map.py +1 -3
- replay/metrics/mrr.py +1 -3
- replay/metrics/ndcg.py +1 -2
- replay/metrics/novelty.py +3 -3
- replay/metrics/offline_metrics.py +18 -18
- replay/metrics/precision.py +1 -3
- replay/metrics/recall.py +1 -3
- replay/metrics/rocauc.py +1 -3
- replay/metrics/surprisal.py +4 -4
- replay/metrics/torch_metrics_builder.py +13 -12
- replay/metrics/unexpectedness.py +2 -2
- replay/models/__init__.py +19 -0
- replay/models/als.py +2 -2
- replay/models/association_rules.py +5 -7
- replay/models/base_neighbour_rec.py +8 -10
- replay/models/base_rec.py +54 -302
- replay/models/cat_pop_rec.py +4 -2
- replay/models/common.py +69 -0
- replay/models/extensions/ann/ann_mixin.py +31 -25
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
- replay/models/extensions/ann/utils.py +4 -3
- replay/models/knn.py +18 -17
- replay/models/lin_ucb.py +3 -3
- replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
- replay/models/nn/sequential/bert4rec/dataset.py +3 -3
- replay/models/nn/sequential/bert4rec/lightning.py +3 -3
- replay/models/nn/sequential/bert4rec/model.py +2 -2
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +14 -14
- replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
- replay/models/nn/sequential/compiled/__init__.py +10 -0
- replay/models/nn/sequential/compiled/base_compiled_model.py +8 -6
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
- replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
- replay/models/nn/sequential/postprocessors/_base.py +2 -3
- replay/models/nn/sequential/postprocessors/postprocessors.py +10 -10
- replay/models/nn/sequential/sasrec/dataset.py +1 -1
- replay/models/nn/sequential/sasrec/lightning.py +3 -3
- replay/models/nn/sequential/sasrec/model.py +9 -9
- replay/models/optimization/__init__.py +14 -0
- replay/models/optimization/optuna_mixin.py +279 -0
- replay/{optimization → models/optimization}/optuna_objective.py +13 -15
- replay/models/slim.py +4 -6
- replay/models/ucb.py +2 -2
- replay/models/word2vec.py +9 -14
- replay/preprocessing/discretizer.py +9 -9
- replay/preprocessing/filters.py +4 -4
- replay/preprocessing/history_based_fp.py +7 -7
- replay/preprocessing/label_encoder.py +9 -8
- replay/scenarios/fallback.py +4 -3
- replay/splitters/base_splitter.py +3 -3
- replay/splitters/cold_user_random_splitter.py +17 -11
- replay/splitters/k_folds.py +4 -4
- replay/splitters/last_n_splitter.py +27 -20
- replay/splitters/new_users_splitter.py +4 -4
- replay/splitters/random_splitter.py +4 -4
- replay/splitters/ratio_splitter.py +10 -10
- replay/splitters/time_splitter.py +6 -6
- replay/splitters/two_stage_splitter.py +4 -4
- replay/utils/__init__.py +7 -2
- replay/utils/common.py +5 -3
- replay/utils/model_handler.py +11 -31
- replay/utils/session_handler.py +4 -4
- replay/utils/spark_utils.py +8 -7
- replay/utils/types.py +31 -19
- replay/utils/warnings.py +26 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0rc0.dist-info}/METADATA +58 -42
- replay_rec-0.20.0rc0.dist-info/RECORD +194 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0rc0.dist-info}/WHEEL +1 -1
- replay/optimization/__init__.py +0 -5
- replay_rec-0.19.0rc0.dist-info/RECORD +0 -191
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0rc0.dist-info/licenses}/LICENSE +0 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0rc0.dist-info/licenses}/NOTICE +0 -0
replay/__init__.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""RecSys library"""
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
# NOTE: This ensures distutils monkey-patching is performed before any
|
|
4
|
+
# functionality removed in Python 3.12 is used in downstream packages (like lightfm)
|
|
5
|
+
import setuptools as _
|
|
6
|
+
|
|
7
|
+
__version__ = "0.20.0.preview"
|
replay/data/dataset.py
CHANGED
|
@@ -5,8 +5,9 @@
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
7
|
import json
|
|
8
|
+
from collections.abc import Iterable, Sequence
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Callable,
|
|
10
|
+
from typing import Callable, Optional, Union
|
|
10
11
|
|
|
11
12
|
import numpy as np
|
|
12
13
|
from pandas import read_parquet as pd_read_parquet
|
|
@@ -315,7 +316,7 @@ class Dataset:
|
|
|
315
316
|
:returns: Loaded Dataset.
|
|
316
317
|
"""
|
|
317
318
|
base_path = Path(path).with_suffix(".replay").resolve()
|
|
318
|
-
with open(base_path / "init_args.json"
|
|
319
|
+
with open(base_path / "init_args.json") as file:
|
|
319
320
|
dataset_dict = json.loads(file.read())
|
|
320
321
|
|
|
321
322
|
if dataframe_type not in ["pandas", "spark", "polars", None]:
|
|
@@ -436,14 +437,14 @@ class Dataset:
|
|
|
436
437
|
)
|
|
437
438
|
|
|
438
439
|
def _get_feature_source_map(self):
|
|
439
|
-
self._feature_source_map:
|
|
440
|
+
self._feature_source_map: dict[FeatureSource, DataFrameLike] = {
|
|
440
441
|
FeatureSource.INTERACTIONS: self.interactions,
|
|
441
442
|
FeatureSource.QUERY_FEATURES: self.query_features,
|
|
442
443
|
FeatureSource.ITEM_FEATURES: self.item_features,
|
|
443
444
|
}
|
|
444
445
|
|
|
445
446
|
def _get_ids_source_map(self):
|
|
446
|
-
self._ids_feature_map:
|
|
447
|
+
self._ids_feature_map: dict[FeatureHint, DataFrameLike] = {
|
|
447
448
|
FeatureHint.QUERY_ID: self.query_features if self.query_features is not None else self.interactions,
|
|
448
449
|
FeatureHint.ITEM_ID: self.item_features if self.item_features is not None else self.interactions,
|
|
449
450
|
}
|
|
@@ -499,10 +500,10 @@ class Dataset:
|
|
|
499
500
|
)
|
|
500
501
|
return FeatureSchema(features_list=features_list + filled_features)
|
|
501
502
|
|
|
502
|
-
def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) ->
|
|
503
|
+
def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) -> list[FeatureInfo]:
|
|
503
504
|
features_list = list(feature_schema.all_features)
|
|
504
505
|
|
|
505
|
-
source_mapping:
|
|
506
|
+
source_mapping: dict[str, FeatureSource] = {}
|
|
506
507
|
for source in FeatureSource:
|
|
507
508
|
dataframe = self._feature_source_map[source]
|
|
508
509
|
if dataframe is not None:
|
|
@@ -524,7 +525,7 @@ class Dataset:
|
|
|
524
525
|
self._set_cardinality(features_list=features_list)
|
|
525
526
|
return features_list
|
|
526
527
|
|
|
527
|
-
def _get_unlabeled_columns(self, source: FeatureSource, feature_schema: FeatureSchema) ->
|
|
528
|
+
def _get_unlabeled_columns(self, source: FeatureSource, feature_schema: FeatureSchema) -> list[FeatureInfo]:
|
|
528
529
|
set_source_dataframe_columns = set(self._feature_source_map[source].columns)
|
|
529
530
|
set_labeled_dataframe_columns = set(feature_schema.columns)
|
|
530
531
|
unlabeled_columns = set_source_dataframe_columns - set_labeled_dataframe_columns
|
|
@@ -534,13 +535,13 @@ class Dataset:
|
|
|
534
535
|
]
|
|
535
536
|
return unlabeled_features_list
|
|
536
537
|
|
|
537
|
-
def _fill_unlabeled_features(self, source: FeatureSource, feature_schema: FeatureSchema) ->
|
|
538
|
+
def _fill_unlabeled_features(self, source: FeatureSource, feature_schema: FeatureSchema) -> list[FeatureInfo]:
|
|
538
539
|
unlabeled_columns = self._get_unlabeled_columns(source=source, feature_schema=feature_schema)
|
|
539
540
|
self._set_features_source(feature_list=unlabeled_columns, source=source)
|
|
540
541
|
self._set_cardinality(features_list=unlabeled_columns)
|
|
541
542
|
return unlabeled_columns
|
|
542
543
|
|
|
543
|
-
def _set_features_source(self, feature_list:
|
|
544
|
+
def _set_features_source(self, feature_list: list[FeatureInfo], source: FeatureSource) -> None:
|
|
544
545
|
for feature in feature_list:
|
|
545
546
|
feature._set_feature_source(source)
|
|
546
547
|
|
|
@@ -610,9 +611,9 @@ class Dataset:
|
|
|
610
611
|
if self.is_pandas:
|
|
611
612
|
try:
|
|
612
613
|
data[column] = data[column].astype(int)
|
|
613
|
-
except Exception:
|
|
614
|
+
except Exception as exc:
|
|
614
615
|
msg = f"IDs in {source.name}.{column} are not encoded. They are not int."
|
|
615
|
-
raise ValueError(msg)
|
|
616
|
+
raise ValueError(msg) from exc
|
|
616
617
|
|
|
617
618
|
if self.is_pandas:
|
|
618
619
|
is_int = np.issubdtype(dict(data.dtypes)[column], int)
|
|
@@ -775,10 +776,10 @@ def check_dataframes_types_equal(dataframe: DataFrameLike, other: DataFrameLike)
|
|
|
775
776
|
|
|
776
777
|
:returns: True if dataframes have same type.
|
|
777
778
|
"""
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
779
|
+
return any(
|
|
780
|
+
[
|
|
781
|
+
isinstance(dataframe, PandasDataFrame) and isinstance(other, PandasDataFrame),
|
|
782
|
+
isinstance(dataframe, SparkDataFrame) and isinstance(other, SparkDataFrame),
|
|
783
|
+
isinstance(dataframe, PolarsDataFrame) and isinstance(other, PolarsDataFrame),
|
|
784
|
+
]
|
|
785
|
+
)
|
|
@@ -6,7 +6,8 @@ Contains classes for encoding categorical data
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import warnings
|
|
9
|
-
from
|
|
9
|
+
from collections.abc import Iterable, Iterator, Sequence
|
|
10
|
+
from typing import Optional, Union
|
|
10
11
|
|
|
11
12
|
from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, FeatureType
|
|
12
13
|
from replay.preprocessing import LabelEncoder, LabelEncodingRule, SequenceEncodingRule
|
|
@@ -45,9 +46,9 @@ class DatasetLabelEncoder:
|
|
|
45
46
|
"""
|
|
46
47
|
self._handle_unknown_rule = handle_unknown_rule
|
|
47
48
|
self._default_value_rule = default_value_rule
|
|
48
|
-
self._encoding_rules:
|
|
49
|
+
self._encoding_rules: dict[str, LabelEncodingRule] = {}
|
|
49
50
|
|
|
50
|
-
self._features_columns:
|
|
51
|
+
self._features_columns: dict[Union[FeatureHint, FeatureSource], Sequence[str]] = {}
|
|
51
52
|
|
|
52
53
|
def fit(self, dataset: Dataset) -> "DatasetLabelEncoder":
|
|
53
54
|
"""
|
|
@@ -161,7 +162,7 @@ class DatasetLabelEncoder:
|
|
|
161
162
|
"""
|
|
162
163
|
self._check_if_initialized()
|
|
163
164
|
|
|
164
|
-
columns_set:
|
|
165
|
+
columns_set: set[str]
|
|
165
166
|
columns_set = {columns} if isinstance(columns, str) else {*columns}
|
|
166
167
|
|
|
167
168
|
def get_encoding_rules() -> Iterator[LabelEncodingRule]:
|
replay/data/nn/__init__.py
CHANGED
|
@@ -14,17 +14,17 @@ if TORCH_AVAILABLE:
|
|
|
14
14
|
)
|
|
15
15
|
|
|
16
16
|
__all__ = [
|
|
17
|
+
"DEFAULT_GROUND_TRUTH_PADDING_VALUE",
|
|
18
|
+
"DEFAULT_TRAIN_PADDING_VALUE",
|
|
17
19
|
"MutableTensorMap",
|
|
20
|
+
"PandasSequentialDataset",
|
|
21
|
+
"PolarsSequentialDataset",
|
|
22
|
+
"SequenceTokenizer",
|
|
23
|
+
"SequentialDataset",
|
|
18
24
|
"TensorFeatureInfo",
|
|
19
25
|
"TensorFeatureSource",
|
|
20
26
|
"TensorMap",
|
|
21
27
|
"TensorSchema",
|
|
22
|
-
"SequenceTokenizer",
|
|
23
|
-
"PandasSequentialDataset",
|
|
24
|
-
"PolarsSequentialDataset",
|
|
25
|
-
"SequentialDataset",
|
|
26
|
-
"DEFAULT_GROUND_TRUTH_PADDING_VALUE",
|
|
27
|
-
"DEFAULT_TRAIN_PADDING_VALUE",
|
|
28
28
|
"TorchSequentialBatch",
|
|
29
29
|
"TorchSequentialDataset",
|
|
30
30
|
"TorchSequentialValidationBatch",
|
replay/data/nn/schema.py
CHANGED
|
@@ -1,17 +1,8 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, Sequence, ValuesView
|
|
1
3
|
from typing import (
|
|
2
|
-
Dict,
|
|
3
|
-
ItemsView,
|
|
4
|
-
Iterable,
|
|
5
|
-
Iterator,
|
|
6
|
-
KeysView,
|
|
7
|
-
List,
|
|
8
|
-
Mapping,
|
|
9
4
|
Optional,
|
|
10
|
-
OrderedDict,
|
|
11
|
-
Sequence,
|
|
12
|
-
Set,
|
|
13
5
|
Union,
|
|
14
|
-
ValuesView,
|
|
15
6
|
)
|
|
16
7
|
|
|
17
8
|
import torch
|
|
@@ -20,7 +11,7 @@ from replay.data import FeatureHint, FeatureSource, FeatureType
|
|
|
20
11
|
|
|
21
12
|
# Alias
|
|
22
13
|
TensorMap = Mapping[str, torch.Tensor]
|
|
23
|
-
MutableTensorMap =
|
|
14
|
+
MutableTensorMap = dict[str, torch.Tensor]
|
|
24
15
|
|
|
25
16
|
|
|
26
17
|
class TensorFeatureSource:
|
|
@@ -79,7 +70,7 @@ class TensorFeatureInfo:
|
|
|
79
70
|
feature_type: FeatureType,
|
|
80
71
|
is_seq: bool = False,
|
|
81
72
|
feature_hint: Optional[FeatureHint] = None,
|
|
82
|
-
feature_sources: Optional[
|
|
73
|
+
feature_sources: Optional[list[TensorFeatureSource]] = None,
|
|
83
74
|
cardinality: Optional[int] = None,
|
|
84
75
|
padding_value: int = 0,
|
|
85
76
|
embedding_dim: Optional[int] = None,
|
|
@@ -154,13 +145,13 @@ class TensorFeatureInfo:
|
|
|
154
145
|
self._feature_hint = hint
|
|
155
146
|
|
|
156
147
|
@property
|
|
157
|
-
def feature_sources(self) -> Optional[
|
|
148
|
+
def feature_sources(self) -> Optional[list[TensorFeatureSource]]:
|
|
158
149
|
"""
|
|
159
150
|
:returns: List of sources feature came from.
|
|
160
151
|
"""
|
|
161
152
|
return self._feature_sources
|
|
162
153
|
|
|
163
|
-
def _set_feature_sources(self, sources:
|
|
154
|
+
def _set_feature_sources(self, sources: list[TensorFeatureSource]) -> None:
|
|
164
155
|
self._feature_sources = sources
|
|
165
156
|
|
|
166
157
|
@property
|
|
@@ -276,7 +267,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
276
267
|
|
|
277
268
|
:returns: New tensor schema of given features.
|
|
278
269
|
"""
|
|
279
|
-
features:
|
|
270
|
+
features: set[TensorFeatureInfo] = set()
|
|
280
271
|
for feature_name in features_to_keep:
|
|
281
272
|
features.add(self._tensor_schema[feature_name])
|
|
282
273
|
return TensorSchema(list(features))
|
|
@@ -432,7 +423,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
432
423
|
return None
|
|
433
424
|
return rating_features.item().name
|
|
434
425
|
|
|
435
|
-
def _get_object_args(self) ->
|
|
426
|
+
def _get_object_args(self) -> dict:
|
|
436
427
|
"""
|
|
437
428
|
Returns list of features represented as dictionaries.
|
|
438
429
|
"""
|
|
@@ -456,7 +447,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
|
|
|
456
447
|
return features
|
|
457
448
|
|
|
458
449
|
@classmethod
|
|
459
|
-
def _create_object_by_args(cls, args:
|
|
450
|
+
def _create_object_by_args(cls, args: dict) -> "TensorSchema":
|
|
460
451
|
features_list = []
|
|
461
452
|
for feature_data in args:
|
|
462
453
|
feature_data["feature_sources"] = (
|
|
@@ -2,8 +2,9 @@ import abc
|
|
|
2
2
|
import json
|
|
3
3
|
import pickle
|
|
4
4
|
import warnings
|
|
5
|
+
from collections.abc import Sequence
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
7
|
+
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import polars as pl
|
|
@@ -14,11 +15,11 @@ from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, Feat
|
|
|
14
15
|
from replay.data.dataset_utils import DatasetLabelEncoder
|
|
15
16
|
from replay.preprocessing import LabelEncoder, LabelEncodingRule
|
|
16
17
|
from replay.preprocessing.label_encoder import HandleUnknownStrategies
|
|
17
|
-
from replay.utils
|
|
18
|
+
from replay.utils import deprecation_warning
|
|
18
19
|
|
|
19
|
-
|
|
20
|
-
from .
|
|
21
|
-
from .
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from .schema import TensorFeatureInfo, TensorFeatureSource, TensorSchema
|
|
22
|
+
from .sequential_dataset import SequentialDataset
|
|
22
23
|
|
|
23
24
|
SequenceDataFrameLike = Union[PandasDataFrame, PolarsDataFrame]
|
|
24
25
|
_T = TypeVar("_T")
|
|
@@ -34,7 +35,7 @@ class SequenceTokenizer:
|
|
|
34
35
|
|
|
35
36
|
def __init__(
|
|
36
37
|
self,
|
|
37
|
-
tensor_schema: TensorSchema,
|
|
38
|
+
tensor_schema: "TensorSchema",
|
|
38
39
|
handle_unknown_rule: HandleUnknownStrategies = "error",
|
|
39
40
|
default_value_rule: Optional[Union[int, str]] = None,
|
|
40
41
|
allow_collect_to_master: bool = False,
|
|
@@ -77,7 +78,7 @@ class SequenceTokenizer:
|
|
|
77
78
|
self,
|
|
78
79
|
dataset: Dataset,
|
|
79
80
|
tensor_features_to_keep: Optional[Sequence[str]] = None,
|
|
80
|
-
) -> SequentialDataset:
|
|
81
|
+
) -> "SequentialDataset":
|
|
81
82
|
"""
|
|
82
83
|
:param dataset: input dataset to transform
|
|
83
84
|
:param tensor_features_to_keep: specified feature names to transform
|
|
@@ -89,7 +90,7 @@ class SequenceTokenizer:
|
|
|
89
90
|
def fit_transform(
|
|
90
91
|
self,
|
|
91
92
|
dataset: Dataset,
|
|
92
|
-
) -> SequentialDataset:
|
|
93
|
+
) -> "SequentialDataset":
|
|
93
94
|
"""
|
|
94
95
|
:param dataset: input dataset to transform
|
|
95
96
|
:returns: SequentialDataset
|
|
@@ -97,7 +98,7 @@ class SequenceTokenizer:
|
|
|
97
98
|
return self.fit(dataset)._transform_unchecked(dataset)
|
|
98
99
|
|
|
99
100
|
@property
|
|
100
|
-
def tensor_schema(self) -> TensorSchema:
|
|
101
|
+
def tensor_schema(self) -> "TensorSchema":
|
|
101
102
|
"""
|
|
102
103
|
:returns: tensor schema
|
|
103
104
|
"""
|
|
@@ -149,7 +150,9 @@ class SequenceTokenizer:
|
|
|
149
150
|
self,
|
|
150
151
|
dataset: Dataset,
|
|
151
152
|
tensor_features_to_keep: Optional[Sequence[str]] = None,
|
|
152
|
-
) -> SequentialDataset:
|
|
153
|
+
) -> "SequentialDataset":
|
|
154
|
+
from replay.data.nn.sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset
|
|
155
|
+
|
|
153
156
|
schema = self._tensor_schema
|
|
154
157
|
if tensor_features_to_keep is not None:
|
|
155
158
|
schema = schema.subset(tensor_features_to_keep)
|
|
@@ -185,7 +188,9 @@ class SequenceTokenizer:
|
|
|
185
188
|
def _group_dataset(
|
|
186
189
|
self,
|
|
187
190
|
dataset: Dataset,
|
|
188
|
-
) ->
|
|
191
|
+
) -> tuple[SequenceDataFrameLike, Optional[SequenceDataFrameLike], Optional[SequenceDataFrameLike]]:
|
|
192
|
+
from replay.data.nn.utils import ensure_pandas, groupby_sequences
|
|
193
|
+
|
|
189
194
|
grouped_interactions = groupby_sequences(
|
|
190
195
|
events=dataset.interactions,
|
|
191
196
|
groupby_col=dataset.feature_schema.query_id_column,
|
|
@@ -218,7 +223,7 @@ class SequenceTokenizer:
|
|
|
218
223
|
|
|
219
224
|
def _make_sequence_features(
|
|
220
225
|
self,
|
|
221
|
-
schema: TensorSchema,
|
|
226
|
+
schema: "TensorSchema",
|
|
222
227
|
feature_schema: FeatureSchema,
|
|
223
228
|
grouped_interactions: SequenceDataFrameLike,
|
|
224
229
|
query_features: Optional[SequenceDataFrameLike],
|
|
@@ -242,7 +247,7 @@ class SequenceTokenizer:
|
|
|
242
247
|
def _match_features_with_tensor_schema(
|
|
243
248
|
cls,
|
|
244
249
|
dataset: Dataset,
|
|
245
|
-
tensor_schema: TensorSchema,
|
|
250
|
+
tensor_schema: "TensorSchema",
|
|
246
251
|
) -> Dataset:
|
|
247
252
|
feature_subset_filter = cls._get_features_filter_from_schema(
|
|
248
253
|
tensor_schema,
|
|
@@ -261,16 +266,16 @@ class SequenceTokenizer:
|
|
|
261
266
|
@classmethod
|
|
262
267
|
def _get_features_filter_from_schema(
|
|
263
268
|
cls,
|
|
264
|
-
tensor_schema: TensorSchema,
|
|
269
|
+
tensor_schema: "TensorSchema",
|
|
265
270
|
query_id_column: str,
|
|
266
271
|
item_id_column: str,
|
|
267
|
-
) ->
|
|
272
|
+
) -> set[str]:
|
|
268
273
|
# We need only features, which related to tensor schema, otherwise feature should
|
|
269
274
|
# be ignored for efficiency reasons. The code below does feature filtering, and
|
|
270
275
|
# keeps features used as a source in tensor schema.
|
|
271
276
|
|
|
272
277
|
# Query and item IDs are always needed
|
|
273
|
-
features_subset:
|
|
278
|
+
features_subset: list[str] = [
|
|
274
279
|
query_id_column,
|
|
275
280
|
item_id_column,
|
|
276
281
|
]
|
|
@@ -291,7 +296,7 @@ class SequenceTokenizer:
|
|
|
291
296
|
return set(features_subset)
|
|
292
297
|
|
|
293
298
|
@classmethod
|
|
294
|
-
def _check_tensor_schema(cls, tensor_schema: TensorSchema) -> None:
|
|
299
|
+
def _check_tensor_schema(cls, tensor_schema: "TensorSchema") -> None:
|
|
295
300
|
# Check consistency of sequential features
|
|
296
301
|
for tensor_feature in tensor_schema.all_features:
|
|
297
302
|
feature_sources = tensor_feature.feature_sources
|
|
@@ -299,7 +304,7 @@ class SequenceTokenizer:
|
|
|
299
304
|
msg = "All tensor features must have sources defined"
|
|
300
305
|
raise ValueError(msg)
|
|
301
306
|
|
|
302
|
-
source_tables:
|
|
307
|
+
source_tables: list[FeatureSource] = [s.source for s in feature_sources]
|
|
303
308
|
|
|
304
309
|
unexpected_tables = list(filter(lambda x: not isinstance(x, FeatureSource), source_tables))
|
|
305
310
|
if len(unexpected_tables) > 0:
|
|
@@ -319,11 +324,11 @@ class SequenceTokenizer:
|
|
|
319
324
|
def _check_if_tensor_schema_matches_data( # noqa: C901
|
|
320
325
|
cls,
|
|
321
326
|
dataset: Dataset,
|
|
322
|
-
tensor_schema: TensorSchema,
|
|
327
|
+
tensor_schema: "TensorSchema",
|
|
323
328
|
tensor_features_to_keep: Optional[Sequence[str]] = None,
|
|
324
329
|
) -> None:
|
|
325
330
|
# Check if all source columns specified in tensor schema exist in provided data frames
|
|
326
|
-
sources_for_tensors:
|
|
331
|
+
sources_for_tensors: list["TensorFeatureSource"] = []
|
|
327
332
|
for tensor_feature_name, tensor_feature in tensor_schema.items():
|
|
328
333
|
if tensor_features_to_keep is not None and tensor_feature_name not in tensor_features_to_keep:
|
|
329
334
|
continue
|
|
@@ -413,9 +418,11 @@ class SequenceTokenizer:
|
|
|
413
418
|
|
|
414
419
|
:returns: Loaded tokenizer object.
|
|
415
420
|
"""
|
|
421
|
+
from replay.data.nn import TensorSchema
|
|
422
|
+
|
|
416
423
|
if not use_pickle:
|
|
417
424
|
base_path = Path(path).with_suffix(".replay").resolve()
|
|
418
|
-
with open(base_path / "init_args.json"
|
|
425
|
+
with open(base_path / "init_args.json") as file:
|
|
419
426
|
tokenizer_dict = json.loads(file.read())
|
|
420
427
|
|
|
421
428
|
# load tensor_schema, tensor_features
|
|
@@ -500,7 +507,7 @@ class _BaseSequenceProcessor(Generic[_T]):
|
|
|
500
507
|
|
|
501
508
|
def __init__(
|
|
502
509
|
self,
|
|
503
|
-
tensor_schema: TensorSchema,
|
|
510
|
+
tensor_schema: "TensorSchema",
|
|
504
511
|
query_id_column: str,
|
|
505
512
|
item_id_column: str,
|
|
506
513
|
grouped_interactions: _T,
|
|
@@ -535,7 +542,7 @@ class _BaseSequenceProcessor(Generic[_T]):
|
|
|
535
542
|
return self._process_num_feature(tensor_feature)
|
|
536
543
|
assert False, "Unknown tensor feature type"
|
|
537
544
|
|
|
538
|
-
def _process_num_feature(self, tensor_feature: TensorFeatureInfo) -> _T:
|
|
545
|
+
def _process_num_feature(self, tensor_feature: "TensorFeatureInfo") -> _T:
|
|
539
546
|
"""
|
|
540
547
|
Process numerical tensor feature depends on it source.
|
|
541
548
|
"""
|
|
@@ -548,7 +555,7 @@ class _BaseSequenceProcessor(Generic[_T]):
|
|
|
548
555
|
return self._process_num_item_feature(tensor_feature)
|
|
549
556
|
assert False, "Unknown tensor feature source table"
|
|
550
557
|
|
|
551
|
-
def _process_cat_feature(self, tensor_feature: TensorFeatureInfo) -> _T:
|
|
558
|
+
def _process_cat_feature(self, tensor_feature: "TensorFeatureInfo") -> _T:
|
|
552
559
|
"""
|
|
553
560
|
Process categorical tensor feature depends on it source.
|
|
554
561
|
"""
|
|
@@ -562,27 +569,27 @@ class _BaseSequenceProcessor(Generic[_T]):
|
|
|
562
569
|
assert False, "Unknown tensor feature source table"
|
|
563
570
|
|
|
564
571
|
@abc.abstractmethod
|
|
565
|
-
def _process_cat_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
|
|
572
|
+
def _process_cat_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
|
|
566
573
|
pass
|
|
567
574
|
|
|
568
575
|
@abc.abstractmethod
|
|
569
|
-
def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
|
|
576
|
+
def _process_cat_query_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
|
|
570
577
|
pass
|
|
571
578
|
|
|
572
579
|
@abc.abstractmethod
|
|
573
|
-
def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
|
|
580
|
+
def _process_cat_item_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
|
|
574
581
|
pass
|
|
575
582
|
|
|
576
583
|
@abc.abstractmethod
|
|
577
|
-
def _process_num_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
|
|
584
|
+
def _process_num_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
|
|
578
585
|
pass
|
|
579
586
|
|
|
580
587
|
@abc.abstractmethod
|
|
581
|
-
def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
|
|
588
|
+
def _process_num_query_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
|
|
582
589
|
pass
|
|
583
590
|
|
|
584
591
|
@abc.abstractmethod
|
|
585
|
-
def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
|
|
592
|
+
def _process_num_item_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
|
|
586
593
|
pass
|
|
587
594
|
|
|
588
595
|
|
|
@@ -597,7 +604,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
|
|
|
597
604
|
|
|
598
605
|
def __init__(
|
|
599
606
|
self,
|
|
600
|
-
tensor_schema: TensorSchema,
|
|
607
|
+
tensor_schema: "TensorSchema",
|
|
601
608
|
query_id_column: str,
|
|
602
609
|
item_id_column: str,
|
|
603
610
|
grouped_interactions: PandasDataFrame,
|
|
@@ -619,7 +626,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
|
|
|
619
626
|
"""
|
|
620
627
|
:returns: processed Pandas DataFrame with all features from tensor schema.
|
|
621
628
|
"""
|
|
622
|
-
all_features:
|
|
629
|
+
all_features: dict[str, Union[np.ndarray, list[np.ndarray]]] = {}
|
|
623
630
|
all_features[self._query_id_column] = self._grouped_interactions[self._query_id_column].values
|
|
624
631
|
|
|
625
632
|
for tensor_feature_name in self._tensor_schema:
|
|
@@ -628,8 +635,8 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
|
|
|
628
635
|
return PandasDataFrame(all_features)
|
|
629
636
|
|
|
630
637
|
def _process_num_interaction_feature(
|
|
631
|
-
self, tensor_feature: TensorFeatureInfo
|
|
632
|
-
) -> Union[
|
|
638
|
+
self, tensor_feature: "TensorFeatureInfo"
|
|
639
|
+
) -> Union[list[np.ndarray], list[list]]:
|
|
633
640
|
"""
|
|
634
641
|
Process numerical interaction feature.
|
|
635
642
|
|
|
@@ -650,7 +657,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
|
|
|
650
657
|
values.append(np.array(sequence))
|
|
651
658
|
return values
|
|
652
659
|
|
|
653
|
-
def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> Union[
|
|
660
|
+
def _process_num_item_feature(self, tensor_feature: "TensorFeatureInfo") -> Union[list[np.ndarray], list[list]]:
|
|
654
661
|
"""
|
|
655
662
|
Process numerical feature from item features dataset.
|
|
656
663
|
|
|
@@ -676,7 +683,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
|
|
|
676
683
|
|
|
677
684
|
return values
|
|
678
685
|
|
|
679
|
-
def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) ->
|
|
686
|
+
def _process_num_query_feature(self, tensor_feature: "TensorFeatureInfo") -> list[np.ndarray]:
|
|
680
687
|
"""
|
|
681
688
|
Process numerical feature from query features dataset.
|
|
682
689
|
|
|
@@ -687,8 +694,8 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
|
|
|
687
694
|
return self._process_cat_query_feature(tensor_feature)
|
|
688
695
|
|
|
689
696
|
def _process_cat_interaction_feature(
|
|
690
|
-
self, tensor_feature: TensorFeatureInfo
|
|
691
|
-
) -> Union[
|
|
697
|
+
self, tensor_feature: "TensorFeatureInfo"
|
|
698
|
+
) -> Union[list[np.ndarray], list[list]]:
|
|
692
699
|
"""
|
|
693
700
|
Process categorical interaction feature.
|
|
694
701
|
|
|
@@ -709,7 +716,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
|
|
|
709
716
|
values.append(np.array(sequence))
|
|
710
717
|
return values
|
|
711
718
|
|
|
712
|
-
def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) ->
|
|
719
|
+
def _process_cat_query_feature(self, tensor_feature: "TensorFeatureInfo") -> list[np.ndarray]:
|
|
713
720
|
"""
|
|
714
721
|
Process categorical feature from query features dataset.
|
|
715
722
|
|
|
@@ -738,7 +745,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
|
|
|
738
745
|
]
|
|
739
746
|
return [np.array([query_feature[i]]).reshape(-1) for i in range(len(self._grouped_interactions))]
|
|
740
747
|
|
|
741
|
-
def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> Union[
|
|
748
|
+
def _process_cat_item_feature(self, tensor_feature: "TensorFeatureInfo") -> Union[list[np.ndarray], list[list]]:
|
|
742
749
|
"""
|
|
743
750
|
Process categorical feature from item features dataset.
|
|
744
751
|
|
|
@@ -754,7 +761,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
|
|
|
754
761
|
assert source is not None
|
|
755
762
|
|
|
756
763
|
item_feature = self._item_features[source.column]
|
|
757
|
-
values:
|
|
764
|
+
values: list[np.ndarray] = []
|
|
758
765
|
|
|
759
766
|
for item_id_sequence in self._grouped_interactions[self._item_id_column]:
|
|
760
767
|
feature_sequence = item_feature.loc[item_id_sequence].values
|
|
@@ -784,7 +791,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
|
|
|
784
791
|
data = data.join(self._process_feature(tensor_feature_name), on=self._query_id_column, how="left")
|
|
785
792
|
return data
|
|
786
793
|
|
|
787
|
-
def _process_num_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
|
|
794
|
+
def _process_num_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
|
|
788
795
|
"""
|
|
789
796
|
Process numerical interaction feature.
|
|
790
797
|
|
|
@@ -794,7 +801,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
|
|
|
794
801
|
"""
|
|
795
802
|
return self._process_cat_interaction_feature(tensor_feature)
|
|
796
803
|
|
|
797
|
-
def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
|
|
804
|
+
def _process_num_query_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
|
|
798
805
|
"""
|
|
799
806
|
Process numerical feature from query features dataset.
|
|
800
807
|
|
|
@@ -805,7 +812,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
|
|
|
805
812
|
"""
|
|
806
813
|
return self._process_cat_query_feature(tensor_feature)
|
|
807
814
|
|
|
808
|
-
def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
|
|
815
|
+
def _process_num_item_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
|
|
809
816
|
"""
|
|
810
817
|
Process numerical feature from item features dataset.
|
|
811
818
|
|
|
@@ -816,7 +823,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
|
|
|
816
823
|
"""
|
|
817
824
|
return self._process_cat_item_feature(tensor_feature)
|
|
818
825
|
|
|
819
|
-
def _process_cat_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
|
|
826
|
+
def _process_cat_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
|
|
820
827
|
"""
|
|
821
828
|
Process categorical interaction feature.
|
|
822
829
|
|
|
@@ -833,7 +840,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
|
|
|
833
840
|
{source.column: tensor_feature.name}
|
|
834
841
|
)
|
|
835
842
|
|
|
836
|
-
def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
|
|
843
|
+
def _process_cat_query_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
|
|
837
844
|
"""
|
|
838
845
|
Process categorical feature from query features dataset.
|
|
839
846
|
|
|
@@ -877,7 +884,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
|
|
|
877
884
|
{source.column: tensor_feature.name}
|
|
878
885
|
)
|
|
879
886
|
|
|
880
|
-
def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
|
|
887
|
+
def _process_cat_item_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
|
|
881
888
|
"""
|
|
882
889
|
Process categorical feature from item features dataset.
|
|
883
890
|
|