replay-rec 0.19.0rc0__py3-none-any.whl → 0.20.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. replay/__init__.py +6 -2
  2. replay/data/dataset.py +19 -18
  3. replay/data/dataset_utils/dataset_label_encoder.py +5 -4
  4. replay/data/nn/__init__.py +6 -6
  5. replay/data/nn/schema.py +9 -18
  6. replay/data/nn/sequence_tokenizer.py +54 -47
  7. replay/data/nn/sequential_dataset.py +16 -11
  8. replay/data/nn/torch_sequential_dataset.py +18 -16
  9. replay/data/nn/utils.py +3 -2
  10. replay/data/schema.py +3 -12
  11. replay/experimental/metrics/base_metric.py +6 -5
  12. replay/experimental/metrics/coverage.py +5 -5
  13. replay/experimental/metrics/experiment.py +2 -2
  14. replay/experimental/models/__init__.py +38 -1
  15. replay/experimental/models/admm_slim.py +59 -7
  16. replay/experimental/models/base_neighbour_rec.py +6 -10
  17. replay/experimental/models/base_rec.py +58 -12
  18. replay/experimental/models/base_torch_rec.py +2 -2
  19. replay/experimental/models/cql.py +6 -6
  20. replay/experimental/models/ddpg.py +47 -38
  21. replay/experimental/models/dt4rec/dt4rec.py +3 -3
  22. replay/experimental/models/dt4rec/utils.py +4 -5
  23. replay/experimental/models/extensions/spark_custom_models/als_extension.py +5 -5
  24. replay/experimental/models/lightfm_wrap.py +4 -3
  25. replay/experimental/models/mult_vae.py +4 -4
  26. replay/experimental/models/neural_ts.py +13 -13
  27. replay/experimental/models/neuromf.py +4 -4
  28. replay/experimental/models/scala_als.py +14 -17
  29. replay/experimental/nn/data/schema_builder.py +4 -4
  30. replay/experimental/preprocessing/data_preparator.py +13 -13
  31. replay/experimental/preprocessing/padder.py +7 -7
  32. replay/experimental/preprocessing/sequence_generator.py +7 -7
  33. replay/experimental/scenarios/obp_wrapper/__init__.py +4 -4
  34. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +5 -5
  35. replay/experimental/scenarios/obp_wrapper/replay_offline.py +4 -4
  36. replay/experimental/scenarios/obp_wrapper/utils.py +3 -5
  37. replay/experimental/scenarios/two_stages/reranker.py +4 -4
  38. replay/experimental/scenarios/two_stages/two_stages_scenario.py +18 -18
  39. replay/experimental/utils/session_handler.py +2 -2
  40. replay/metrics/base_metric.py +12 -11
  41. replay/metrics/categorical_diversity.py +8 -8
  42. replay/metrics/coverage.py +11 -15
  43. replay/metrics/experiment.py +6 -6
  44. replay/metrics/hitrate.py +1 -3
  45. replay/metrics/map.py +1 -3
  46. replay/metrics/mrr.py +1 -3
  47. replay/metrics/ndcg.py +1 -2
  48. replay/metrics/novelty.py +3 -3
  49. replay/metrics/offline_metrics.py +18 -18
  50. replay/metrics/precision.py +1 -3
  51. replay/metrics/recall.py +1 -3
  52. replay/metrics/rocauc.py +1 -3
  53. replay/metrics/surprisal.py +4 -4
  54. replay/metrics/torch_metrics_builder.py +13 -12
  55. replay/metrics/unexpectedness.py +2 -2
  56. replay/models/__init__.py +19 -0
  57. replay/models/als.py +2 -2
  58. replay/models/association_rules.py +5 -7
  59. replay/models/base_neighbour_rec.py +8 -10
  60. replay/models/base_rec.py +54 -302
  61. replay/models/cat_pop_rec.py +4 -2
  62. replay/models/common.py +69 -0
  63. replay/models/extensions/ann/ann_mixin.py +31 -25
  64. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
  65. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
  66. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
  67. replay/models/extensions/ann/utils.py +4 -3
  68. replay/models/knn.py +18 -17
  69. replay/models/lin_ucb.py +3 -3
  70. replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
  71. replay/models/nn/sequential/bert4rec/dataset.py +3 -3
  72. replay/models/nn/sequential/bert4rec/lightning.py +3 -3
  73. replay/models/nn/sequential/bert4rec/model.py +2 -2
  74. replay/models/nn/sequential/callbacks/prediction_callbacks.py +14 -14
  75. replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
  76. replay/models/nn/sequential/compiled/__init__.py +10 -0
  77. replay/models/nn/sequential/compiled/base_compiled_model.py +8 -6
  78. replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
  79. replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
  80. replay/models/nn/sequential/postprocessors/_base.py +2 -3
  81. replay/models/nn/sequential/postprocessors/postprocessors.py +10 -10
  82. replay/models/nn/sequential/sasrec/dataset.py +1 -1
  83. replay/models/nn/sequential/sasrec/lightning.py +3 -3
  84. replay/models/nn/sequential/sasrec/model.py +9 -9
  85. replay/models/optimization/__init__.py +14 -0
  86. replay/models/optimization/optuna_mixin.py +279 -0
  87. replay/{optimization → models/optimization}/optuna_objective.py +13 -15
  88. replay/models/slim.py +4 -6
  89. replay/models/ucb.py +2 -2
  90. replay/models/word2vec.py +9 -14
  91. replay/preprocessing/discretizer.py +9 -9
  92. replay/preprocessing/filters.py +4 -4
  93. replay/preprocessing/history_based_fp.py +7 -7
  94. replay/preprocessing/label_encoder.py +9 -8
  95. replay/scenarios/fallback.py +4 -3
  96. replay/splitters/base_splitter.py +3 -3
  97. replay/splitters/cold_user_random_splitter.py +17 -11
  98. replay/splitters/k_folds.py +4 -4
  99. replay/splitters/last_n_splitter.py +27 -20
  100. replay/splitters/new_users_splitter.py +4 -4
  101. replay/splitters/random_splitter.py +4 -4
  102. replay/splitters/ratio_splitter.py +10 -10
  103. replay/splitters/time_splitter.py +6 -6
  104. replay/splitters/two_stage_splitter.py +4 -4
  105. replay/utils/__init__.py +7 -2
  106. replay/utils/common.py +5 -3
  107. replay/utils/model_handler.py +11 -31
  108. replay/utils/session_handler.py +4 -4
  109. replay/utils/spark_utils.py +8 -7
  110. replay/utils/types.py +31 -19
  111. replay/utils/warnings.py +26 -0
  112. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0rc0.dist-info}/METADATA +58 -42
  113. replay_rec-0.20.0rc0.dist-info/RECORD +194 -0
  114. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0rc0.dist-info}/WHEEL +1 -1
  115. replay/optimization/__init__.py +0 -5
  116. replay_rec-0.19.0rc0.dist-info/RECORD +0 -191
  117. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0rc0.dist-info/licenses}/LICENSE +0 -0
  118. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0rc0.dist-info/licenses}/NOTICE +0 -0
