replay-rec 0.19.0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. replay/__init__.py +6 -2
  2. replay/data/dataset.py +9 -9
  3. replay/data/nn/__init__.py +6 -6
  4. replay/data/nn/sequence_tokenizer.py +44 -38
  5. replay/data/nn/sequential_dataset.py +13 -8
  6. replay/data/nn/torch_sequential_dataset.py +14 -13
  7. replay/data/nn/utils.py +1 -1
  8. replay/metrics/base_metric.py +1 -1
  9. replay/metrics/coverage.py +7 -11
  10. replay/metrics/experiment.py +3 -3
  11. replay/metrics/offline_metrics.py +2 -2
  12. replay/models/__init__.py +19 -0
  13. replay/models/association_rules.py +1 -4
  14. replay/models/base_neighbour_rec.py +6 -9
  15. replay/models/base_rec.py +44 -293
  16. replay/models/cat_pop_rec.py +2 -1
  17. replay/models/common.py +69 -0
  18. replay/models/extensions/ann/ann_mixin.py +30 -25
  19. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
  20. replay/models/extensions/ann/utils.py +4 -3
  21. replay/models/knn.py +18 -17
  22. replay/models/nn/sequential/bert4rec/dataset.py +1 -1
  23. replay/models/nn/sequential/callbacks/prediction_callbacks.py +2 -2
  24. replay/models/nn/sequential/compiled/__init__.py +10 -0
  25. replay/models/nn/sequential/compiled/base_compiled_model.py +3 -1
  26. replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
  27. replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
  28. replay/models/nn/sequential/sasrec/dataset.py +1 -1
  29. replay/models/nn/sequential/sasrec/model.py +1 -1
  30. replay/models/optimization/__init__.py +14 -0
  31. replay/models/optimization/optuna_mixin.py +279 -0
  32. replay/{optimization → models/optimization}/optuna_objective.py +13 -15
  33. replay/models/slim.py +2 -4
  34. replay/models/word2vec.py +7 -12
  35. replay/preprocessing/discretizer.py +1 -2
  36. replay/preprocessing/history_based_fp.py +1 -1
  37. replay/preprocessing/label_encoder.py +1 -1
  38. replay/splitters/cold_user_random_splitter.py +13 -7
  39. replay/splitters/last_n_splitter.py +17 -10
  40. replay/utils/__init__.py +6 -2
  41. replay/utils/common.py +4 -2
  42. replay/utils/model_handler.py +11 -31
  43. replay/utils/session_handler.py +2 -2
  44. replay/utils/spark_utils.py +2 -2
  45. replay/utils/types.py +28 -18
  46. replay/utils/warnings.py +26 -0
  47. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/METADATA +56 -32
  48. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/RECORD +51 -47
  49. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/WHEEL +1 -1
  50. replay_rec-0.20.0.dist-info/licenses/NOTICE +41 -0
  51. replay/optimization/__init__.py +0 -5
  52. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info/licenses}/LICENSE +0 -0
replay/__init__.py CHANGED
@@ -1,3 +1,7 @@
1
- """ RecSys library """
1
+ """RecSys library"""
2
2
 
3
- __version__ = "0.19.0"
3
+ # NOTE: This ensures distutils monkey-patching is performed before any
4
+ # functionality removed in Python 3.12 is used in downstream packages (like lightfm)
5
+ import setuptools as _
6
+
7
+ __version__ = "0.20.0"
replay/data/dataset.py CHANGED
@@ -610,9 +610,9 @@ class Dataset:
610
610
  if self.is_pandas:
611
611
  try:
612
612
  data[column] = data[column].astype(int)
613
- except Exception:
613
+ except Exception as exc:
614
614
  msg = f"IDs in {source.name}.{column} are not encoded. They are not int."
615
- raise ValueError(msg)
615
+ raise ValueError(msg) from exc
616
616
 
617
617
  if self.is_pandas:
618
618
  is_int = np.issubdtype(dict(data.dtypes)[column], int)
@@ -775,10 +775,10 @@ def check_dataframes_types_equal(dataframe: DataFrameLike, other: DataFrameLike)
775
775
 
776
776
  :returns: True if dataframes have same type.
