replay-rec 0.20.0__py3-none-any.whl → 0.20.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +10 -9
  3. replay/data/dataset_utils/dataset_label_encoder.py +5 -4
  4. replay/data/nn/schema.py +9 -18
  5. replay/data/nn/sequence_tokenizer.py +16 -15
  6. replay/data/nn/sequential_dataset.py +4 -4
  7. replay/data/nn/torch_sequential_dataset.py +5 -4
  8. replay/data/nn/utils.py +2 -1
  9. replay/data/schema.py +3 -12
  10. replay/experimental/__init__.py +0 -0
  11. replay/experimental/metrics/__init__.py +62 -0
  12. replay/experimental/metrics/base_metric.py +603 -0
  13. replay/experimental/metrics/coverage.py +97 -0
  14. replay/experimental/metrics/experiment.py +175 -0
  15. replay/experimental/metrics/hitrate.py +26 -0
  16. replay/experimental/metrics/map.py +30 -0
  17. replay/experimental/metrics/mrr.py +18 -0
  18. replay/experimental/metrics/ncis_precision.py +31 -0
  19. replay/experimental/metrics/ndcg.py +49 -0
  20. replay/experimental/metrics/precision.py +22 -0
  21. replay/experimental/metrics/recall.py +25 -0
  22. replay/experimental/metrics/rocauc.py +49 -0
  23. replay/experimental/metrics/surprisal.py +90 -0
  24. replay/experimental/metrics/unexpectedness.py +76 -0
  25. replay/experimental/models/__init__.py +50 -0
  26. replay/experimental/models/admm_slim.py +257 -0
  27. replay/experimental/models/base_neighbour_rec.py +200 -0
  28. replay/experimental/models/base_rec.py +1386 -0
  29. replay/experimental/models/base_torch_rec.py +234 -0
  30. replay/experimental/models/cql.py +454 -0
  31. replay/experimental/models/ddpg.py +932 -0
  32. replay/experimental/models/dt4rec/__init__.py +0 -0
  33. replay/experimental/models/dt4rec/dt4rec.py +189 -0
  34. replay/experimental/models/dt4rec/gpt1.py +401 -0
  35. replay/experimental/models/dt4rec/trainer.py +127 -0
  36. replay/experimental/models/dt4rec/utils.py +264 -0
  37. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  38. replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
  39. replay/experimental/models/hierarchical_recommender.py +331 -0
  40. replay/experimental/models/implicit_wrap.py +131 -0
  41. replay/experimental/models/lightfm_wrap.py +303 -0
  42. replay/experimental/models/mult_vae.py +332 -0
  43. replay/experimental/models/neural_ts.py +986 -0
  44. replay/experimental/models/neuromf.py +406 -0
  45. replay/experimental/models/scala_als.py +293 -0
  46. replay/experimental/models/u_lin_ucb.py +115 -0
  47. replay/experimental/nn/data/__init__.py +1 -0
  48. replay/experimental/nn/data/schema_builder.py +102 -0
  49. replay/experimental/preprocessing/__init__.py +3 -0
  50. replay/experimental/preprocessing/data_preparator.py +839 -0
  51. replay/experimental/preprocessing/padder.py +229 -0
  52. replay/experimental/preprocessing/sequence_generator.py +208 -0
  53. replay/experimental/scenarios/__init__.py +1 -0
  54. replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  55. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
  56. replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
  57. replay/experimental/scenarios/obp_wrapper/utils.py +85 -0
  58. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  59. replay/experimental/scenarios/two_stages/reranker.py +117 -0
  60. replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
  61. replay/experimental/utils/__init__.py +0 -0
  62. replay/experimental/utils/logger.py +24 -0
  63. replay/experimental/utils/model_handler.py +186 -0
  64. replay/experimental/utils/session_handler.py +44 -0
  65. replay/metrics/base_metric.py +11 -10
  66. replay/metrics/categorical_diversity.py +8 -8
  67. replay/metrics/coverage.py +4 -4
  68. replay/metrics/experiment.py +3 -3
  69. replay/metrics/hitrate.py +1 -3
  70. replay/metrics/map.py +1 -3
  71. replay/metrics/mrr.py +1 -3
  72. replay/metrics/ndcg.py +1 -2
  73. replay/metrics/novelty.py +3 -3
  74. replay/metrics/offline_metrics.py +16 -16
  75. replay/metrics/precision.py +1 -3
  76. replay/metrics/recall.py +1 -3
  77. replay/metrics/rocauc.py +1 -3
  78. replay/metrics/surprisal.py +4 -4
  79. replay/metrics/torch_metrics_builder.py +13 -12
  80. replay/metrics/unexpectedness.py +2 -2
  81. replay/models/als.py +2 -2
  82. replay/models/association_rules.py +4 -3
  83. replay/models/base_neighbour_rec.py +3 -2
  84. replay/models/base_rec.py +11 -10
  85. replay/models/cat_pop_rec.py +2 -1
  86. replay/models/extensions/ann/ann_mixin.py +2 -1
  87. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
  88. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
  89. replay/models/lin_ucb.py +3 -3
  90. replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
  91. replay/models/nn/sequential/bert4rec/dataset.py +2 -2
  92. replay/models/nn/sequential/bert4rec/lightning.py +3 -3
  93. replay/models/nn/sequential/bert4rec/model.py +2 -2
  94. replay/models/nn/sequential/callbacks/prediction_callbacks.py +12 -12
  95. replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
  96. replay/models/nn/sequential/compiled/base_compiled_model.py +5 -5
  97. replay/models/nn/sequential/postprocessors/_base.py +2 -3
  98. replay/models/nn/sequential/postprocessors/postprocessors.py +10 -10
  99. replay/models/nn/sequential/sasrec/lightning.py +3 -3
  100. replay/models/nn/sequential/sasrec/model.py +8 -8
  101. replay/models/slim.py +2 -2
  102. replay/models/ucb.py +2 -2
  103. replay/models/word2vec.py +3 -3
  104. replay/preprocessing/discretizer.py +8 -7
  105. replay/preprocessing/filters.py +4 -4
  106. replay/preprocessing/history_based_fp.py +6 -6
  107. replay/preprocessing/label_encoder.py +8 -7
  108. replay/scenarios/fallback.py +4 -3
  109. replay/splitters/base_splitter.py +3 -3
  110. replay/splitters/cold_user_random_splitter.py +4 -4
  111. replay/splitters/k_folds.py +4 -4
  112. replay/splitters/last_n_splitter.py +10 -10
  113. replay/splitters/new_users_splitter.py +4 -4
  114. replay/splitters/random_splitter.py +4 -4
  115. replay/splitters/ratio_splitter.py +10 -10
  116. replay/splitters/time_splitter.py +6 -6
  117. replay/splitters/two_stage_splitter.py +4 -4
  118. replay/utils/__init__.py +1 -0
  119. replay/utils/common.py +1 -1
  120. replay/utils/session_handler.py +2 -2
  121. replay/utils/spark_utils.py +6 -5
  122. replay/utils/types.py +3 -1
  123. {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/METADATA +17 -17
  124. replay_rec-0.20.0rc0.dist-info/RECORD +194 -0
  125. replay_rec-0.20.0.dist-info/RECORD +0 -139
  126. {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/WHEEL +0 -0
  127. {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/licenses/LICENSE +0 -0
  128. {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/licenses/NOTICE +0 -0
replay/metrics/recall.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import List
2
-
3
1
  from .base_metric import Metric
4
2
 
5
3
 
@@ -65,7 +63,7 @@ class Recall(Metric):
65
63
  """
66
64
 
67
65
  @staticmethod
68
- def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
66
+ def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
69
67
  if not ground_truth or not pred:
70
68
  return [0.0 for _ in ks]
71
69
  set_gt = set(ground_truth)
replay/metrics/rocauc.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import List
2
-
3
1
  from .base_metric import Metric
4
2
 
5
3
 
@@ -74,7 +72,7 @@ class RocAuc(Metric):
74
72
  """
75
73
 
76
74
  @staticmethod
77
- def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
75
+ def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
78
76
  if not ground_truth or not pred:
79
77
  return [0.0 for _ in ks]
80
78
  set_gt = set(ground_truth)
@@ -1,5 +1,5 @@
1
1
  from collections import defaultdict
2
- from typing import Dict, List, Union
2
+ from typing import Union
3
3
 
4
4
  import numpy as np
5
5
  import polars as pl
@@ -82,7 +82,7 @@ class Surprisal(Metric):
82
82
  <BLANKLINE>
83
83
  """
84
84
 
85
- def _get_weights(self, train: Dict) -> Dict:
85
+ def _get_weights(self, train: dict) -> dict:
86
86
  n_users = len(train.keys())
87
87
  items_counter = defaultdict(set)
88
88
  for user, items in train.items():
@@ -93,7 +93,7 @@ class Surprisal(Metric):
93
93
  weights[item] = np.log2(n_users / len(users)) / np.log2(n_users)
94
94
  return weights
95
95
 
96
- def _get_recommendation_weights(self, recommendations: Dict, train: Dict) -> Dict:
96
+ def _get_recommendation_weights(self, recommendations: dict, train: dict) -> dict:
97
97
  weights = self._get_weights(train)
98
98
  recs_with_weights = {}
99
99
  for user, items in recommendations.items():
@@ -183,7 +183,7 @@ class Surprisal(Metric):
183
183
  )
184
184
 
185
185
  @staticmethod
186
- def _get_metric_value_by_user(ks: List[int], pred_item_ids: List, pred_weights: List) -> List[float]:
186
+ def _get_metric_value_by_user(ks: list[int], pred_item_ids: list, pred_weights: list) -> list[float]:
187
187
  if not pred_item_ids:
188
188
  return [0.0 for _ in ks]
189
189
  res = []
@@ -1,6 +1,7 @@
1
1
  import abc
2
+ from collections.abc import Mapping
2
3
  from dataclasses import dataclass
3
- from typing import Any, Dict, List, Literal, Mapping, Optional, Set
4
+ from typing import Any, Literal, Optional
4
5
 
5
6
  import numpy as np
6
7
 
@@ -19,13 +20,13 @@ MetricName = Literal[
19
20
  "coverage",
20
21
  ]
21
22
 
22
- DEFAULT_METRICS: List[MetricName] = [
23
+ DEFAULT_METRICS: list[MetricName] = [
23
24
  "map",
24
25
  "ndcg",
25
26
  "recall",
26
27
  ]
27
28
 
28
- DEFAULT_KS: List[int] = [1, 5, 10, 20]
29
+ DEFAULT_KS: list[int] = [1, 5, 10, 20]
29
30
 
30
31
 
31
32
  @dataclass
@@ -34,7 +35,7 @@ class _MetricRequirements:
34
35
  Stores description of metrics which need to be computed
35
36
  """
36
37
 
37
- top_k: List[int]
38
+ top_k: list[int]
38
39
  need_recall: bool
39
40
  need_precision: bool
40
41
  need_ndcg: bool
@@ -68,14 +69,14 @@ class _MetricRequirements:
68
69
  self._metric_names = metrics
69
70
 
70
71
  @property
71
- def metric_names(self) -> List[str]:
72
+ def metric_names(self) -> list[str]:
72
73
  """
73
74
  Getting metric names
74
75
  """
75
76
  return self._metric_names
76
77
 
77
78
  @classmethod
78
- def from_metrics(cls, metrics: Set[str], top_k: List[int]) -> "_MetricRequirements":
79
+ def from_metrics(cls, metrics: set[str], top_k: list[int]) -> "_MetricRequirements":
79
80
  """
80
81
  Creating a class based on a given list of metrics and K values
81
82
  """
@@ -96,7 +97,7 @@ class _CoverageHelper:
96
97
  Computes coverage metric over multiple batches
97
98
  """
98
99
 
99
- def __init__(self, top_k: List[int], item_count: Optional[int]) -> None:
100
+ def __init__(self, top_k: list[int], item_count: Optional[int]) -> None:
100
101
  """
101
102
  :param top_k: (list): Consider the highest k scores in the ranking.
102
103
  :param item_count: (optional, int): the total number of items in the dataset.
@@ -110,7 +111,7 @@ class _CoverageHelper:
110
111
  Reload the metric counter
111
112
  """
112
113
  self._train_hist = torch.zeros(self.item_count)
113
- self._pred_hist: Dict[int, torch.Tensor] = {k: torch.zeros(self.item_count) for k in self._top_k}
114
+ self._pred_hist: dict[int, torch.Tensor] = {k: torch.zeros(self.item_count) for k in self._top_k}
114
115
 
115
116
  def _ensure_hists_on_device(self, device: torch.device) -> None:
116
117
  self._train_hist = self._train_hist.to(device)
@@ -197,8 +198,8 @@ class TorchMetricsBuilder(_MetricBuilder):
197
198
 
198
199
  def __init__(
199
200
  self,
200
- metrics: List[MetricName] = DEFAULT_METRICS,
201
- top_k: Optional[List[int]] = DEFAULT_KS,
201
+ metrics: list[MetricName] = DEFAULT_METRICS,
202
+ top_k: Optional[list[int]] = DEFAULT_KS,
202
203
  item_count: Optional[int] = None,
203
204
  ) -> None:
204
205
  """
@@ -331,8 +332,8 @@ class TorchMetricsBuilder(_MetricBuilder):
331
332
 
332
333
  def _compute_metrics_sum(
333
334
  self, predictions: torch.LongTensor, ground_truth: torch.LongTensor, train: Optional[torch.LongTensor]
334
- ) -> List[float]:
335
- result: List[float] = []
335
+ ) -> list[float]:
336
+ result: list[float] = []
336
337
 
337
338
  # Getting a tensor of the same size as predictions
338
339
  # The tensor contains information about whether the item from the prediction is present in the test set
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
4
4
 
@@ -152,7 +152,7 @@ class Unexpectedness(Metric):
152
152
  )
153
153
 
154
154
  @staticmethod
155
- def _get_metric_value_by_user(ks: List[int], base_recs: Optional[List], recs: Optional[List]) -> List[float]:
155
+ def _get_metric_value_by_user(ks: list[int], base_recs: Optional[list], recs: Optional[list]) -> list[float]:
156
156
  if not base_recs or not recs:
157
157
  return [0.0 for _ in ks]
158
158
  return [1.0 - len(set(recs[:k]) & set(base_recs[:k])) / k for k in ks]
replay/models/als.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from os.path import join
2
- from typing import Optional, Tuple
2
+ from typing import Optional
3
3
 
4
4
  from replay.data import Dataset
5
5
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
@@ -148,7 +148,7 @@ class ALSWrap(Recommender, ItemVectorModel):
148
148
 
149
149
  def _get_features(
150
150
  self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
151
- ) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
151
+ ) -> tuple[Optional[SparkDataFrame], Optional[int]]:
152
152
  entity = "user" if self.query_column in ids.columns else "item"
153
153
  entity_col = self.query_column if self.query_column in ids.columns else self.item_column
154
154
 
@@ -1,4 +1,5 @@
1
- from typing import Any, Dict, Iterable, List, Optional, Union
1
+ from collections.abc import Iterable
2
+ from typing import Any, Optional, Union
2
3
 
3
4
  import numpy as np
4
5
 
@@ -97,13 +98,13 @@ class AssociationRulesItemRec(NeighbourRec):
97
98
  In this case all items in sessions should have the same rating.
98
99
  """
99
100
 
100
- def _get_ann_infer_params(self) -> Dict[str, Any]:
101
+ def _get_ann_infer_params(self) -> dict[str, Any]:
101
102
  return {
102
103
  "features_col": None,
103
104
  }
104
105
 
105
106
  can_predict_item_to_item = True
106
- item_to_item_metrics: List[str] = ["lift", "confidence", "confidence_gain"]
107
+ item_to_item_metrics: list[str] = ["lift", "confidence", "confidence_gain"]
107
108
  similarity: SparkDataFrame
108
109
  can_change_metric = True
109
110
  _search_space = {
@@ -4,7 +4,8 @@ Part of set of abstract classes (from base_rec.py)
4
4
  """
5
5
 
6
6
  from abc import ABC
7
- from typing import Any, Dict, Iterable, Optional, Union
7
+ from collections.abc import Iterable
8
+ from typing import Any, Optional, Union
8
9
 
9
10
  from replay.data.dataset import Dataset
10
11
  from replay.utils import PYSPARK_AVAILABLE, MissingImport, SparkDataFrame
@@ -187,7 +188,7 @@ class NeighbourRec(ANNMixin, Recommender, ABC):
187
188
  "similarity" if metric is None else metric,
188
189
  )