replay/__init__.py CHANGED
@@ -1,3 +1,7 @@
1
- """ RecSys library """
1
+ """RecSys library"""
2
2
 
3
- __version__ = "0.19.0.preview"
3
+ # NOTE: This ensures distutils monkey-patching is performed before any
4
+ # functionality removed in Python 3.12 is used in downstream packages (like lightfm)
5
+ import setuptools as _
6
+
7
+ __version__ = "0.20.0.preview"
replay/data/dataset.py CHANGED
@@ -5,8 +5,9 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import json
8
+ from collections.abc import Iterable, Sequence
8
9
  from pathlib import Path
9
- from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
10
+ from typing import Callable, Optional, Union
10
11
 
11
12
  import numpy as np
12
13
  from pandas import read_parquet as pd_read_parquet
@@ -315,7 +316,7 @@ class Dataset:
315
316
  :returns: Loaded Dataset.
316
317
  """
317
318
  base_path = Path(path).with_suffix(".replay").resolve()
318
- with open(base_path / "init_args.json", "r") as file:
319
+ with open(base_path / "init_args.json") as file:
319
320
  dataset_dict = json.loads(file.read())
320
321
 
321
322
  if dataframe_type not in ["pandas", "spark", "polars", None]:
@@ -436,14 +437,14 @@ class Dataset:
436
437
  )
437
438
 
438
439
  def _get_feature_source_map(self):
439
- self._feature_source_map: Dict[FeatureSource, DataFrameLike] = {
440
+ self._feature_source_map: dict[FeatureSource, DataFrameLike] = {
440
441
  FeatureSource.INTERACTIONS: self.interactions,
441
442
  FeatureSource.QUERY_FEATURES: self.query_features,
442
443
  FeatureSource.ITEM_FEATURES: self.item_features,
443
444
  }
444
445
 
445
446
  def _get_ids_source_map(self):
446
- self._ids_feature_map: Dict[FeatureHint, DataFrameLike] = {
447
+ self._ids_feature_map: dict[FeatureHint, DataFrameLike] = {
447
448
  FeatureHint.QUERY_ID: self.query_features if self.query_features is not None else self.interactions,
448
449
  FeatureHint.ITEM_ID: self.item_features if self.item_features is not None else self.interactions,
449
450
  }
@@ -499,10 +500,10 @@ class Dataset:
499
500
  )
500
501
  return FeatureSchema(features_list=features_list + filled_features)
501
502
 
502
- def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) -> List[FeatureInfo]:
503
+ def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) -> list[FeatureInfo]:
503
504
  features_list = list(feature_schema.all_features)
504
505
 
505
- source_mapping: Dict[str, FeatureSource] = {}
506
+ source_mapping: dict[str, FeatureSource] = {}
506
507
  for source in FeatureSource:
507
508
  dataframe = self._feature_source_map[source]
508
509
  if dataframe is not None:
@@ -524,7 +525,7 @@ class Dataset:
524
525
  self._set_cardinality(features_list=features_list)
525
526
  return features_list
526
527
 
527
- def _get_unlabeled_columns(self, source: FeatureSource, feature_schema: FeatureSchema) -> List[FeatureInfo]:
528
+ def _get_unlabeled_columns(self, source: FeatureSource, feature_schema: FeatureSchema) -> list[FeatureInfo]:
528
529
  set_source_dataframe_columns = set(self._feature_source_map[source].columns)
529
530
  set_labeled_dataframe_columns = set(feature_schema.columns)
530
531
  unlabeled_columns = set_source_dataframe_columns - set_labeled_dataframe_columns
@@ -534,13 +535,13 @@ class Dataset:
534
535
  ]
535
536
  return unlabeled_features_list
536
537
 
537
- def _fill_unlabeled_features(self, source: FeatureSource, feature_schema: FeatureSchema) -> List[FeatureInfo]:
538
+ def _fill_unlabeled_features(self, source: FeatureSource, feature_schema: FeatureSchema) -> list[FeatureInfo]:
538
539
  unlabeled_columns = self._get_unlabeled_columns(source=source, feature_schema=feature_schema)
539
540
  self._set_features_source(feature_list=unlabeled_columns, source=source)
540
541
  self._set_cardinality(features_list=unlabeled_columns)
541
542
  return unlabeled_columns
542
543
 
543
- def _set_features_source(self, feature_list: List[FeatureInfo], source: FeatureSource) -> None:
544
+ def _set_features_source(self, feature_list: list[FeatureInfo], source: FeatureSource) -> None:
544
545
  for feature in feature_list:
545
546
  feature._set_feature_source(source)
546
547
 
@@ -610,9 +611,9 @@ class Dataset:
610
611
  if self.is_pandas:
611
612
  try:
612
613
  data[column] = data[column].astype(int)
613
- except Exception:
614
+ except Exception as exc:
614
615
  msg = f"IDs in {source.name}.{column} are not encoded. They are not int."
615
- raise ValueError(msg)
616
+ raise ValueError(msg) from exc
616
617
 
617
618
  if self.is_pandas:
618
619
  is_int = np.issubdtype(dict(data.dtypes)[column], int)
@@ -775,10 +776,10 @@ def check_dataframes_types_equal(dataframe: DataFrameLike, other: DataFrameLike)
775
776
 
776
777
  :returns: True if dataframes have same type.
777
778
  """