777
777
  """
778
- if isinstance(dataframe, PandasDataFrame) and isinstance(other, PandasDataFrame):
779
- return True
780
- if isinstance(dataframe, SparkDataFrame) and isinstance(other, SparkDataFrame):
781
- return True
782
- if isinstance(dataframe, PolarsDataFrame) and isinstance(other, PolarsDataFrame):
783
- return True
784
- return False
778
+ return any(
779
+ [
780
+ isinstance(dataframe, PandasDataFrame) and isinstance(other, PandasDataFrame),
781
+ isinstance(dataframe, SparkDataFrame) and isinstance(other, SparkDataFrame),
782
+ isinstance(dataframe, PolarsDataFrame) and isinstance(other, PolarsDataFrame),
783
+ ]
784
+ )
@@ -14,17 +14,17 @@ if TORCH_AVAILABLE:
14
14
  )
15
15
 
16
16
  __all__ = [
17
+ "DEFAULT_GROUND_TRUTH_PADDING_VALUE",
18
+ "DEFAULT_TRAIN_PADDING_VALUE",
17
19
  "MutableTensorMap",
20
+ "PandasSequentialDataset",
21
+ "PolarsSequentialDataset",
22
+ "SequenceTokenizer",
23
+ "SequentialDataset",
18
24
  "TensorFeatureInfo",
19
25
  "TensorFeatureSource",
20
26
  "TensorMap",
21
27
  "TensorSchema",
22
- "SequenceTokenizer",
23
- "PandasSequentialDataset",
24
- "PolarsSequentialDataset",
25
- "SequentialDataset",
26
- "DEFAULT_GROUND_TRUTH_PADDING_VALUE",
27
- "DEFAULT_TRAIN_PADDING_VALUE",
28
28
  "TorchSequentialBatch",
29
29
  "TorchSequentialDataset",
30
30
  "TorchSequentialValidationBatch",
@@ -3,7 +3,7 @@ import json
3
3
  import pickle
4
4
  import warnings
5
5
  from pathlib import Path
6
- from typing import Dict, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union
6
+ from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union
7
7
 
8
8
  import numpy as np
9
9
  import polars as pl
@@ -14,11 +14,11 @@ from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, Feat
14
14
  from replay.data.dataset_utils import DatasetLabelEncoder
15
15
  from replay.preprocessing import LabelEncoder, LabelEncodingRule
16
16
  from replay.preprocessing.label_encoder import HandleUnknownStrategies
17
- from replay.utils.model_handler import deprecation_warning
17
+ from replay.utils import deprecation_warning
18
18
 
19
- from .schema import TensorFeatureInfo, TensorFeatureSource, TensorSchema
20
- from .sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset, SequentialDataset
21
- from .utils import ensure_pandas, groupby_sequences
19
+ if TYPE_CHECKING:
20
+ from .schema import TensorFeatureInfo, TensorFeatureSource, TensorSchema
21
+ from .sequential_dataset import SequentialDataset
22
22
 
23
23
  SequenceDataFrameLike = Union[PandasDataFrame, PolarsDataFrame]
24
24
  _T = TypeVar("_T")
@@ -34,7 +34,7 @@ class SequenceTokenizer:
34
34
 
35
35
  def __init__(
36
36
  self,
37
- tensor_schema: TensorSchema,
37
+ tensor_schema: "TensorSchema",
38
38
  handle_unknown_rule: HandleUnknownStrategies = "error",
39
39
  default_value_rule: Optional[Union[int, str]] = None,
40
40
  allow_collect_to_master: bool = False,
@@ -77,7 +77,7 @@ class SequenceTokenizer:
77
77
  self,
78
78
  dataset: Dataset,
79
79
  tensor_features_to_keep: Optional[Sequence[str]] = None,
80
- ) -> SequentialDataset:
80
+ ) -> "SequentialDataset":
81
81
  """
82
82
  :param dataset: input dataset to transform
83
83
  :param tensor_features_to_keep: specified feature names to transform
@@ -89,7 +89,7 @@ class SequenceTokenizer:
89
89
  def fit_transform(
90
90
  self,
91
91
  dataset: Dataset,
92
- ) -> SequentialDataset:
92
+ ) -> "SequentialDataset":
93
93
  """
94
94
  :param dataset: input dataset to transform
95
95
  :returns: SequentialDataset
@@ -97,7 +97,7 @@ class SequenceTokenizer:
97
97
  return self.fit(dataset)._transform_unchecked(dataset)
98
98
 
99
99
  @property
100
- def tensor_schema(self) -> TensorSchema:
100
+ def tensor_schema(self) -> "TensorSchema":
101
101
  """