189
190
 
190
- def _configure_index_builder(self, interactions: SparkDataFrame) -> Dict[str, Any]:
191
+ def _configure_index_builder(self, interactions: SparkDataFrame) -> dict[str, Any]:
191
192
  similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
192
193
  self.index_builder.index_params.items_count = interactions.select(sf.max(self.item_column)).first()[0] + 1
193
194
  return similarity_df, {
replay/models/base_rec.py CHANGED
@@ -13,8 +13,9 @@ Base abstract classes:
13
13
 
14
14
  import warnings
15
15
  from abc import ABC, abstractmethod
16
+ from collections.abc import Iterable
16
17
  from os.path import join
17
- from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
18
+ from typing import Any, Optional, Union
18
19
 
19
20
  import numpy as np
20
21
  import pandas as pd
@@ -55,14 +56,14 @@ class IsSavable(ABC):
55
56
 
56
57
  @property
57
58
  @abstractmethod
58
- def _init_args(self) -> Dict:
59
+ def _init_args(self) -> dict:
59
60
  """
60
61
  Dictionary of the model attributes passed during model initialization.
61
62
  Used for model saving and loading
62
63
  """
63
64
 
64
65
  @property
65
- def _dataframes(self) -> Dict:
66
+ def _dataframes(self) -> dict:
66
67
  """
67
68
  Dictionary of the model dataframes required for inference.
68
69
  Used for model saving and loading
@@ -508,7 +509,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
508
509
  or None if `file_path` is provided
509
510
  """
510
511
  if dataset is not None:
511
- interactions, query_features, item_features, pairs = [
512
+ interactions, query_features, item_features, pairs = (
512
513
  convert2spark(df)
513
514
  for df in [
514
515
  dataset.interactions,
@@ -516,7 +517,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
516
517
  dataset.item_features,
517
518
  pairs,
518
519
  ]
519
- ]
520
+ )
520
521
  if set(pairs.columns) != {self.item_column, self.query_column}:
521
522
  msg = "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
522
523
  raise ValueError(msg)
@@ -590,7 +591,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
590
591
 
591
592
  def _get_features_wrap(
592
593
  self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
593
- ) -> Optional[Tuple[SparkDataFrame, int]]:
594
+ ) -> Optional[tuple[SparkDataFrame, int]]:
594
595
  if self.query_column not in ids.columns and self.item_column not in ids.columns:
595
596
  msg = f"{self.query_column} or {self.item_column} missing"
596
597
  raise ValueError(msg)
@@ -599,7 +600,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
599
600
 
600
601
  def _get_features(
601
602
  self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
602
- ) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
603
+ ) -> tuple[Optional[SparkDataFrame], Optional[int]]:
603
604
  """
604
605
  Get embeddings from model
605
606
 
@@ -679,7 +680,7 @@ class ItemVectorModel(BaseRecommender):
679
680
  """Parent for models generating items' vector representations"""
680
681
 
681
682
  can_predict_item_to_item: bool = True
682
- item_to_item_metrics: List[str] = [
683
+ item_to_item_metrics: list[str] = [
683
684
  "euclidean_distance_sim",
684
685
  "cosine_similarity",
685
686
  "dot_product",
@@ -899,7 +900,7 @@ class HybridRecommender(BaseRecommender, ABC):
899
900
 
900
901
  def get_features(
901
902
  self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
902
- ) -> Optional[Tuple[SparkDataFrame, int]]:
903
+ ) -> Optional[tuple[SparkDataFrame, int]]:
903
904
  """
904
905
  Returns query or item feature vectors as a Column with type ArrayType
905
906
  If a model does not have a vector for some ids they are not present in the final result.
@@ -1026,7 +1027,7 @@ class Recommender(BaseRecommender, ABC):
1026
1027
  recs_file_path=recs_file_path,
1027
1028
  )
1028
1029
 
1029
- def get_features(self, ids: SparkDataFrame) -> Optional[Tuple[SparkDataFrame, int]]:
1030
+ def get_features(self, ids: SparkDataFrame) -> Optional[tuple[SparkDataFrame, int]]:
1030
1031
  """
1031
1032
  Returns query or item feature vectors as a Column with type ArrayType
1032
1033
 
@@ -1,5 +1,6 @@
1
+ from collections.abc import Iterable
1
2
  from os.path import join
2
- from typing import Iterable, Optional, Union
3
+ from typing import Optional, Union
3
4
 
4
5
  from replay.data import Dataset
5
6
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
@@ -2,7 +2,8 @@ import importlib
2
2
  import logging
3
3
  import sys
4
4
  from abc import abstractmethod
5
- from typing import Any, Iterable, Optional, Union
5
+ from collections.abc import Iterable
6
+ from typing import Any, Optional, Union
6
7
 
7
8
  from replay.data import Dataset
8
9
  from replay.models.common import RecommenderCommons
@@ -1,5 +1,6 @@
1
1
  import logging
2
- from typing import Iterator, Optional
2
+ from collections.abc import Iterator
3
+ from typing import Optional
3
4
 
4
5
  import numpy as np
5
6
 
@@ -1,5 +1,6 @@
1
1
  import logging
2
- from typing import Iterator, Optional
2
+ from collections.abc import Iterator
3
+ from typing import Optional
3
4
 
4
5
  import pandas as pd
5
6
 
replay/models/lin_ucb.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import warnings
2
- from typing import List, Tuple, Union
2
+ from typing import Union
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -70,7 +70,7 @@ class HybridArm:
70
70
  # right-hand side of the regression
71
71
  self.b = np.zeros(d, dtype=float)
72
72
 
73
- def feature_update(self, usr_features, usr_itm_features, relevances) -> Tuple[np.ndarray, np.ndarray]:
73
+ def feature_update(self, usr_features, usr_itm_features, relevances) -> tuple[np.ndarray, np.ndarray]:
74
74
  """
75
75
  Function to update featurs or each Lin-UCB hand in the current model.
76
76
 
@@ -175,7 +175,7 @@ class LinUCB(HybridRecommender):
175
175
  "alpha": {"type": "uniform", "args": [0.001, 10.0]},
176
176
  }
177
177
  _study = None # field required for proper optuna's optimization
178
- linucb_arms: List[Union[DisjointArm, HybridArm]] # initialize only when working within fit method
178
+ linucb_arms: list[Union[DisjointArm, HybridArm]] # initialize only when working within fit method
179
179
  rel_matrix: np.array # matrix with relevance scores from predict method
180
180
 
181
181
  def __init__(
@@ -1,5 +1,5 @@
1
1
  import abc
2
- from typing import Iterator, Tuple
2
+ from collections.abc import Iterator
3
3
 
4
4
  import torch
5
5
 
@@ -47,7 +47,7 @@ class FatOptimizerFactory(OptimizerFactory):
47
47
  learning_rate: float = 0.001,
48
48
  weight_decay: float = 0.0,
49
49
  sgd_momentum: float = 0.0,
50
- betas: Tuple[float, float] = (0.9, 0.98),
50
+ betas: tuple[float, float] = (0.9, 0.98),
51
51
  ) -> None:
52
52
  super().__init__()
53
53
  self.optimizer = optimizer
@@ -1,5 +1,5 @@
1
1
  import abc
2
- from typing import NamedTuple, Optional, Tuple, cast
2
+ from typing import NamedTuple, Optional, cast
3
3
 
4
4
  import torch
5
5
  from torch.utils.data import Dataset as TorchDataset
@@ -295,7 +295,7 @@ def _shift_features(
295
295
  schema: TensorSchema,
296
296
  features: TensorMap,
297
297
  padding_mask: torch.BoolTensor,
298
- ) -> Tuple[TensorMap, torch.BoolTensor, torch.BoolTensor]:
298
+ ) -> tuple[TensorMap, torch.BoolTensor, torch.BoolTensor]:
299
299
  shifted_features: MutableTensorMap = {}
300
300
  for feature_name, feature in schema.items():
301
301
  if feature.is_seq:
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import Any, Dict, Literal, Optional, Tuple, Union, cast
2
+ from typing import Any, Literal, Optional, Union, cast
3
3
 
4
4
  import lightning
5
5
  import torch
@@ -338,7 +338,7 @@ class Bert4Rec(lightning.LightningModule):
338
338
  positive_labels: torch.LongTensor,
339
339
  padding_mask: torch.BoolTensor,
340
340
  tokens_mask: torch.BoolTensor,
341
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.LongTensor, int]:
341
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.LongTensor, int]:
342
342
  assert self._loss_sample_count is not None
343
343
  n_negative_samples = self._loss_sample_count
344
344
 
@@ -440,7 +440,7 @@ class Bert4Rec(lightning.LightningModule):
440
440
  msg = "Not supported loss_type"
441
441
  raise NotImplementedError(msg)
442
442
 
443
- def get_all_embeddings(self) -> Dict[str, torch.nn.Embedding]:
443
+ def get_all_embeddings(self) -> dict[str, torch.nn.Embedding]:
444
444
  """
445
445
  :returns: copy of all embeddings as a dictionary.
446
446
  """
@@ -1,6 +1,6 @@
1
1
  import contextlib
2
2
  from abc import ABC, abstractmethod
3
- from typing import Dict, Optional, Union
3
+ from typing import Optional, Union
4
4
 
5
5
  import torch
6
6
  import torch.nn as nn
@@ -303,7 +303,7 @@ class BertEmbedding(torch.nn.Module):
303
303
  """
304
304
  return self.cat_embeddings[self.schema.item_id_feature_name].weight
305
305
 
306
- def get_all_embeddings(self) -> Dict[str, torch.Tensor]:
306
+ def get_all_embeddings(self) -> dict[str, torch.Tensor]:
307
307
  """
308
308
  :returns: copy of all embeddings presented in this layer as a dict.
309
309
  """
@@ -1,5 +1,5 @@
1
1
  import abc
2
- from typing import Generic, List, Optional, Protocol, Tuple, TypeVar, cast
2
+ from typing import Generic, Optional, Protocol, TypeVar, cast
3
3
 
4
4
  import lightning
5
5
  import torch
@@ -38,7 +38,7 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
38
38
  query_column: str,
39
39
  item_column: str,
40
40
  rating_column: str = "rating",
41
- postprocessors: Optional[List[BasePostProcessor]] = None,
41
+ postprocessors: Optional[list[BasePostProcessor]] = None,
42
42
  ) -> None:
43
43
  """
44
44
  :param top_k: Takes the highest k scores in the ranking.
@@ -52,10 +52,10 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
52
52
  self.item_column = item_column
53
53
  self.rating_column = rating_column
54
54
  self._top_k = top_k
55
- self._postprocessors: List[BasePostProcessor] = postprocessors or []
56
- self._query_batches: List[torch.Tensor] = []
57
- self._item_batches: List[torch.Tensor] = []
58
- self._item_scores: List[torch.Tensor] = []
55
+ self._postprocessors: list[BasePostProcessor] = postprocessors or []
56
+ self._query_batches: list[torch.Tensor] = []
57
+ self._item_batches: list[torch.Tensor] = []
58
+ self._item_scores: list[torch.Tensor] = []
59
59
 
60
60
  def on_predict_epoch_start(
61
61
  self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
@@ -97,7 +97,7 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
97
97
 
98
98
  def _compute_pipeline(
99
99
  self, query_ids: torch.LongTensor, scores: torch.Tensor
100
- ) -> Tuple[torch.LongTensor, torch.Tensor]:
100
+ ) -> tuple[torch.LongTensor, torch.Tensor]:
101
101
  for postprocessor in self._postprocessors:
102
102
  query_ids, scores = postprocessor.on_prediction(query_ids, scores)
103
103
  return query_ids, scores
@@ -166,7 +166,7 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
166
166
  item_column: str,
167
167
  rating_column: str,
168
168
  spark_session: SparkSession,
169
- postprocessors: Optional[List[BasePostProcessor]] = None,
169
+ postprocessors: Optional[list[BasePostProcessor]] = None,
170
170
  ) -> None:
171
171
  """