778
- if isinstance(dataframe, PandasDataFrame) and isinstance(other, PandasDataFrame):
779
- return True
780
- if isinstance(dataframe, SparkDataFrame) and isinstance(other, SparkDataFrame):
781
- return True
782
- if isinstance(dataframe, PolarsDataFrame) and isinstance(other, PolarsDataFrame):
783
- return True
784
- return False
779
+ return any(
780
+ [
781
+ isinstance(dataframe, PandasDataFrame) and isinstance(other, PandasDataFrame),
782
+ isinstance(dataframe, SparkDataFrame) and isinstance(other, SparkDataFrame),
783
+ isinstance(dataframe, PolarsDataFrame) and isinstance(other, PolarsDataFrame),
784
+ ]
785
+ )
@@ -6,7 +6,8 @@ Contains classes for encoding categorical data
6
6
  """
7
7
 
8
8
  import warnings
9
- from typing import Dict, Iterable, Iterator, Optional, Sequence, Set, Union
9
+ from collections.abc import Iterable, Iterator, Sequence
10
+ from typing import Optional, Union
10
11
 
11
12
  from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, FeatureType
12
13
  from replay.preprocessing import LabelEncoder, LabelEncodingRule, SequenceEncodingRule
@@ -45,9 +46,9 @@ class DatasetLabelEncoder:
45
46
  """