102
102
  :returns: tensor schema
103
103
  """
@@ -149,7 +149,9 @@ class SequenceTokenizer:
149
149
  self,
150
150
  dataset: Dataset,
151
151
  tensor_features_to_keep: Optional[Sequence[str]] = None,
152
- ) -> SequentialDataset:
152
+ ) -> "SequentialDataset":
153
+ from replay.data.nn.sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset
154
+
153
155
  schema = self._tensor_schema
154
156
  if tensor_features_to_keep is not None:
155
157
  schema = schema.subset(tensor_features_to_keep)
@@ -186,6 +188,8 @@ class SequenceTokenizer:
186
188
  self,
187
189
  dataset: Dataset,
188
190
  ) -> Tuple[SequenceDataFrameLike, Optional[SequenceDataFrameLike], Optional[SequenceDataFrameLike]]:
191
+ from replay.data.nn.utils import ensure_pandas, groupby_sequences
192
+
189
193
  grouped_interactions = groupby_sequences(
190
194
  events=dataset.interactions,
191
195
  groupby_col=dataset.feature_schema.query_id_column,
@@ -218,7 +222,7 @@ class SequenceTokenizer:
218
222
 
219
223
  def _make_sequence_features(
220
224
  self,
221
- schema: TensorSchema,
225
+ schema: "TensorSchema",
222
226
  feature_schema: FeatureSchema,
223
227
  grouped_interactions: SequenceDataFrameLike,
224
228
  query_features: Optional[SequenceDataFrameLike],
@@ -242,7 +246,7 @@ class SequenceTokenizer:
242
246
  def _match_features_with_tensor_schema(
243
247
  cls,
244
248
  dataset: Dataset,
245
- tensor_schema: TensorSchema,
249
+ tensor_schema: "TensorSchema",
246
250
  ) -> Dataset:
247
251
  feature_subset_filter = cls._get_features_filter_from_schema(
248
252
  tensor_schema,
@@ -261,7 +265,7 @@ class SequenceTokenizer:
261
265
  @classmethod
262
266
  def _get_features_filter_from_schema(
263
267
  cls,
264
- tensor_schema: TensorSchema,
268
+ tensor_schema: "TensorSchema",
265
269
  query_id_column: str,
266
270
  item_id_column: str,
267
271
  ) -> Set[str]:
@@ -291,7 +295,7 @@ class SequenceTokenizer:
291
295
  return set(features_subset)
292
296
 
293
297
  @classmethod
294
- def _check_tensor_schema(cls, tensor_schema: TensorSchema) -> None:
298
+ def _check_tensor_schema(cls, tensor_schema: "TensorSchema") -> None:
295
299
  # Check consistency of sequential features
296
300
  for tensor_feature in tensor_schema.all_features:
297
301
  feature_sources = tensor_feature.feature_sources
@@ -319,11 +323,11 @@ class SequenceTokenizer:
319
323
  def _check_if_tensor_schema_matches_data( # noqa: C901
320
324
  cls,
321
325
  dataset: Dataset,
322
- tensor_schema: TensorSchema,
326
+ tensor_schema: "TensorSchema",
323
327
  tensor_features_to_keep: Optional[Sequence[str]] = None,
324
328
  ) -> None:
325
329
  # Check if all source columns specified in tensor schema exist in provided data frames
326
- sources_for_tensors: List[TensorFeatureSource] = []
330
+ sources_for_tensors: List["TensorFeatureSource"] = []
327
331
  for tensor_feature_name, tensor_feature in tensor_schema.items():
328
332
  if tensor_features_to_keep is not None and tensor_feature_name not in tensor_features_to_keep:
329
333
  continue
@@ -413,6 +417,8 @@ class SequenceTokenizer:
413
417
 
414
418
  :returns: Loaded tokenizer object.
415
419
  """
420
+ from replay.data.nn import TensorSchema
421
+
416
422
  if not use_pickle:
417
423
  base_path = Path(path).with_suffix(".replay").resolve()
418
424
  with open(base_path / "init_args.json", "r") as file:
@@ -500,7 +506,7 @@ class _BaseSequenceProcessor(Generic[_T]):
500
506
 
501
507
  def __init__(
502
508
  self,
503
- tensor_schema: TensorSchema,
509
+ tensor_schema: "TensorSchema",
504
510
  query_id_column: str,
505
511
  item_id_column: str,
506
512
  grouped_interactions: _T,
@@ -535,7 +541,7 @@ class _BaseSequenceProcessor(Generic[_T]):
535
541
  return self._process_num_feature(tensor_feature)
536
542
  assert False, "Unknown tensor feature type"
537
543
 
538
- def _process_num_feature(self, tensor_feature: TensorFeatureInfo) -> _T:
544
+ def _process_num_feature(self, tensor_feature: "TensorFeatureInfo") -> _T:
539
545
  """
540
546
  Process numerical tensor feature depends on it source.
541
547
  """
@@ -548,7 +554,7 @@ class _BaseSequenceProcessor(Generic[_T]):
548
554
  return self._process_num_item_feature(tensor_feature)
549
555
  assert False, "Unknown tensor feature source table"
550
556
 
551
- def _process_cat_feature(self, tensor_feature: TensorFeatureInfo) -> _T:
557
+ def _process_cat_feature(self, tensor_feature: "TensorFeatureInfo") -> _T:
552
558
  """
553
559
  Process categorical tensor feature depends on it source.
554
560
  """
@@ -562,27 +568,27 @@ class _BaseSequenceProcessor(Generic[_T]):
562
568
  assert False, "Unknown tensor feature source table"
563
569
 
564
570
  @abc.abstractmethod
565
- def _process_cat_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
571
+ def _process_cat_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
566
572
  pass
567
573
 
568
574
  @abc.abstractmethod
569
- def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
575
+ def _process_cat_query_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
570
576
  pass
571
577
 
572
578
  @abc.abstractmethod
573
- def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
579
+ def _process_cat_item_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
574
580
  pass
575
581
 
576
582
  @abc.abstractmethod
577
- def _process_num_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
583
+ def _process_num_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
578
584
  pass
579
585
 
580
586
  @abc.abstractmethod
581
- def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
587
+ def _process_num_query_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
582
588
  pass
583
589
 
584
590
  @abc.abstractmethod
585
- def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> _T: # pragma: no cover
591
+ def _process_num_item_feature(self, tensor_feature: "TensorFeatureInfo") -> _T: # pragma: no cover
586
592
  pass
587
593
 
588
594
 
@@ -597,7 +603,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
597
603
 
598
604
  def __init__(
599
605
  self,
600
- tensor_schema: TensorSchema,
606
+ tensor_schema: "TensorSchema",
601
607
  query_id_column: str,
602
608
  item_id_column: str,
603
609
  grouped_interactions: PandasDataFrame,
@@ -628,7 +634,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
628
634
  return PandasDataFrame(all_features)
629
635
 
630
636
  def _process_num_interaction_feature(
631
- self, tensor_feature: TensorFeatureInfo
637
+ self, tensor_feature: "TensorFeatureInfo"
632
638
  ) -> Union[List[np.ndarray], List[List]]:
633
639
  """
634
640
  Process numerical interaction feature.
@@ -650,7 +656,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
650
656
  values.append(np.array(sequence))
651
657
  return values
652
658
 
653
- def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> Union[List[np.ndarray], List[List]]:
659
+ def _process_num_item_feature(self, tensor_feature: "TensorFeatureInfo") -> Union[List[np.ndarray], List[List]]:
654
660
  """
655
661
  Process numerical feature from item features dataset.
656
662
 
@@ -676,7 +682,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
676
682
 
677
683
  return values
678
684
 
679
- def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
685
+ def _process_num_query_feature(self, tensor_feature: "TensorFeatureInfo") -> List[np.ndarray]:
680
686
  """
681
687
  Process numerical feature from query features dataset.
682
688
 
@@ -687,7 +693,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
687
693
  return self._process_cat_query_feature(tensor_feature)
688
694
 
689
695
  def _process_cat_interaction_feature(
690
- self, tensor_feature: TensorFeatureInfo
696
+ self, tensor_feature: "TensorFeatureInfo"
691
697
  ) -> Union[List[np.ndarray], List[List]]:
692
698
  """
693
699
  Process categorical interaction feature.
@@ -709,7 +715,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
709
715
  values.append(np.array(sequence))
710
716
  return values
711
717
 
712
- def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
718
+ def _process_cat_query_feature(self, tensor_feature: "TensorFeatureInfo") -> List[np.ndarray]:
713
719
  """
714
720
  Process categorical feature from query features dataset.
715
721
 
@@ -738,7 +744,7 @@ class _PandasSequenceProcessor(_BaseSequenceProcessor[PandasDataFrame]):
738
744
  ]
739
745
  return [np.array([query_feature[i]]).reshape(-1) for i in range(len(self._grouped_interactions))]
740
746
 
741
- def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> Union[List[np.ndarray], List[List]]:
747
+ def _process_cat_item_feature(self, tensor_feature: "TensorFeatureInfo") -> Union[List[np.ndarray], List[List]]:
742
748
  """
743
749
  Process categorical feature from item features dataset.
744
750
 
@@ -784,7 +790,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
784
790
  data = data.join(self._process_feature(tensor_feature_name), on=self._query_id_column, how="left")
785
791
  return data
786
792
 
787
- def _process_num_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
793
+ def _process_num_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
788
794
  """
789
795
  Process numerical interaction feature.
790
796
 
@@ -794,7 +800,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
794
800
  """
795
801
  return self._process_cat_interaction_feature(tensor_feature)
796
802
 
797
- def _process_num_query_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
803
+ def _process_num_query_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
798
804
  """
799
805
  Process numerical feature from query features dataset.
800
806
 
@@ -805,7 +811,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
805
811
  """
806
812
  return self._process_cat_query_feature(tensor_feature)
807
813
 
808
- def _process_num_item_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
814
+ def _process_num_item_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
809
815
  """
810
816
  Process numerical feature from item features dataset.
811
817
 
@@ -816,7 +822,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
816
822
  """
817
823
  return self._process_cat_item_feature(tensor_feature)
818
824
 
819
- def _process_cat_interaction_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
825
+ def _process_cat_interaction_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
820
826
  """
821
827
  Process categorical interaction feature.
822
828
 
@@ -833,7 +839,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
833
839
  {source.column: tensor_feature.name}
834
840
  )
835
841
 
836
- def _process_cat_query_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
842
+ def _process_cat_query_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
837
843
  """
838
844
  Process categorical feature from query features dataset.
839
845
 
@@ -877,7 +883,7 @@ class _PolarsSequenceProcessor(_BaseSequenceProcessor[PolarsDataFrame]):
877
883
  {source.column: tensor_feature.name}
878
884
  )
879
885
 
880
- def _process_cat_item_feature(self, tensor_feature: TensorFeatureInfo) -> PolarsDataFrame:
886
+ def _process_cat_item_feature(self, tensor_feature: "TensorFeatureInfo") -> PolarsDataFrame:
881
887
  """
882
888
  Process categorical feature from item features dataset.
883
889
 
@@ -1,7 +1,7 @@
1
1
  import abc
2
2
  import json
3
3
  from pathlib import Path
4
- from typing import Tuple, Union
4
+ from typing import TYPE_CHECKING, Tuple, Union
5
5
 
6
6
  import numpy as np
7
7
  import pandas as pd
@@ -9,7 +9,8 @@ import polars as pl
9
9
  from pandas import DataFrame as PandasDataFrame
10
10
  from polars import DataFrame as PolarsDataFrame
11
11
 
12
- from .schema import TensorSchema
12
+ if TYPE_CHECKING:
13
+ from .schema import TensorSchema
13
14
 
14
15
 
15
16
  class SequentialDataset(abc.ABC):
@@ -81,7 +82,7 @@ class SequentialDataset(abc.ABC):
81
82
 
82
83
  @property
83
84
  @abc.abstractmethod
84
- def schema(self) -> TensorSchema: # pragma: no cover
85
+ def schema(self) -> "TensorSchema": # pragma: no cover
85
86
  """
86
87
  :returns: List of tensor features.
87
88
  """
@@ -128,7 +129,7 @@ class PandasSequentialDataset(SequentialDataset):
128
129
 
129
130
  def __init__(
130
131
  self,
131
- tensor_schema: TensorSchema,
132
+ tensor_schema: "TensorSchema",
132
133
  query_id_column: str,
133
134
  item_id_column: str,
134
135
  sequences: PandasDataFrame,
@@ -184,11 +185,11 @@ class PandasSequentialDataset(SequentialDataset):
184
185
  )
185
186
 
186
187
  @property
187
- def schema(self) -> TensorSchema:
188
+ def schema(self) -> "TensorSchema":
188
189
  return self._tensor_schema
189
190
 
190
191
  @classmethod
191
- def _check_if_schema_matches_data(cls, tensor_schema: TensorSchema, data: PandasDataFrame) -> None:
192
+ def _check_if_schema_matches_data(cls, tensor_schema: "TensorSchema", data: PandasDataFrame) -> None:
192
193
  for tensor_feature_name in tensor_schema:
193
194
  if tensor_feature_name not in data:
194
195
  msg = "Tensor schema does not match with provided data frame"
@@ -199,6 +200,8 @@ class PandasSequentialDataset(SequentialDataset):
199
200
  """
200
201
  Method for loading PandasSequentialDataset object from `.replay` directory.
201
202
  """
203
+ from replay.data.nn import TensorSchema
204
+
202
205
  base_path = Path(path).with_suffix(".replay").resolve()
203
206
  with open(base_path / "init_args.json", "r") as file:
204
207
  sequential_dict = json.loads(file.read())
@@ -221,7 +224,7 @@ class PolarsSequentialDataset(PandasSequentialDataset):
221
224
 
222
225
  def __init__(
223
226
  self,
224
- tensor_schema: TensorSchema,
227
+ tensor_schema: "TensorSchema",
225
228
  query_id_column: str,
226
229
  item_id_column: str,
227
230
  sequences: PolarsDataFrame,
@@ -270,7 +273,7 @@ class PolarsSequentialDataset(PandasSequentialDataset):
270
273
  return pl.from_dict(df.to_dict("list"))
271
274
 
272
275
  @classmethod
273
- def _check_if_schema_matches_data(cls, tensor_schema: TensorSchema, data: PolarsDataFrame) -> None:
276
+ def _check_if_schema_matches_data(cls, tensor_schema: "TensorSchema", data: PolarsDataFrame) -> None:
274
277
  for tensor_feature_name in tensor_schema:
275
278
  if tensor_feature_name not in data:
276
279
  msg = "Tensor schema does not match with provided data frame"
@@ -281,6 +284,8 @@ class PolarsSequentialDataset(PandasSequentialDataset):
281
284
  """
282
285
  Method for loading PandasSequentialDataset object from `.replay` directory.
283
286
  """
287
+ from replay.data.nn import TensorSchema
288
+
284
289
  base_path = Path(path).with_suffix(".replay").resolve()
285
290
  with open(base_path / "init_args.json", "r") as file:
286
291
  sequential_dict = json.loads(file.read())
@@ -1,13 +1,14 @@
1
- from typing import Generator, NamedTuple, Optional, Sequence, Tuple, Union, cast
1
+ from typing import TYPE_CHECKING, Generator, NamedTuple, Optional, Sequence, Tuple, Union, cast
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
5
  from torch.utils.data import Dataset as TorchDataset
6
6
 
7
- from replay.utils.model_handler import deprecation_warning
7
+ from replay.utils import deprecation_warning
8
8
 
9
- from .schema import TensorFeatureInfo, TensorMap, TensorSchema
10
- from .sequential_dataset import SequentialDataset
9
+ if TYPE_CHECKING:
10
+ from .schema import TensorFeatureInfo, TensorMap, TensorSchema
11
+ from .sequential_dataset import SequentialDataset
11
12
 
12
13
 
13
14
  # We do not use dataclasses as PyTorch default collate
@@ -19,7 +20,7 @@ class TorchSequentialBatch(NamedTuple):
19
20
 
20
21
  query_id: torch.LongTensor
21
22
  padding_mask: torch.BoolTensor
22
- features: TensorMap
23
+ features: "TensorMap"
23
24
 
24
25
 
25
26
  class TorchSequentialDataset(TorchDataset):
@@ -33,7 +34,7 @@ class TorchSequentialDataset(TorchDataset):
33
34
  )
34
35
  def __init__(
35
36
  self,
36
- sequential: SequentialDataset,
37
+ sequential: "SequentialDataset",
37
38
  max_sequence_length: int,
38
39
  sliding_window_step: Optional[int] = None,
39
40
  padding_value: int = 0,
@@ -89,7 +90,7 @@ class TorchSequentialDataset(TorchDataset):
89
90
 
90
91
  def _generate_tensor_feature(
91
92
  self,
92
- feature: TensorFeatureInfo,
93
+ feature: "TensorFeatureInfo",
93
94
  sequence_index: int,
94
95
  sequence_offset: int,
95
96
  ) -> torch.Tensor:
@@ -161,7 +162,7 @@ class TorchSequentialValidationBatch(NamedTuple):
161
162
 
162
163
  query_id: torch.LongTensor
163
164
  padding_mask: torch.BoolTensor
164
- features: TensorMap
165
+ features: "TensorMap"
165
166
  ground_truth: torch.LongTensor
166
167
  train: torch.LongTensor
167
168
 
@@ -181,9 +182,9 @@ class TorchSequentialValidationDataset(TorchDataset):
181
182
  )
182
183
  def __init__(
183
184
  self,
184
- sequential: SequentialDataset,
185
- ground_truth: SequentialDataset,
186
- train: SequentialDataset,
185
+ sequential: "SequentialDataset",
186
+ ground_truth: "SequentialDataset",
187
+ train: "SequentialDataset",
187
188
  max_sequence_length: int,
188
189
  padding_value: int = 0,
189
190
  sliding_window_step: Optional[int] = None,
@@ -280,8 +281,8 @@ class TorchSequentialValidationDataset(TorchDataset):
280
281
  @classmethod
281
282
  def _check_if_schema_match(
282
283
  cls,
283
- sequential_schema: TensorSchema,
284
- ground_truth_schema: TensorSchema,
284
+ sequential_schema: "TensorSchema",
285
+ ground_truth_schema: "TensorSchema",
285
286
  ) -> None:
286
287
  sequential_item_feature = sequential_schema.item_id_features.item()
287
288
  ground_truth_item_feature = ground_truth_schema.item_id_features.item()
replay/data/nn/utils.py CHANGED
@@ -30,7 +30,7 @@ def groupby_sequences(events: DataFrameLike, groupby_col: str, sort_col: Optiona
30
30
  events = events.sort_values(event_cols_without_iterable)
31
31
 
32
32
  grouped_sequences = (
33
- events.groupby(groupby_col).agg({col: list for col in event_cols_without_groupby}).reset_index()
33
+ events.groupby(groupby_col).agg(dict.fromkeys(event_cols_without_groupby, list)).reset_index()
34
34
  )
35
35
  elif isinstance(events, PolarsDataFrame):
36
36
  event_cols_without_groupby = events.columns
@@ -145,7 +145,7 @@ class Metric(ABC):
145
145
 
146
146
  def _convert_pandas_to_dict_with_score(self, data: PandasDataFrame) -> Dict:
147
147
  return (
148
- data.sort_values(by=self.rating_column, ascending=False)
148
+ data.sort_values(by=[self.rating_column, self.item_column], ascending=False, kind="stable")
149
149
  .groupby(self.query_column)[self.item_column]
150
150
  .apply(list)
151
151
  .to_dict()
@@ -102,20 +102,16 @@ class Coverage(Metric):
102
102
  return grouped_recs
103
103
 
104
104
  def _get_enriched_recommendations_polars(self, recommendations: PolarsDataFrame) -> PolarsDataFrame:
105
- sorted_by_score_recommendations = recommendations.select(
106
- pl.all().sort_by(self.rating_column, descending=True).over(self.query_column)
107
- )
108
- sorted_by_score_recommendations = sorted_by_score_recommendations.with_columns(
109
- sorted_by_score_recommendations.select(
110
- pl.col(self.query_column).cum_count().over(self.query_column).alias("rank")
105
+ return (
106
+ recommendations.with_columns(
107
+ pl.col(self.rating_column)
108
+ .rank(method="ordinal", descending=True)
109
+ .over(self.query_column)
110
+ .alias("__rank")
111
111
  )
112
- )
113
- grouped_recs = (
114
- sorted_by_score_recommendations.select(self.item_column, "rank")
115
112
  .group_by(self.item_column)
116
- .agg(pl.col("rank").min().alias("best_position"))
113
+ .agg(pl.col("__rank").min().alias("best_position"))
117
114
  )
118
- return grouped_recs
119
115
 
120
116
  def _spark_compute(self, recs: SparkDataFrame, train: SparkDataFrame) -> MetricsMeanReturnType:
121
117
  """
@@ -84,12 +84,12 @@ class Experiment:
84
84
  >>> ex.add_result("model", recommendations)
85
85
  >>> ex.results
86
86
  NDCG@2 NDCG@3 Surprisal@3
87
- baseline 0.204382 0.234639 0.608476
88
- model 0.333333 0.489760 0.719587
87
+ baseline 0.333333 0.25512 0.608476
88
+ model 0.333333 0.48976 0.719587
89
89
  >>> ex.compare("baseline")
90
90
  NDCG@2 NDCG@3 Surprisal@3
91
91
  baseline – – –
92
- model 63.09% 108.73% 18.26%
92
+ model 0.0% 91.97% 18.26%
93
93
  >>> ex = Experiment([Precision(3, mode=Median()), Precision(3, mode=ConfidenceInterval(0.95))], groundtruth)
94
94
  >>> ex.add_result("baseline", base_rec)
95
95
  >>> ex.add_result("model", recommendations)
@@ -121,11 +121,11 @@ class OfflineMetrics:
121
121
  ... base_recommendations={"ALS": base_rec, "KNN": recommendations}
122
122
  ... )