172
172
  :param top_k: Takes the highest k scores in the ranking.
@@ -213,7 +213,7 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
213
213
  return prediction
214
214
 
215
215
 
216
- class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]]):
216
+ class TorchPredictionCallback(BasePredictionCallback[tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]]):
217
217
  """
218
218
  Callback for predition stage with tuple of tensors
219
219
  """
@@ -221,7 +221,7 @@ class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, tor
221
221
  def __init__(
222
222
  self,
223
223
  top_k: int,
224
- postprocessors: Optional[List[BasePostProcessor]] = None,
224
+ postprocessors: Optional[list[BasePostProcessor]] = None,
225
225
  ) -> None:
226
226
  """
227
227
  :param top_k: Takes the highest k scores in the ranking.
@@ -240,7 +240,7 @@ class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, tor
240
240
  query_ids: torch.Tensor,
241
241
  item_ids: torch.Tensor,
242
242
  item_scores: torch.Tensor,
243
- ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]:
243
+ ) -> tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]:
244
244
  return (
245
245
  cast(torch.LongTensor, query_ids.flatten().cpu().long()),
246
246
  cast(torch.LongTensor, item_ids.cpu().long()),
@@ -254,7 +254,7 @@ class QueryEmbeddingsPredictionCallback(lightning.Callback):
254
254
  """