46
47
  self._handle_unknown_rule = handle_unknown_rule
47
48
  self._default_value_rule = default_value_rule
48
- self._encoding_rules: Dict[str, LabelEncodingRule] = {}
49
+ self._encoding_rules: dict[str, LabelEncodingRule] = {}
49
50
 
50
- self._features_columns: Dict[Union[FeatureHint, FeatureSource], Sequence[str]] = {}
51
+ self._features_columns: dict[Union[FeatureHint, FeatureSource], Sequence[str]] = {}
51
52
 
52
53
  def fit(self, dataset: Dataset) -> "DatasetLabelEncoder":
53
54
  """
@@ -161,7 +162,7 @@ class DatasetLabelEncoder:
161
162
  """
162
163
  self._check_if_initialized()
163
164
 
164
- columns_set: Set[str]
165
+ columns_set: set[str]
165
166
  columns_set = {columns} if isinstance(columns, str) else {*columns}
166
167
 
167
168
  def get_encoding_rules() -> Iterator[LabelEncodingRule]:
@@ -14,17 +14,17 @@ if TORCH_AVAILABLE:
14
14
  )
15
15
 
16
16
  __all__ = [
17
+ "DEFAULT_GROUND_TRUTH_PADDING_VALUE",
18
+ "DEFAULT_TRAIN_PADDING_VALUE",
17
19
  "MutableTensorMap",
20
+ "PandasSequentialDataset",
21
+ "PolarsSequentialDataset",
22
+ "SequenceTokenizer",
23
+ "SequentialDataset",
18
24
  "TensorFeatureInfo",
19
25
  "TensorFeatureSource",
20
26
  "TensorMap",
21
27
  "TensorSchema",
22
- "SequenceTokenizer",
23
- "PandasSequentialDataset",
24
- "PolarsSequentialDataset",
25
- "SequentialDataset",
26
- "DEFAULT_GROUND_TRUTH_PADDING_VALUE",
27
- "DEFAULT_TRAIN_PADDING_VALUE",
28
28
  "TorchSequentialBatch",
29
29
  "TorchSequentialDataset",
30
30
  "TorchSequentialValidationBatch",
replay/data/nn/schema.py CHANGED
@@ -1,17 +1,8 @@
1
+ from collections import OrderedDict
2
+ from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, Sequence, ValuesView
1
3
  from typing import (
2
- Dict,
3
- ItemsView,
4
- Iterable,
5
- Iterator,
6
- KeysView,
7
- List,
8
- Mapping,
9
4
  Optional,
10
- OrderedDict,
11
- Sequence,
12
- Set,
13
5
  Union,
14
- ValuesView,
15
6
  )
16
7
 
17
8
  import torch
@@ -20,7 +11,7 @@ from replay.data import FeatureHint, FeatureSource, FeatureType
20
11
 
21
12
  # Alias
22
13
  TensorMap = Mapping[str, torch.Tensor]
23
- MutableTensorMap = Dict[str, torch.Tensor]
14
+ MutableTensorMap = dict[str, torch.Tensor]
24
15
 
25
16
 
26
17
  class TensorFeatureSource:
@@ -79,7 +70,7 @@ class TensorFeatureInfo:
79
70
  feature_type: FeatureType,
80
71
  is_seq: bool = False,
81
72
  feature_hint: Optional[FeatureHint] = None,
82
- feature_sources: Optional[List[TensorFeatureSource]] = None,
73
+ feature_sources: Optional[list[TensorFeatureSource]] = None,
83
74
  cardinality: Optional[int] = None,
84
75
  padding_value: int = 0,
85
76
  embedding_dim: Optional[int] = None,
@@ -154,13 +145,13 @@ class TensorFeatureInfo:
154
145
  self._feature_hint = hint
155
146
 
156
147
  @property
157
- def feature_sources(self) -> Optional[List[TensorFeatureSource]]:
148
+ def feature_sources(self) -> Optional[list[TensorFeatureSource]]:
158
149
  """
159
150
  :returns: List of sources feature came from.
160
151
  """
161
152
  return self._feature_sources
162
153
 
163
- def _set_feature_sources(self, sources: List[TensorFeatureSource]) -> None:
154
+ def _set_feature_sources(self, sources: list[TensorFeatureSource]) -> None:
164
155
  self._feature_sources = sources
165
156
 
166
157
  @property
@@ -276,7 +267,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
276
267
 
277
268
  :returns: New tensor schema of given features.
278
269
  """
279
- features: Set[TensorFeatureInfo] = set()
270
+ features: set[TensorFeatureInfo] = set()
280
271
  for feature_name in features_to_keep:
281
272
  features.add(self._tensor_schema[feature_name])
282
273
  return TensorSchema(list(features))
@@ -432,7 +423,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
432
423
  return None
433
424
  return rating_features.item().name
434
425
 
435
- def _get_object_args(self) -> Dict:
426
+ def _get_object_args(self) -> dict:
436
427
  """
437
428
  Returns list of features represented as dictionaries.
438
429
  """
@@ -456,7 +447,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
456
447
  return features
457
448
 
458
449
  @classmethod
459
- def _create_object_by_args(cls, args: Dict) -> "TensorSchema":
450
+ def _create_object_by_args(cls, args: dict) -> "TensorSchema":
460
451
  features_list = []
461
452
  for feature_data in args:
462
453
  feature_data["feature_sources"] = (
@@ -2,8 +2,9 @@ import abc
2
2
  import json
3
3
  import pickle
4
4
  import warnings
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Dict, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union
7
+ from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
7
8
 
8
9
  import numpy as np
9
10
  import polars as pl
@@ -14,11 +15,11 @@ from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, Feat
14
15
  from replay.data.dataset_utils import DatasetLabelEncoder
15
16
  from replay.preprocessing import LabelEncoder, LabelEncodingRule
16
17
  from replay.preprocessing.label_encoder import HandleUnknownStrategies
17
- from replay.utils.model_handler import deprecation_warning
18
+ from replay.utils import deprecation_warning
18
19
 
19
- from .schema import TensorFeatureInfo, TensorFeatureSource, TensorSchema
20
- from .sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset, SequentialDataset
21
- from .utils import ensure_pandas, groupby_sequences
20
+ if TYPE_CHECKING:
21
+ from .schema import TensorFeatureInfo, TensorFeatureSource, TensorSchema
22
+ from .sequential_dataset import SequentialDataset
22
23
 
23
24
  SequenceDataFrameLike = Union[PandasDataFrame, PolarsDataFrame]
24
25
  _T = TypeVar("_T")
@@ -34,7 +35,7 @@ class SequenceTokenizer:
34
35
 
35
36
  def __init__(
36
37
  self,
37
- tensor_schema: TensorSchema,
38
+ tensor_schema: "TensorSchema",
38
39
  handle_unknown_rule: HandleUnknownStrategies = "error",
39
40
  default_value_rule: Optional[Union[int, str]] = None,
40
41
  allow_collect_to_master: bool = False,
@@ -77,7 +78,7 @@ class SequenceTokenizer:
77
78
  self,
78
79
  dataset: Dataset,
79
80
  tensor_features_to_keep: Optional[Sequence[str]] = None,
80
- ) -> SequentialDataset:
81
+ ) -> "SequentialDataset":
81
82
  """
82
83
  :param dataset: input dataset to transform
83
84
  :param tensor_features_to_keep: specified feature names to transform
@@ -89,7 +90,7 @@ class SequenceTokenizer:
89
90
  def fit_transform(
90
91
  self,
91
92
  dataset: Dataset,
92
- ) -> SequentialDataset:
93
+ ) -> "SequentialDataset":
93
94
  """
94
95
  :param dataset: input dataset to transform
95
96
  :returns: SequentialDataset
@@ -97,7 +98,7 @@ class SequenceTokenizer:
97
98
  return self.fit(dataset)._transform_unchecked(dataset)
98
99
 
99
100
  @property
100
- def tensor_schema(self) -> TensorSchema:
101
+ def tensor_schema(self) -> "TensorSchema":
101
102
  """