123
123
  {'Precision@2': 0.3333333333333333,
124
- 'Unexpectedness_ALS@1': 0.3333333333333333,
124
+ 'Unexpectedness_ALS@1': 0.6666666666666666,
125
125
  'Unexpectedness_ALS@2': 0.16666666666666666,
126
126
  'Unexpectedness_KNN@1': 0.0,
127
127
  'Unexpectedness_KNN@2': 0.0,
128
- 'Unexpectedness-PerUser_ALS@1': {1: 1.0, 2: 0.0, 3: 0.0},
128
+ 'Unexpectedness-PerUser_ALS@1': {1: 1.0, 2: 1.0, 3: 0.0},
129
129
  'Unexpectedness-PerUser_ALS@2': {1: 0.5, 2: 0.0, 3: 0.0},
130
130
  'Unexpectedness-PerUser_KNN@1': {1: 0.0, 2: 0.0, 3: 0.0},
131
131
  'Unexpectedness-PerUser_KNN@2': {1: 0.0, 2: 0.0, 3: 0.0}}
replay/models/__init__.py CHANGED
@@ -23,3 +23,22 @@ from .thompson_sampling import ThompsonSampling
23
23
  from .ucb import UCB
24
24
  from .wilson import Wilson
25
25
  from .word2vec import Word2VecRec