255
255
 
256
256
  def __init__(self):
257
- self._embeddings_per_batch: List[torch.Tensor] = []
257
+ self._embeddings_per_batch: list[torch.Tensor] = []
258
258
 
259
259
  def on_predict_epoch_start(
260
260
  self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
@@ -1,4 +1,4 @@
1
- from typing import Any, List, Literal, Optional, Protocol, Tuple
1
+ from typing import Any, Literal, Optional, Protocol
2
2
 
3
3
  import lightning
4
4
  import torch
@@ -38,9 +38,9 @@ class ValidationMetricsCallback(lightning.Callback):
38
38
 
39
39
  def __init__(
40
40
  self,
41
- metrics: Optional[List[CallbackMetricName]] = None,
42
- ks: Optional[List[int]] = None,
43
- postprocessors: Optional[List[BasePostProcessor]] = None,
41
+ metrics: Optional[list[CallbackMetricName]] = None,
42
+ ks: Optional[list[int]] = None,
43
+ postprocessors: Optional[list[BasePostProcessor]] = None,
44
44
  item_count: Optional[int] = None,
45
45
  ):
46
46
  """
@@ -52,11 +52,11 @@ class ValidationMetricsCallback(lightning.Callback):
52
52
  self._metrics = metrics
53
53
  self._ks = ks
54
54
  self._item_count = item_count
55
- self._metrics_builders: List[TorchMetricsBuilder] = []
56
- self._dataloaders_size: List[int] = []
57
- self._postprocessors: List[BasePostProcessor] = postprocessors or []
55
+ self._metrics_builders: list[TorchMetricsBuilder] = []
56
+ self._dataloaders_size: list[int] = []
57
+ self._postprocessors: list[BasePostProcessor] = postprocessors or []
58
58
 
59
- def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> List[int]:
59
+ def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> list[int]:
60
60
  if isinstance(dataloaders, torch.utils.data.DataLoader):
61
61
  return [len(dataloaders)]
62
62
  return [len(dataloader) for dataloader in dataloaders]
@@ -85,7 +85,7 @@ class ValidationMetricsCallback(lightning.Callback):
85
85
 
86
86
  def _compute_pipeline(
87
87
  self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
88
- ) -> Tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
88
+ ) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
89
89
  for postprocessor in self._postprocessors:
90
90
  query_ids, scores, ground_truth = postprocessor.on_validation(query_ids, scores, ground_truth)
91
91
  return query_ids, scores, ground_truth
@@ -1,7 +1,7 @@
1
1
  import pathlib
2
2
  import tempfile
3
3
  from abc import abstractmethod
4
- from typing import Any, Dict, List, Literal, Optional, Tuple, Union
4
+ from typing import Any, Literal, Optional, Union
5
5
 
6
6
  import lightning
7
7
  import openvino as ov
@@ -68,7 +68,7 @@ class BaseCompiledModel:
68
68
  """
69
69
  self._batch_size: int
70
70
  self._max_seq_len: int
71
- self._inputs_names: List[str]
71
+ self._inputs_names: list[str]
72
72
  self._output_name: str
73
73
 
74
74
  self._set_inner_params_from_openvino_model(compiled_model)
@@ -171,9 +171,9 @@ class BaseCompiledModel:
171
171
  @staticmethod
172
172
  def _run_model_compilation(
173
173
  lightning_model: lightning.LightningModule,
174
- model_input_sample: Tuple[Union[torch.Tensor, Dict[str, torch.Tensor]]],
175
- model_input_names: List[str],
176
- model_dynamic_axes_in_input: Dict[str, Dict],
174
+ model_input_sample: tuple[Union[torch.Tensor, dict[str, torch.Tensor]]],
175
+ model_input_names: list[str],
176
+ model_dynamic_axes_in_input: dict[str, dict],
177
177
  batch_size: int,
178
178
  num_candidates_to_score: Union[int, None],
179
179
  num_threads: Optional[int] = None,
@@ -1,5 +1,4 @@
1
1
  import abc
2
- from typing import Tuple
3
2
 
4
3
  import torch
5
4
 
@@ -10,7 +9,7 @@ class BasePostProcessor(abc.ABC): # pragma: no cover
10
9
  """
11
10
 
12
11
  @abc.abstractmethod
13
- def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> Tuple[torch.LongTensor, torch.Tensor]:
12
+ def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> tuple[torch.LongTensor, torch.Tensor]:
14
13
  """
15
14
  Prediction step.
16
15
 
@@ -24,7 +23,7 @@ class BasePostProcessor(abc.ABC): # pragma: no cover
24
23
  @abc.abstractmethod
25
24
  def on_validation(
26
25
  self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
27
- ) -> Tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
26
+ ) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
28
27
  """
29
28
  Validation step.
30
29