102
103
  :returns: tensor schema
103
104
  """
@@ -149,7 +150,9 @@ class SequenceTokenizer:
149
150
  self,
150
151
  dataset: Dataset,
151
152
  tensor_features_to_keep: Optional[Sequence[str]] = None,
152
- ) -> SequentialDataset:
153
+ ) -> "SequentialDataset":
154
+ from replay.data.nn.sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset
155
+
153
156
  schema = self._tensor_schema
154
157
  if tensor_features_to_keep is not None:
155
158
  schema = schema.subset(tensor_features_to_keep)
@@ -185,7 +188,9 @@ class SequenceTokenizer:
185
188
  def _group_dataset(
186
189
  self,
187
190
  dataset: Dataset,
188
- ) -> Tuple[SequenceDataFrameLike, Optional[SequenceDataFrameLike], Optional[SequenceDataFrameLike]]:
191
+ ) -> tuple[SequenceDataFrameLike, Optional[SequenceDataFrameLike], Optional[SequenceDataFrameLike]]:
192
+ from replay.data.nn.utils import ensure_pandas, groupby_sequences
193
+
189
194
  grouped_interactions = groupby_sequences(
190
195
  events=dataset.interactions,
191
196
  groupby_col=dataset.feature_schema.query_id_column,
@@ -218,7 +223,7 @@ class SequenceTokenizer:
218
223
 
219
224
  def _make_sequence_features(
220
225
  self,
221
- schema: TensorSchema,
226
+ schema: "TensorSchema",
222
227
  feature_schema: FeatureSchema,
223
228
  grouped_interactions: SequenceDataFrameLike,
224
229
  query_features: Optional[SequenceDataFrameLike],
@@ -242,7 +247,7 @@ class SequenceTokenizer:
242
247
  def _match_features_with_tensor_schema(
243
248
  cls,
244
249
  dataset: Dataset,
245
- tensor_schema: TensorSchema,
250
+ tensor_schema: "TensorSchema",
246
251
  ) -> Dataset:
247
252
  feature_subset_filter = cls._get_features_filter_from_schema(
248
253
  tensor_schema,
@@ -261,16 +266,16 @@ class SequenceTokenizer:
261
266
  @classmethod
262
267
  def _get_features_filter_from_schema(
263
268
  cls,
264
- tensor_schema: TensorSchema,
269
+ tensor_schema: "TensorSchema",
265
270
  query_id_column: str,
266
271
  item_id_column: str,
267
- ) -> Set[str]:
272
+ ) -> set[str]:
268
273
  # We need only features, which related to tensor schema, otherwise feature should
269
274
  # be ignored for efficiency reasons. The code below does feature filtering, and
270
275
  # keeps features used as a source in tensor schema.
271
276
 
272
277
  # Query and item IDs are always needed
273
- features_subset: List[str] = [
278
+ features_subset: list[str] = [
274
279
  query_id_column,
275
280
  item_id_column,
276
281
  ]
@@ -291,7 +296,7 @@ class SequenceTokenizer:
291
296
  return set(features_subset)
292
297
 
293
298
  @classmethod
294
- def _check_tensor_schema(cls, tensor_schema: TensorSchema) -> None:
299
+ def _check_tensor_schema(cls, tensor_schema: "TensorSchema") -> None:
295
300
  # Check consistency of sequential features
296
301
  for tensor_feature in tensor_schema.all_features:
297
302
  feature_sources = tensor_feature.feature_sources
@@ -299,7 +304,7 @@ class SequenceTokenizer:
299
304
  msg = "All tensor features must have sources defined"
300
305
  raise ValueError(msg)
301
306
 
302
- source_tables: List[FeatureSource] = [s.source for s in feature_sources]
307
+ source_tables: list[FeatureSource] = [s.source for s in feature_sources]
303
308
 
304
309
  unexpected_tables = list(filter(lambda x: not isinstance(x, FeatureSource), source_tables))
305
310
  if len(unexpected_tables) > 0:
@@ -319,11 +324,11 @@ class SequenceTokenizer:
319
324
  def _check_if_tensor_schema_matches_data( # noqa: C901
320
325
  cls,
321
326
  dataset: Dataset,
322
- tensor_schema: TensorSchema,
327
+ tensor_schema: "TensorSchema",
323
328
  tensor_features_to_keep: Optional[Sequence[str]] = None,
324
329
  ) -> None:
325
330
  # Check if all source columns specified in tensor schema exist in provided data frames
326
- sources_for_tensors: List[TensorFeatureSource] = []
331
+ sources_for_tensors: list["TensorFeatureSource"] = []
327
332
  for tensor_feature_name, tensor_feature in tensor_schema.items():
328
333
  if tensor_features_to_keep is not None and tensor_feature_name not in tensor_features_to_keep:
329
334
  continue
@@ -413,9 +418,11 @@ class SequenceTokenizer:
413
418
 
414
419
  :returns: Loaded tokenizer object.
415
420
  """
