replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +0 -61
- replay/experimental/metrics/base_metric.py +0 -661
- replay/experimental/metrics/coverage.py +0 -117
- replay/experimental/metrics/experiment.py +0 -200
- replay/experimental/metrics/hitrate.py +0 -27
- replay/experimental/metrics/map.py +0 -31
- replay/experimental/metrics/mrr.py +0 -19
- replay/experimental/metrics/ncis_precision.py +0 -32
- replay/experimental/metrics/ndcg.py +0 -50
- replay/experimental/metrics/precision.py +0 -23
- replay/experimental/metrics/recall.py +0 -26
- replay/experimental/metrics/rocauc.py +0 -50
- replay/experimental/metrics/surprisal.py +0 -102
- replay/experimental/metrics/unexpectedness.py +0 -74
- replay/experimental/models/__init__.py +0 -10
- replay/experimental/models/admm_slim.py +0 -216
- replay/experimental/models/base_neighbour_rec.py +0 -222
- replay/experimental/models/base_rec.py +0 -1361
- replay/experimental/models/base_torch_rec.py +0 -247
- replay/experimental/models/cql.py +0 -468
- replay/experimental/models/ddpg.py +0 -1007
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +0 -193
- replay/experimental/models/dt4rec/gpt1.py +0 -411
- replay/experimental/models/dt4rec/trainer.py +0 -128
- replay/experimental/models/dt4rec/utils.py +0 -274
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
- replay/experimental/models/implicit_wrap.py +0 -138
- replay/experimental/models/lightfm_wrap.py +0 -327
- replay/experimental/models/mult_vae.py +0 -374
- replay/experimental/models/neuromf.py +0 -462
- replay/experimental/models/scala_als.py +0 -311
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -58
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -929
- replay/experimental/preprocessing/padder.py +0 -231
- replay/experimental/preprocessing/sequence_generator.py +0 -218
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
- replay/experimental/scenarios/two_stages/reranker.py +0 -116
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -213
- replay/experimental/utils/session_handler.py +0 -47
- replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
- replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
|
@@ -1,4 +1,7 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import pickle
|
|
3
|
+
import warnings
|
|
4
|
+
from pathlib import Path
|
|
2
5
|
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
3
6
|
|
|
4
7
|
import numpy as np
|
|
@@ -6,14 +9,15 @@ import polars as pl
|
|
|
6
9
|
from pandas import DataFrame as PandasDataFrame
|
|
7
10
|
from polars import DataFrame as PolarsDataFrame
|
|
8
11
|
|
|
9
|
-
from replay.data import Dataset, FeatureSchema, FeatureSource
|
|
12
|
+
from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, FeatureType
|
|
10
13
|
from replay.data.dataset_utils import DatasetLabelEncoder
|
|
11
|
-
from .
|
|
12
|
-
from .sequential_dataset import PandasSequentialDataset, SequentialDataset, PolarsSequentialDataset
|
|
13
|
-
from .utils import ensure_pandas, groupby_sequences
|
|
14
|
-
from replay.preprocessing import LabelEncoder
|
|
14
|
+
from replay.preprocessing import LabelEncoder, LabelEncodingRule
|
|
15
15
|
from replay.preprocessing.label_encoder import HandleUnknownStrategies
|
|
16
|
+
from replay.utils.model_handler import deprecation_warning
|
|
16
17
|
|
|
18
|
+
from .schema import TensorFeatureInfo, TensorFeatureSource, TensorSchema
|
|
19
|
+
from .sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset, SequentialDataset
|
|
20
|
+
from .utils import ensure_pandas, groupby_sequences
|
|
17
21
|
|
|
18
22
|
SequenceDataFrameLike = Union[PandasDataFrame, PolarsDataFrame]
|
|
19
23
|
|
|
@@ -33,7 +37,7 @@ class SequenceTokenizer:
|
|
|
33
37
|
"""
|
|
34
38
|
:param tensor_schema: tensor schema of tensor features
|
|
35
39
|
:param handle_unknown_rule: handle unknown labels rule for LabelEncoder,
|
|
36
|
-
values are in ('error', 'use_default_value').
|
|
40
|
+
values are in ('error', 'use_default_value', 'drop').
|
|
37
41
|
Default: `error`
|
|
38
42
|
:param default_value: Default value that will fill the unknown labels after transform.
|
|
39
43
|
When the parameter handle_unknown is set to ``use_default_value``,
|
|
@@ -60,6 +64,7 @@ class SequenceTokenizer:
|
|
|
60
64
|
:returns: fitted SequenceTokenizer
|
|
61
65
|
"""
|
|
62
66
|
self._check_if_tensor_schema_matches_data(dataset, self._tensor_schema)
|
|
67
|
+
self._assign_tensor_features_cardinality(dataset)
|
|
63
68
|
self._encoder.fit(dataset)
|
|
64
69
|
return self
|
|
65
70
|
|
|
@@ -84,7 +89,6 @@ class SequenceTokenizer:
|
|
|
84
89
|
:param dataset: input dataset to transform
|
|
85
90
|
:returns: SequentialDataset
|
|
86
91
|
"""
|
|
87
|
-
# pylint: disable=protected-access
|
|
88
92
|
return self.fit(dataset)._transform_unchecked(dataset)
|
|
89
93
|
|
|
90
94
|
@property
|
|
@@ -161,10 +165,7 @@ class SequenceTokenizer:
|
|
|
161
165
|
|
|
162
166
|
assert self._tensor_schema.item_id_feature_name
|
|
163
167
|
|
|
164
|
-
if is_polars
|
|
165
|
-
dataset_type = PolarsSequentialDataset
|
|
166
|
-
else:
|
|
167
|
-
dataset_type = PandasSequentialDataset
|
|
168
|
+
dataset_type = PolarsSequentialDataset if is_polars else PandasSequentialDataset
|
|
168
169
|
|
|
169
170
|
return dataset_type(
|
|
170
171
|
tensor_schema=schema,
|
|
@@ -191,7 +192,7 @@ class SequenceTokenizer:
|
|
|
191
192
|
return (
|
|
192
193
|
grouped_interactions.sort(dataset.feature_schema.query_id_column),
|
|
193
194
|
dataset.query_features,
|
|
194
|
-
dataset.item_features
|
|
195
|
+
dataset.item_features,
|
|
195
196
|
)
|
|
196
197
|
|
|
197
198
|
# We sort by QUERY_ID to make sure order is deterministic
|
|
@@ -211,7 +212,6 @@ class SequenceTokenizer:
|
|
|
211
212
|
|
|
212
213
|
return grouped_interactions_pd, query_features_pd, item_features_pd
|
|
213
214
|
|
|
214
|
-
# pylint: disable=too-many-arguments
|
|
215
215
|
def _make_sequence_features(
|
|
216
216
|
self,
|
|
217
217
|
schema: TensorSchema,
|
|
@@ -298,24 +298,27 @@ class SequenceTokenizer:
|
|
|
298
298
|
for tensor_feature in tensor_schema.all_features:
|
|
299
299
|
feature_sources = tensor_feature.feature_sources
|
|
300
300
|
if not feature_sources:
|
|
301
|
-
|
|
301
|
+
msg = "All tensor features must have sources defined"
|
|
302
|
+
raise ValueError(msg)
|
|
302
303
|
|
|
303
304
|
source_tables: List[FeatureSource] = [s.source for s in feature_sources]
|
|
304
305
|
|
|
305
306
|
unexpected_tables = list(filter(lambda x: not isinstance(x, FeatureSource), source_tables))
|
|
306
307
|
if len(unexpected_tables) > 0:
|
|
307
|
-
|
|
308
|
+
msg = f"Found unexpected source tables: {unexpected_tables}"
|
|
309
|
+
raise ValueError(msg)
|
|
308
310
|
|
|
309
311
|
if not tensor_feature.is_seq:
|
|
310
312
|
if FeatureSource.INTERACTIONS in source_tables:
|
|
311
|
-
|
|
313
|
+
msg = "Interaction features must be treated as sequential"
|
|
314
|
+
raise ValueError(msg)
|
|
312
315
|
|
|
313
316
|
if FeatureSource.ITEM_FEATURES in source_tables:
|
|
314
|
-
|
|
317
|
+
msg = "Item features must be treated as sequential"
|
|
318
|
+
raise ValueError(msg)
|
|
315
319
|
|
|
316
|
-
# pylint: disable=too-many-branches
|
|
317
320
|
@classmethod
|
|
318
|
-
def _check_if_tensor_schema_matches_data(
|
|
321
|
+
def _check_if_tensor_schema_matches_data( # noqa: C901
|
|
319
322
|
cls,
|
|
320
323
|
dataset: Dataset,
|
|
321
324
|
tensor_schema: TensorSchema,
|
|
@@ -324,77 +327,205 @@ class SequenceTokenizer:
|
|
|
324
327
|
# Check if all source columns specified in tensor schema exist in provided data frames
|
|
325
328
|
sources_for_tensors: List[TensorFeatureSource] = []
|
|
326
329
|
for tensor_feature_name, tensor_feature in tensor_schema.items():
|
|
327
|
-
if
|
|
330
|
+
if tensor_features_to_keep is not None and tensor_feature_name not in tensor_features_to_keep:
|
|
328
331
|
continue
|
|
329
332
|
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
sources_for_tensors += feature_sources
|
|
333
|
+
if tensor_feature.feature_sources:
|
|
334
|
+
sources_for_tensors += tensor_feature.feature_sources
|
|
333
335
|
|
|
334
336
|
query_id_column = dataset.feature_schema.query_id_column
|
|
335
337
|
item_id_column = dataset.feature_schema.item_id_column
|
|
336
338
|
|
|
337
|
-
interaction_feature_columns =
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
339
|
+
interaction_feature_columns = {
|
|
340
|
+
*dataset.feature_schema.interaction_features.columns,
|
|
341
|
+
query_id_column,
|
|
342
|
+
item_id_column,
|
|
343
|
+
}
|
|
344
|
+
query_feature_columns = {*dataset.feature_schema.query_features.columns, query_id_column}
|
|
345
|
+
item_feature_columns = {*dataset.feature_schema.item_features.columns, item_id_column}
|
|
342
346
|
|
|
343
347
|
for feature_source in sources_for_tensors:
|
|
344
348
|
assert feature_source is not None
|
|
345
349
|
if feature_source.source == FeatureSource.INTERACTIONS:
|
|
346
350
|
if feature_source.column not in interaction_feature_columns:
|
|
347
|
-
|
|
351
|
+
msg = f"Expected column '{feature_source.column}' in dataset"
|
|
352
|
+
raise ValueError(msg)
|
|
348
353
|
elif feature_source.source == FeatureSource.QUERY_FEATURES:
|
|
349
354
|
if dataset.query_features is None:
|
|
350
|
-
|
|
355
|
+
msg = f"Expected column '{feature_source.column}', but query features are not specified"
|
|
356
|
+
raise ValueError(msg)
|
|
351
357
|
if feature_source.column not in query_feature_columns:
|
|
352
|
-
|
|
358
|
+
msg = f"Expected column '{feature_source.column}' in query features data frame"
|
|
359
|
+
raise ValueError(msg)
|
|
353
360
|
elif feature_source.source == FeatureSource.ITEM_FEATURES:
|
|
354
361
|
if dataset.item_features is None:
|
|
355
|
-
|
|
362
|
+
msg = f"Expected column '{feature_source.column}', but item features are not specified"
|
|
363
|
+
raise ValueError(msg)
|
|
356
364
|
if feature_source.column not in item_feature_columns:
|
|
357
|
-
|
|
365
|
+
msg = f"Expected column '{feature_source.column}' in item features data frame"
|
|
366
|
+
raise ValueError(msg)
|
|
358
367
|
else:
|
|
359
|
-
|
|
368
|
+
msg = f"Found unexpected table '{feature_source.source}' in tensor schema"
|
|
369
|
+
raise ValueError(msg)
|
|
360
370
|
|
|
361
371
|
# Check if user ID and item ID columns are consistent with tensor schema
|
|
362
372
|
if tensor_schema.query_id_feature_name is not None:
|
|
363
373
|
tensor_feature = tensor_schema.query_id_features.item()
|
|
364
374
|
assert tensor_feature.feature_source
|
|
365
375
|
if tensor_feature.feature_source.column != dataset.feature_schema.query_id_column:
|
|
366
|
-
|
|
376
|
+
msg = "Tensor schema query ID source colum does not match query ID in data frame"
|
|
377
|
+
raise ValueError(msg)
|
|
367
378
|
|
|
368
379
|
if tensor_schema.item_id_feature_name is None:
|
|
369
|
-
|
|
380
|
+
msg = "Tensor schema must have item id feature defined"
|
|
381
|
+
raise ValueError(msg)
|
|
370
382
|
|
|
371
383
|
tensor_feature = tensor_schema.item_id_features.item()
|
|
372
384
|
assert tensor_feature.feature_source
|
|
373
385
|
if tensor_feature.feature_source.column != dataset.feature_schema.item_id_column:
|
|
374
|
-
|
|
386
|
+
msg = "Tensor schema item ID source colum does not match item ID in data frame"
|
|
387
|
+
raise ValueError(msg)
|
|
388
|
+
|
|
389
|
+
def _assign_tensor_features_cardinality(self, dataset: Dataset) -> None:
|
|
390
|
+
for tensor_feature in self._tensor_schema.categorical_features.all_features:
|
|
391
|
+
dataset_feature = dataset.feature_schema[tensor_feature.feature_source.column]
|
|
392
|
+
if tensor_feature.cardinality is not None:
|
|
393
|
+
warnings.warn(
|
|
394
|
+
f"The specified cardinality of {tensor_feature.name} "
|
|
395
|
+
f"will be replaced by {dataset_feature.column} from Dataset"
|
|
396
|
+
)
|
|
397
|
+
if dataset_feature.feature_type != FeatureType.CATEGORICAL:
|
|
398
|
+
error_msg = (
|
|
399
|
+
f"TensorFeatureInfo {tensor_feature.name} "
|
|
400
|
+
f"and FeatureInfo {dataset_feature.column} must be the same FeatureType"
|
|
401
|
+
)
|
|
402
|
+
raise RuntimeError(error_msg)
|
|
403
|
+
tensor_feature._set_cardinality(dataset_feature.cardinality)
|
|
375
404
|
|
|
376
405
|
@classmethod
|
|
377
|
-
|
|
406
|
+
@deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
|
|
407
|
+
def load(cls, path: str, use_pickle: bool = False) -> "SequenceTokenizer":
|
|
378
408
|
"""
|
|
379
409
|
Load tokenizer object from the given path.
|
|
380
410
|
|
|
381
411
|
:param path: Path to load the tokenizer.
|
|
412
|
+
:param use_pickle: If `False` - tokenizer will be loaded from `.replay` directory.
|
|
413
|
+
If `True` - tokenizer will be loaded with pickle.
|
|
414
|
+
Default: `False`.
|
|
382
415
|
|
|
383
416
|
:returns: Loaded tokenizer object.
|
|
384
417
|
"""
|
|
385
|
-
|
|
386
|
-
|
|
418
|
+
if not use_pickle:
|
|
419
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
420
|
+
with open(base_path / "init_args.json", "r") as file:
|
|
421
|
+
tokenizer_dict = json.loads(file.read())
|
|
422
|
+
|
|
423
|
+
# load tensor_schema, tensor_features
|
|
424
|
+
tensor_schema_data = tokenizer_dict["init_args"]["tensor_schema"]
|
|
425
|
+
features_list = []
|
|
426
|
+
for feature_data in tensor_schema_data:
|
|
427
|
+
feature_data["feature_sources"] = [
|
|
428
|
+
TensorFeatureSource(source=FeatureSource[x["source"]], column=x["column"], index=x["index"])
|
|
429
|
+
for x in feature_data["feature_sources"]
|
|
430
|
+
]
|
|
431
|
+
f_type = feature_data["feature_type"]
|
|
432
|
+
f_hint = feature_data["feature_hint"]
|
|
433
|
+
feature_data["feature_type"] = FeatureType[f_type] if f_type else None
|
|
434
|
+
feature_data["feature_hint"] = FeatureHint[f_hint] if f_hint else None
|
|
435
|
+
features_list.append(TensorFeatureInfo(**feature_data))
|
|
436
|
+
tokenizer_dict["init_args"]["tensor_schema"] = TensorSchema(features_list)
|
|
437
|
+
|
|
438
|
+
# Load encoder columns and rules
|
|
439
|
+
types = list(FeatureHint) + list(FeatureSource)
|
|
440
|
+
map_types = {x.name: x for x in types}
|
|
441
|
+
encoder_features_columns = {
|
|
442
|
+
map_types[key]: value for key, value in tokenizer_dict["encoder"]["features_columns"].items()
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
rules_dict = tokenizer_dict["encoder"]["encoding_rules"]
|
|
446
|
+
for rule in rules_dict:
|
|
447
|
+
rule_data = rules_dict[rule]
|
|
448
|
+
if rule_data["mapping"] and rule_data["is_int"]:
|
|
449
|
+
rule_data["mapping"] = {int(key): value for key, value in rule_data["mapping"].items()}
|
|
450
|
+
del rule_data["is_int"]
|
|
451
|
+
|
|
452
|
+
tokenizer_dict["encoder"]["encoding_rules"][rule] = LabelEncodingRule(**rule_data)
|
|
453
|
+
|
|
454
|
+
# Init tokenizer
|
|
455
|
+
tokenizer = cls(**tokenizer_dict["init_args"])
|
|
456
|
+
tokenizer._encoder._features_columns = encoder_features_columns
|
|
457
|
+
tokenizer._encoder._encoding_rules = tokenizer_dict["encoder"]["encoding_rules"]
|
|
458
|
+
else:
|
|
459
|
+
with open(path, "rb") as file:
|
|
460
|
+
tokenizer = pickle.load(file)
|
|
387
461
|
|
|
388
462
|
return tokenizer
|
|
389
463
|
|
|
390
|
-
|
|
464
|
+
@deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
|
|
465
|
+
def save(self, path: str, use_pickle: bool = False) -> None:
|
|
391
466
|
"""
|
|
392
467
|
Save the tokenizer to the given path.
|
|
393
468
|
|
|
394
469
|
:param path: Path to save the tokenizer.
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
470
|
+
:param use_pickle: If `False` - tokenizer will be saved in `.replay` directory.
|
|
471
|
+
If `True` - tokenizer will be saved with pickle.
|
|
472
|
+
Default: `False`.
|
|
473
|
+
"""
|
|
474
|
+
if not use_pickle:
|
|
475
|
+
tokenizer_dict = {}
|
|
476
|
+
tokenizer_dict["_class_name"] = self.__class__.__name__
|
|
477
|
+
tokenizer_dict["init_args"] = {
|
|
478
|
+
"allow_collect_to_master": self._allow_collect_to_master,
|
|
479
|
+
"handle_unknown_rule": self._encoder._handle_unknown_rule,
|
|
480
|
+
"default_value_rule": self._encoder._default_value_rule,
|
|
481
|
+
"tensor_schema": [],
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
# save tensor schema
|
|
485
|
+
for feature in list(self._tensor_schema.values()):
|
|
486
|
+
tokenizer_dict["init_args"]["tensor_schema"].append(
|
|
487
|
+
{
|
|
488
|
+
"name": feature.name,
|
|
489
|
+
"feature_type": feature.feature_type.name,
|
|
490
|
+
"is_seq": feature.is_seq,
|
|
491
|
+
"feature_hint": feature.feature_hint.name if feature.feature_hint else None,
|
|
492
|
+
"feature_sources": [
|
|
493
|
+
{"source": x.source.name, "column": x.column, "index": x.index}
|
|
494
|
+
for x in feature.feature_sources
|
|
495
|
+
]
|
|
496
|
+
if feature.feature_sources
|
|
497
|
+
else None,
|
|
498
|
+
"cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
|
|
499
|
+
"embedding_dim": feature.embedding_dim
|
|
500
|
+
if feature.feature_type == FeatureType.CATEGORICAL
|
|
501
|
+
else None,
|
|
502
|
+
"tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
|
|
503
|
+
}
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# save DatasetLabelEncoder
|
|
507
|
+
tokenizer_dict["encoder"] = {
|
|
508
|
+
"features_columns": {key.name: value for key, value in self._encoder._features_columns.items()},
|
|
509
|
+
"encoding_rules": {
|
|
510
|
+
key: {
|
|
511
|
+
"column": value.column,
|
|
512
|
+
"mapping": value._mapping,
|
|
513
|
+
"handle_unknown": value._handle_unknown,
|
|
514
|
+
"default_value": value._default_value,
|
|
515
|
+
"is_int": isinstance(next(iter(value._mapping.keys())), int),
|
|
516
|
+
}
|
|
517
|
+
for key, value in self._encoder._encoding_rules.items()
|
|
518
|
+
},
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
base_path = Path(path).with_suffix(".replay").resolve()
|
|
522
|
+
base_path.mkdir(parents=True, exist_ok=True)
|
|
523
|
+
|
|
524
|
+
with open(base_path / "init_args.json", "w+") as file:
|
|
525
|
+
json.dump(tokenizer_dict, file)
|
|
526
|
+
else:
|
|
527
|
+
with open(path, "wb") as file:
|
|
528
|
+
pickle.dump(self, file)
|
|
398
529
|
|
|
399
530
|
|
|
400
531
|
class _SequenceProcessor:
|
|
@@ -409,7 +540,6 @@ class _SequenceProcessor:
|
|
|
409
540
|
with passing all tensor features one by one.
|
|
410
541
|
"""
|
|
411
542
|
|
|
412
|
-
# pylint: disable=too-many-arguments
|
|
413
543
|
def __init__(
|
|
414
544
|
self,
|
|
415
545
|
tensor_schema: TensorSchema,
|
|
@@ -462,13 +592,9 @@ class _SequenceProcessor:
|
|
|
462
592
|
for tensor_feature_name in self._tensor_schema:
|
|
463
593
|
tensor_feature = self._tensor_schema[tensor_feature_name]
|
|
464
594
|
if tensor_feature.is_cat:
|
|
465
|
-
data = data.join(
|
|
466
|
-
self._process_cat_feature(tensor_feature), on=self._query_id_column, how="left"
|
|
467
|
-
)
|
|
595
|
+
data = data.join(self._process_cat_feature(tensor_feature), on=self._query_id_column, how="left")
|
|
468
596
|
elif tensor_feature.is_num:
|
|
469
|
-
data = data.join(
|
|
470
|
-
self._process_num_feature(tensor_feature), on=self._query_id_column, how="left"
|
|
471
|
-
)
|
|
597
|
+
data = data.join(self._process_num_feature(tensor_feature), on=self._query_id_column, how="left")
|
|
472
598
|
else:
|
|
473
599
|
assert False, "Unknown tensor feature type"
|
|
474
600
|
return data
|
|
@@ -490,38 +616,40 @@ class _SequenceProcessor:
|
|
|
490
616
|
def get_sequence(user, source, data):
|
|
491
617
|
if source.source == FeatureSource.INTERACTIONS:
|
|
492
618
|
return np.array(
|
|
493
|
-
self._grouped_interactions
|
|
494
|
-
.
|
|
495
|
-
dtype=np.float32
|
|
619
|
+
self._grouped_interactions.filter(pl.col(self._query_id_column) == user)[source.column][0],
|
|
620
|
+
dtype=np.float32,
|
|
496
621
|
).tolist()
|
|
497
622
|
elif source.source == FeatureSource.ITEM_FEATURES:
|
|
498
623
|
return (
|
|
499
624
|
pl.DataFrame({self._item_id_column: data})
|
|
500
625
|
.join(self._item_features, on=self._item_id_column, how="left")
|
|
501
|
-
.select(source.column)
|
|
626
|
+
.select(source.column)
|
|
627
|
+
.to_numpy()
|
|
628
|
+
.reshape(-1)
|
|
629
|
+
.tolist()
|
|
502
630
|
)
|
|
503
631
|
else:
|
|
504
632
|
assert False, "Unknown tensor feature source table"
|
|
633
|
+
|
|
505
634
|
result = (
|
|
506
|
-
self._grouped_interactions
|
|
507
|
-
|
|
508
|
-
.map_rows(
|
|
509
|
-
lambda x:
|
|
510
|
-
(
|
|
511
|
-
x[0],
|
|
512
|
-
[get_sequence(x[0], source, x[1])
|
|
513
|
-
for source in tensor_feature.feature_sources]
|
|
514
|
-
)
|
|
635
|
+
self._grouped_interactions.select(self._query_id_column, self._item_id_column).map_rows(
|
|
636
|
+
lambda x: (x[0], [get_sequence(x[0], source, x[1]) for source in tensor_feature.feature_sources])
|
|
515
637
|
)
|
|
516
638
|
).rename({"column_0": self._query_id_column, "column_1": tensor_feature.name})
|
|
517
639
|
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
640
|
+
if tensor_feature.feature_hint == FeatureHint.TIMESTAMP:
|
|
641
|
+
reshape_size = -1
|
|
642
|
+
else:
|
|
643
|
+
reshape_size = (-1, len(tensor_feature.feature_sources))
|
|
644
|
+
|
|
645
|
+
return pl.DataFrame(
|
|
646
|
+
{
|
|
647
|
+
self._query_id_column: result[self._query_id_column].to_list(),
|
|
648
|
+
tensor_feature.name: [
|
|
649
|
+
np.array(x).reshape(reshape_size).tolist() for x in result[tensor_feature.name].to_list()
|
|
650
|
+
],
|
|
651
|
+
}
|
|
652
|
+
)
|
|
525
653
|
|
|
526
654
|
def _process_num_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
|
|
527
655
|
"""
|
|
@@ -554,7 +682,10 @@ class _SequenceProcessor:
|
|
|
554
682
|
else:
|
|
555
683
|
assert False, "Unknown tensor feature source table"
|
|
556
684
|
all_seqs = np.array(all_features_for_user, dtype=np.float32)
|
|
557
|
-
|
|
685
|
+
if tensor_feature.feature_hint == FeatureHint.TIMESTAMP:
|
|
686
|
+
all_seqs = all_seqs.reshape(-1)
|
|
687
|
+
else:
|
|
688
|
+
all_seqs = all_seqs.reshape(-1, (len(tensor_feature.feature_sources)))
|
|
558
689
|
values.append(all_seqs)
|
|
559
690
|
return values
|
|
560
691
|
|
|
@@ -572,9 +703,9 @@ class _SequenceProcessor:
|
|
|
572
703
|
assert source is not None
|
|
573
704
|
|
|
574
705
|
if self._is_polars:
|
|
575
|
-
return self._grouped_interactions.select(
|
|
576
|
-
|
|
577
|
-
)
|
|
706
|
+
return self._grouped_interactions.select(self._query_id_column, source.column).rename(
|
|
707
|
+
{source.column: tensor_feature.name}
|
|
708
|
+
)
|
|
578
709
|
|
|
579
710
|
return [np.array(sequence, dtype=np.int64) for sequence in self._grouped_interactions[source.column]]
|
|
580
711
|
|
|
@@ -603,9 +734,9 @@ class _SequenceProcessor:
|
|
|
603
734
|
result = self._query_features
|
|
604
735
|
repeat_value = 1
|
|
605
736
|
|
|
606
|
-
return result.select(
|
|
607
|
-
|
|
608
|
-
)
|
|
737
|
+
return result.select(self._query_id_column, pl.col(source.column).repeat_by(repeat_value)).rename(
|
|
738
|
+
{source.column: tensor_feature.name}
|
|
739
|
+
)
|
|
609
740
|
|
|
610
741
|
query_feature = self._query_features[source.column].values
|
|
611
742
|
if tensor_feature.is_seq:
|
|
@@ -632,20 +763,19 @@ class _SequenceProcessor:
|
|
|
632
763
|
|
|
633
764
|
if self._is_polars:
|
|
634
765
|
return (
|
|
635
|
-
self._grouped_interactions
|
|
636
|
-
.select(self._query_id_column, self._item_id_column)
|
|
766
|
+
self._grouped_interactions.select(self._query_id_column, self._item_id_column)
|
|
637
767
|
.map_rows(
|
|
638
|
-
lambda x:
|
|
639
|
-
(
|
|
768
|
+
lambda x: (
|
|
640
769
|
x[0],
|
|
641
770
|
pl.DataFrame({self._item_id_column: x[1]})
|
|
642
771
|
.join(self._item_features, on=self._item_id_column, how="left")
|
|
643
|
-
.select(source.column)
|
|
772
|
+
.select(source.column)
|
|
773
|
+
.to_numpy()
|
|
774
|
+
.reshape(-1)
|
|
775
|
+
.tolist(),
|
|
644
776
|
)
|
|
645
|
-
)
|
|
646
|
-
|
|
647
|
-
"column_1": tensor_feature.name
|
|
648
|
-
})
|
|
777
|
+
)
|
|
778
|
+
.rename({"column_0": self._query_id_column, "column_1": tensor_feature.name})
|
|
649
779
|
)
|
|
650
780
|
|
|
651
781
|
item_feature = self._item_features[source.column]
|
|
@@ -6,11 +6,9 @@ import polars as pl
|
|
|
6
6
|
from pandas import DataFrame as PandasDataFrame
|
|
7
7
|
from polars import DataFrame as PolarsDataFrame
|
|
8
8
|
|
|
9
|
-
from replay.data.schema import FeatureType
|
|
10
9
|
from .schema import TensorSchema
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
# pylint: disable=missing-function-docstring
|
|
14
12
|
class SequentialDataset(abc.ABC):
|
|
15
13
|
"""
|
|
16
14
|
Abstract base class for sequential dataset
|
|
@@ -132,19 +130,9 @@ class PandasSequentialDataset(SequentialDataset):
|
|
|
132
130
|
|
|
133
131
|
self._sequences = sequences
|
|
134
132
|
|
|
135
|
-
for feature in tensor_schema.all_features:
|
|
136
|
-
if feature.feature_type == FeatureType.CATEGORICAL:
|
|
137
|
-
# pylint: disable=protected-access
|
|
138
|
-
feature._set_cardinality_callback(self.cardinality_callback)
|
|
139
|
-
|
|
140
133
|
def __len__(self) -> int:
|
|
141
134
|
return len(self._sequences)
|
|
142
135
|
|
|
143
|
-
def cardinality_callback(self, column: str) -> int:
|
|
144
|
-
if self._query_id_column == column:
|
|
145
|
-
return self._sequences.index.nunique()
|
|
146
|
-
return len({x for seq in self._sequences[column] for x in seq})
|
|
147
|
-
|
|
148
136
|
def get_query_id(self, index: int) -> int:
|
|
149
137
|
return self._sequences.index[index]
|
|
150
138
|
|
|
@@ -181,12 +169,12 @@ class PandasSequentialDataset(SequentialDataset):
|
|
|
181
169
|
|
|
182
170
|
@classmethod
|
|
183
171
|
def _check_if_schema_matches_data(cls, tensor_schema: TensorSchema, data: PandasDataFrame) -> None:
|
|
184
|
-
for tensor_feature_name in tensor_schema
|
|
172
|
+
for tensor_feature_name in tensor_schema:
|
|
185
173
|
if tensor_feature_name not in data:
|
|
186
|
-
|
|
174
|
+
msg = "Tensor schema does not match with provided data frame"
|
|
175
|
+
raise ValueError(msg)
|
|
187
176
|
|
|
188
177
|
|
|
189
|
-
# pylint:disable=super-init-not-called
|
|
190
178
|
class PolarsSequentialDataset(PandasSequentialDataset):
|
|
191
179
|
"""
|
|
192
180
|
Sequential dataset that stores sequences in PolarsDataFrame format.
|
|
@@ -215,11 +203,6 @@ class PolarsSequentialDataset(PandasSequentialDataset):
|
|
|
215
203
|
if self._sequences.index.name != query_id_column:
|
|
216
204
|
self._sequences = self._sequences.set_index(query_id_column)
|
|
217
205
|
|
|
218
|
-
for feature in tensor_schema.all_features:
|
|
219
|
-
if feature.feature_type == FeatureType.CATEGORICAL:
|
|
220
|
-
# pylint: disable=protected-access
|
|
221
|
-
feature._set_cardinality_callback(self.cardinality_callback)
|
|
222
|
-
|
|
223
206
|
def filter_by_query_id(self, query_ids_to_keep: np.ndarray) -> "PolarsSequentialDataset":
|
|
224
207
|
filtered_sequences = self._sequences.loc[query_ids_to_keep]
|
|
225
208
|
if filtered_sequences.index.name == self._query_id_column:
|
|
@@ -233,6 +216,7 @@ class PolarsSequentialDataset(PandasSequentialDataset):
|
|
|
233
216
|
|
|
234
217
|
@classmethod
|
|
235
218
|
def _check_if_schema_matches_data(cls, tensor_schema: TensorSchema, data: PolarsDataFrame) -> None:
|
|
236
|
-
for tensor_feature_name in tensor_schema
|
|
219
|
+
for tensor_feature_name in tensor_schema:
|
|
237
220
|
if tensor_feature_name not in data:
|
|
238
|
-
|
|
221
|
+
msg = "Tensor schema does not match with provided data frame"
|
|
222
|
+
raise ValueError(msg)
|
|
@@ -14,6 +14,7 @@ class TorchSequentialBatch(NamedTuple):
|
|
|
14
14
|
"""
|
|
15
15
|
Batch of TorchSequentialDataset
|
|
16
16
|
"""
|
|
17
|
+
|
|
17
18
|
query_id: torch.LongTensor
|
|
18
19
|
padding_mask: torch.BoolTensor
|
|
19
20
|
features: TensorMap
|
|
@@ -88,7 +89,7 @@ class TorchSequentialDataset(TorchDataset):
|
|
|
88
89
|
) -> torch.Tensor:
|
|
89
90
|
sequence = self._sequential.get_sequence(sequence_index, feature.name)
|
|
90
91
|
if feature.is_seq:
|
|
91
|
-
sequence = sequence[sequence_offset : sequence_offset + self._max_sequence_length]
|
|
92
|
+
sequence = sequence[sequence_offset : sequence_offset + self._max_sequence_length]
|
|
92
93
|
|
|
93
94
|
tensor_dtype = self._get_tensor_dtype(feature)
|
|
94
95
|
tensor_sequence = torch.tensor(sequence, dtype=tensor_dtype)
|
|
@@ -109,14 +110,15 @@ class TorchSequentialDataset(TorchDataset):
|
|
|
109
110
|
elif len(sequence.shape) == 2:
|
|
110
111
|
padded_sequence_shape = (self._max_sequence_length, sequence.shape[1])
|
|
111
112
|
else:
|
|
112
|
-
|
|
113
|
+
msg = f"Unsupported shape for sequence: {len(sequence.shape)}"
|
|
114
|
+
raise ValueError(msg)
|
|
113
115
|
|
|
114
116
|
padded_sequence = torch.full(
|
|
115
117
|
padded_sequence_shape,
|
|
116
118
|
self._padding_value,
|
|
117
119
|
dtype=sequence.dtype,
|
|
118
120
|
)
|
|
119
|
-
padded_sequence[-len(sequence) :].copy_(sequence)
|
|
121
|
+
padded_sequence[-len(sequence) :].copy_(sequence)
|
|
120
122
|
return padded_sequence
|
|
121
123
|
|
|
122
124
|
def _get_tensor_dtype(self, feature: TensorFeatureInfo) -> torch.dtype:
|
|
@@ -151,6 +153,7 @@ class TorchSequentialValidationBatch(NamedTuple):
|
|
|
151
153
|
"""
|
|
152
154
|
Batch of TorchSequentialValidationDataset
|
|
153
155
|
"""
|
|
156
|
+
|
|
154
157
|
query_id: torch.LongTensor
|
|
155
158
|
padding_mask: torch.BoolTensor
|
|
156
159
|
features: TensorMap
|
|
@@ -167,7 +170,6 @@ class TorchSequentialValidationDataset(TorchDataset):
|
|
|
167
170
|
Torch dataset for sequential recommender models that additionally stores ground truth
|
|
168
171
|
"""
|
|
169
172
|
|
|
170
|
-
# pylint: disable=too-many-arguments
|
|
171
173
|
def __init__(
|
|
172
174
|
self,
|
|
173
175
|
sequential: SequentialDataset,
|
|
@@ -195,19 +197,24 @@ class TorchSequentialValidationDataset(TorchDataset):
|
|
|
195
197
|
|
|
196
198
|
if label_feature_name:
|
|
197
199
|
if label_feature_name not in ground_truth.schema:
|
|
198
|
-
|
|
200
|
+
msg = "Label feature name not found in ground truth schema"
|
|
201
|
+
raise ValueError(msg)
|
|
199
202
|
|
|
200
203
|
if label_feature_name not in train.schema:
|
|
201
|
-
|
|
204
|
+
msg = "Label feature name not found in train schema"
|
|
205
|
+
raise ValueError(msg)
|
|
202
206
|
|
|
203
207
|
if not ground_truth.schema[label_feature_name].is_cat:
|
|
204
|
-
|
|
208
|
+
msg = "Label feature must be categorical"
|
|
209
|
+
raise ValueError(msg)
|
|
205
210
|
|
|
206
211
|
if not ground_truth.schema[label_feature_name].is_seq:
|
|
207
|
-
|
|
212
|
+
msg = "Label feature must be sequential"
|
|
213
|
+
raise ValueError(msg)
|
|
208
214
|
|
|
209
215
|
if len(np.intersect1d(sequential.get_all_query_ids(), ground_truth.get_all_query_ids())) == 0:
|
|
210
|
-
|
|
216
|
+
msg = "Sequential data and ground truth must contain the same query IDs"
|
|
217
|
+
raise ValueError(msg)
|
|
211
218
|
|
|
212
219
|
self._ground_truth = ground_truth
|
|
213
220
|
self._train = train
|
|
@@ -271,7 +278,9 @@ class TorchSequentialValidationDataset(TorchDataset):
|
|
|
271
278
|
ground_truth_item_feature = ground_truth_schema.item_id_features.item()
|
|
272
279
|
|
|
273
280
|
if sequential_item_feature.name != ground_truth_item_feature.name:
|
|
274
|
-
|
|
281
|
+
msg = "Schema mismatch: item feature name does not match ground truth"
|
|
282
|
+
raise ValueError(msg)
|
|
275
283
|
|
|
276
284
|
if sequential_item_feature.cardinality != ground_truth_item_feature.cardinality:
|
|
277
|
-
|
|
285
|
+
msg = "Schema mismatch: item feature cardinality does not match ground truth"
|
|
286
|
+
raise ValueError(msg)
|