26
+
27
+ __all__ = [
28
+ "KLUCB",
29
+ "SLIM",
30
+ "UCB",
31
+ "ALSWrap",
32
+ "AssociationRulesItemRec",
33
+ "CatPopRec",
34
+ "ClusterRec",
35
+ "ItemKNN",
36
+ "LinUCB",
37
+ "PopRec",
38
+ "QueryPopRec",
39
+ "RandomRec",
40
+ "Recommender",
41
+ "ThompsonSampling",
42
+ "Wilson",
43
+ "Word2VecRec",
44
+ ]
@@ -142,6 +142,7 @@ class AssociationRulesItemRec(NeighbourRec):
142
142
  :param index_builder: `IndexBuilder` instance that adds ANN functionality.
143
143
  If not set, then ann will not be used.
144
144
  """
145
+ self.init_index_builder(index_builder)
145
146
 
146
147
  self.session_column = session_column
147
148
  self.min_item_count = min_item_count
@@ -149,10 +150,6 @@ class AssociationRulesItemRec(NeighbourRec):
149
150
  self.num_neighbours = num_neighbours
150
151
  self.use_rating = use_rating
151
152
  self.similarity_metric = similarity_metric
152
- if isinstance(index_builder, (IndexBuilder, type(None))):
153
- self.index_builder = index_builder
154
- elif isinstance(index_builder, dict):
155
- self.init_builder_from_dict(index_builder)
156
153
 
157
154
  @property
158
155
  def _init_args(self):