421
+ from replay.data.nn import TensorSchema
422
+
416
423
  if not use_pickle:
417
424
  base_path = Path(path).with_suffix(".replay").resolve()
418
- with open(base_path / "init_args.json", "r") as file:
425
+ with open(base_path / "init_args.json") as file:
419
426
  tokenizer_dict = json.loads(file.read())
420
427
 
421
428
  # load tensor_schema, tensor_features
@@ -500,7 +507,7 @@ class _BaseSequenceProcessor(Generic[_T]):
500
507
 
501
508
  def __init__(
502
509
  self,
503
- tensor_schema: TensorSchema,
510
+ tensor_schema: "TensorSchema",
504
511
  query_id_column: str,
505
512
  item_id_column: str,
506
513
  grouped_interactions: _T,
@@ -535,7 +542,7 @@ class _BaseSequenceProcessor(Generic[_T]):
535
542
  return self._process_num_feature(tensor_feature)
536
543
  assert False, "Unknown tensor feature type"
537
544
 
538
- def _process_num_feature(self, tensor_feature: TensorFeatureInfo) -> _T:
545
+ def _process_num_feature(self, tensor_feature: "TensorFeatureInfo") -> _T:
539
546
  """
540
547
  Process numerical tensor feature depends on it source.
541
548
  """
@@ -548,7 +555,7 @@ class _BaseSequenceProcessor(Generic[_T]):
548
555
  return self._process_num_item_feature(tensor_feature)
549
556
  assert False, "Unknown tensor feature source table"
550
557
 
551
- def _process_cat_feature(self, tensor_feature: TensorFeatureInfo) -> _T:
558
+ def _process_cat_feature(self, tensor_feature: "TensorFeatureInfo") -> _T:
552
559
  """
553
560
  Process categorical tensor feature depends on it source.
554
561
  """
@@ -562,27 +569,27 @@ class _BaseSequenceProcessor(Generic[_T]):
562
569
  assert False, "Unknown tensor feature source table"
563
570
 
564
571
  @abc.abstractmethod
565
- def _process_cat_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
572
+ def _process_cat_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
566
573
  pass
567
574
 
568
575
  @abc.abstractmethod
569
- def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
576
+ def _process_cat_query_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
570
577
  pass
571
578
 
572
579
  @abc.abstractmethod
573
- def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
580
+ def _process_cat_item_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
574
581
  pass
575
582
 
576
583
  @abc.abstractmethod
577
- def _process_num_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
584
+ def _process_num_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
578
585
  pass
579
586
 
580
587
  @abc.abstractmethod
581
- def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
588
+ def _process_num_query_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
582
589
  pass
583
590
 
584
591
  @abc.abstractmethod
585
- def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
592
+ def _process_num_item_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
586
593
  pass
587
594
 
588
595
 
@@ -597,7 +604,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
597
604
 
598
605
  def __init__(
599
606
  self,
600
- tensor_schema: TensorSchema,
607
+ tensor_schema: "TensorSchema",
601
608
  query_id_column: str,
602
609
  item_id_column: str,
603
610
  grouped_interactions: PandasDataFrame,
@@ -619,7 +626,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
619
626
  """
620
627
  :returns: processed Pandas DataFrame with all features from tensor schema.
621
628
  """
622
- all_features: Dict[str, Union[np.ndarray, List[np.ndarray]]] = {}
629
+ all_features: dict[str, Union[np.ndarray, list[np.ndarray]]] = {}
623
630
  all_features[self._query_id_column] = self._grouped_interactions[self._query_id_column].values
624
631
 
625
632
  for tensor_feature_name in self._tensor_schema:
@@ -628,8 +635,8 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
628
635
  return PandasDataFrame(all_features)
629
636
 
630
637
  def _process_num_interaction_feature(
631
- self, tensor_feature: TensorFeatureInfo
632
- ) -> Union[List[np.ndarray], List[List]]:
638
+ self, tensor_feature: "TensorFeatureInfo"
639
+ ) -> Union[list[np.ndarray], list[list]]:
633
640
  """
634
641
  Process numerical interaction feature.
635
642
 
@@ -650,7 +657,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
650
657
  values.append(np.array(sequence))
651
658
  return values
652
659
 
653
- def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> Union[List[np.ndarray], List[List]]:
660
+ def _process_num_item_feature(self, tensor_feature: "TensorFeatureInfo") -> Union[list[np.ndarray], list[list]]:
654
661
  """
655
662
  Process numerical feature from item features dataset.
656
663
 
@@ -676,7 +683,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
676
683
 
677
684
  return values
678
685
 
679
- def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
686
+ def _process_num_query_feature(self, tensor_feature: "TensorFeatureInfo") -> list[np.ndarray]:
680
687
  """
681
688
  Process numerical feature from query features dataset.
682
689
 
@@ -687,8 +694,8 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
687
694
  return self._process_cat_query_feature(tensor_feature)
688
695
 
689
696
  def _process_cat_interaction_feature(
690
- self, tensor_feature: TensorFeatureInfo
691
- ) -> Union[List[np.ndarray], List[List]]:
697
+ self, tensor_feature: "TensorFeatureInfo"
698
+ ) -> Union[list[np.ndarray], list[list]]:
692
699
  """
693
700
  Process categorical interaction feature.
694
701
 
@@ -709,7 +716,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
709
716
  values.append(np.array(sequence))
710
717
  return values
711
718
 
712
- def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
719
+ def _process_cat_query_feature(self, tensor_feature: "TensorFeatureInfo") -> list[np.ndarray]:
713
720
  """
714
721
  Process categorical feature from query features dataset.
715
722
 
@@ -738,7 +745,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
738
745
  ]
739
746
  return [np.array([query_feature[i]]).reshape(-1) for i in range(len(self._grouped_interactions))]
740
747
 
741
- def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> Union[List[np.ndarray], List[List]]:
748
+ def _process_cat_item_feature(self, tensor_feature: "TensorFeatureInfo") -> Union[list[np.ndarray], list[list]]:
742
749
  """
743
750
  Process categorical feature from item features dataset.
744
751
 
@@ -754,7 +761,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
754
761
  assert source is not None
755
762
 
756
763
  item_feature = self._item_features[source.column]
757
- values: List[np.ndarray] = []
764
+ values: list[np.ndarray] = []
758
765
 
759
766
  for item_id_sequence in self._grouped_interactions[self._item_id_column]:
760
767
  feature_sequence = item_feature.loc[item_id_sequence].values
@@ -784,7 +791,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
784
791
  data = data.join(self._process_feature(tensor_feature_name), on=self._query_id_column, how="left")
785
792
  return data
786
793
 
787
- def _process_num_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
794
+ def _process_num_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
788
795
  """
789
796
  Process numerical interaction feature.
790
797
 
@@ -794,7 +801,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
794
801
  """
795
802
  return self._process_cat_interaction_feature(tensor_feature)
796
803
 
797
- def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
804
+ def _process_num_query_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
798
805
  """
799
806
  Process numerical feature from query features dataset.
800
807
 
@@ -805,7 +812,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
805
812
  """
806
813
  return self._process_cat_query_feature(tensor_feature)
807
814
 
808
- def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
815
+ def _process_num_item_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
809
816
  """
810
817
  Process numerical feature from item features dataset.
811
818
 
@@ -816,7 +823,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
816
823
  """
817
824
  return self._process_cat_item_feature(tensor_feature)
818
825
 
819
- def _process_cat_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
826
+ def _process_cat_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
820
827
  """
821
828
  Process categorical interaction feature.
822
829
 
@@ -833,7 +840,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
833
840
  {source.column: tensor_feature.name}
834
841
  )
835
842
 
836
- def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
843
+ def _process_cat_query_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
837
844
  """
838
845
  Process categorical feature from query features dataset.
839
846
 
@@ -877,7 +884,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
877
884
  {source.column: tensor_feature.name}
878
885
  )
879
886
 
880
- def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
887
+ def _process_cat_item_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
881
888
  """
882
889
  Process categorical feature from item features dataset.
883
890