replay-rec 0.20.0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +10 -9
  3. replay/data/dataset_utils/dataset_label_encoder.py +5 -4
  4. replay/data/nn/schema.py +9 -18
  5. replay/data/nn/sequence_tokenizer.py +26 -18
  6. replay/data/nn/sequential_dataset.py +22 -18
  7. replay/data/nn/torch_sequential_dataset.py +17 -16
  8. replay/data/nn/utils.py +2 -1
  9. replay/data/schema.py +3 -12
  10. replay/metrics/base_metric.py +11 -10
  11. replay/metrics/categorical_diversity.py +8 -8
  12. replay/metrics/coverage.py +4 -4
  13. replay/metrics/experiment.py +3 -3
  14. replay/metrics/hitrate.py +1 -3
  15. replay/metrics/map.py +1 -3
  16. replay/metrics/mrr.py +1 -3
  17. replay/metrics/ndcg.py +1 -2
  18. replay/metrics/novelty.py +3 -3
  19. replay/metrics/offline_metrics.py +16 -16
  20. replay/metrics/precision.py +1 -3
  21. replay/metrics/recall.py +1 -3
  22. replay/metrics/rocauc.py +1 -3
  23. replay/metrics/surprisal.py +4 -4
  24. replay/metrics/torch_metrics_builder.py +13 -12
  25. replay/metrics/unexpectedness.py +2 -2
  26. replay/models/als.py +2 -2
  27. replay/models/association_rules.py +4 -3
  28. replay/models/base_neighbour_rec.py +3 -2
  29. replay/models/base_rec.py +11 -10
  30. replay/models/cat_pop_rec.py +2 -1
  31. replay/models/extensions/ann/ann_mixin.py +2 -1
  32. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
  33. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
  34. replay/models/lin_ucb.py +57 -11
  35. replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
  36. replay/models/nn/sequential/bert4rec/dataset.py +5 -18
  37. replay/models/nn/sequential/bert4rec/lightning.py +3 -3
  38. replay/models/nn/sequential/bert4rec/model.py +2 -2
  39. replay/models/nn/sequential/callbacks/prediction_callbacks.py +12 -12
  40. replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
  41. replay/models/nn/sequential/compiled/base_compiled_model.py +5 -5
  42. replay/models/nn/sequential/postprocessors/_base.py +2 -3
  43. replay/models/nn/sequential/postprocessors/postprocessors.py +11 -11
  44. replay/models/nn/sequential/sasrec/dataset.py +3 -16
  45. replay/models/nn/sequential/sasrec/lightning.py +3 -3
  46. replay/models/nn/sequential/sasrec/model.py +8 -8
  47. replay/models/slim.py +2 -2
  48. replay/models/ucb.py +2 -2
  49. replay/models/word2vec.py +3 -3
  50. replay/preprocessing/discretizer.py +8 -7
  51. replay/preprocessing/filters.py +4 -4
  52. replay/preprocessing/history_based_fp.py +6 -6
  53. replay/preprocessing/label_encoder.py +8 -7
  54. replay/scenarios/fallback.py +4 -3
  55. replay/splitters/base_splitter.py +3 -3
  56. replay/splitters/cold_user_random_splitter.py +4 -4
  57. replay/splitters/k_folds.py +4 -4
  58. replay/splitters/last_n_splitter.py +10 -10
  59. replay/splitters/new_users_splitter.py +4 -4
  60. replay/splitters/random_splitter.py +4 -4
  61. replay/splitters/ratio_splitter.py +10 -10
  62. replay/splitters/time_splitter.py +6 -6
  63. replay/splitters/two_stage_splitter.py +4 -4
  64. replay/utils/__init__.py +1 -1
  65. replay/utils/common.py +1 -1
  66. replay/utils/session_handler.py +2 -2
  67. replay/utils/spark_utils.py +6 -5
  68. replay/utils/types.py +3 -1
  69. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/METADATA +7 -1
  70. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/RECORD +73 -74
  71. replay/utils/warnings.py +0 -26
  72. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/WHEEL +0 -0
  73. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/LICENSE +0 -0
  74. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/NOTICE +0 -0
@@ -1,6 +1,7 @@
1
1
  import warnings
2
2
  from abc import ABC, abstractmethod
3
- from typing import Any, Dict, List, Mapping, Optional, Union
3
+ from collections.abc import Mapping
4
+ from typing import Any, Optional, Union
4
5
 
5
6
  import numpy as np
6
7
  import polars as pl
@@ -14,7 +15,7 @@ if PYSPARK_AVAILABLE:
14
15
  from pyspark.sql.types import ArrayType, DoubleType, StructType
15
16
 
16
17
 
17
- MetricsDataFrameLike = Union[DataFrameLike, Dict]
18
+ MetricsDataFrameLike = Union[DataFrameLike, dict]
18
19
  MetricsMeanReturnType = Mapping[str, float]
19
20
  MetricsPerUserReturnType = Mapping[str, Mapping[Any, float]]
20
21
  MetricsReturnType = Union[MetricsMeanReturnType, MetricsPerUserReturnType]
@@ -29,7 +30,7 @@ class Metric(ABC):
29
30
 
30
31
  def __init__(
31
32
  self,
32
- topk: Union[List[int], int],
33
+ topk: Union[list[int], int],
33
34
  query_column: str = "query_id",
34
35
  item_column: str = "item_id",
35
36
  rating_column: str = "rating",
@@ -89,7 +90,7 @@ class Metric(ABC):
89
90
  if duplicates_count:
90
91
  self._duplicate_warn()
91
92
 
92
- def _check_duplicates_dict(self, recommendations: Dict) -> None:
93
+ def _check_duplicates_dict(self, recommendations: dict) -> None:
93
94
  for items in recommendations.values():
94
95
  items_set = set(items)
95
96
  if len(items) != len(items_set):
@@ -143,7 +144,7 @@ class Metric(ABC):
143
144
  ground_truth=ground_truth,
144
145
  )
145
146
 
146
- def _convert_pandas_to_dict_with_score(self, data: PandasDataFrame) -> Dict:
147
+ def _convert_pandas_to_dict_with_score(self, data: PandasDataFrame) -> dict:
147
148
  return (
148
149
  data.sort_values(by=[self.rating_column, self.item_column], ascending=False, kind="stable")
149
150
  .groupby(self.query_column)[self.item_column]
@@ -151,7 +152,7 @@ class Metric(ABC):
151
152
  .to_dict()
152
153
  )
153
154
 
154
- def _convert_dict_to_dict_with_score(self, data: Dict) -> Dict:
155
+ def _convert_dict_to_dict_with_score(self, data: dict) -> dict:
155
156
  converted_data = {}
156
157
  for user, items in data.items():
157
158
  is_sorted = True
@@ -164,10 +165,10 @@ class Metric(ABC):
164
165
  converted_data[user] = [item for item, _ in items]
165
166
  return converted_data
166
167
 
167
- def _convert_pandas_to_dict_without_score(self, data: PandasDataFrame) -> Dict:
168
+ def _convert_pandas_to_dict_without_score(self, data: PandasDataFrame) -> dict:
168
169
  return data.groupby(self.query_column)[self.item_column].apply(list).to_dict()
169
170
 
170
- def _dict_call(self, users: List, **kwargs: Dict) -> MetricsReturnType:
171
+ def _dict_call(self, users: list, **kwargs: dict) -> MetricsReturnType:
171
172
  """
172
173
  Calculating metrics in dict format.
173
174
  kwargs can contain different dicts (for example, ground_truth or train), it depends on the metric.
@@ -287,7 +288,7 @@ class Metric(ABC):
287
288
  )
288
289
  return self._rearrange_columns(enriched_recommendations)
289
290
 
290
- def _aggregate_results_per_user(self, distribution_per_user: Dict[Any, List[float]]) -> MetricsPerUserReturnType:
291
+ def _aggregate_results_per_user(self, distribution_per_user: dict[Any, list[float]]) -> MetricsPerUserReturnType:
291
292
  res: MetricsPerUserReturnType = {}
292
293
  for index, val in enumerate(self.topk):
293
294
  metric_name = f"{self.__name__}@{val}"
@@ -374,7 +375,7 @@ class Metric(ABC):
374
375
 
375
376
  @staticmethod
376
377
  @abstractmethod
377
- def _get_metric_value_by_user(ks: List[int], *args: List) -> List[float]: # pragma: no cover
378
+ def _get_metric_value_by_user(ks: list[int], *args: list) -> list[float]: # pragma: no cover
378
379
  """
379
380
  Metric calculation for one user.
380
381
 
@@ -1,5 +1,5 @@
1
1
  from collections import defaultdict
2
- from typing import Dict, List, Union
2
+ from typing import Union
3
3
 
4
4
  import numpy as np
5
5
  import polars as pl
@@ -62,7 +62,7 @@ class CategoricalDiversity(Metric):
62
62
 
63
63
  def __init__(
64
64
  self,
65
- topk: Union[List, int],
65
+ topk: Union[list, int],
66
66
  query_column: str = "query_id",
67
67
  category_column: str = "category_id",
68
68
  rating_column: str = "rating",
@@ -195,7 +195,7 @@ class CategoricalDiversity(Metric):
195
195
  return self._polars_compute_per_user(recs)
196
196
  return self._polars_compute_agg(recs)
197
197
 
198
- def _convert_pandas_to_dict_with_score(self, data: PandasDataFrame) -> Dict:
198
+ def _convert_pandas_to_dict_with_score(self, data: PandasDataFrame) -> dict:
199
199
  return (
200
200
  data.sort_values(by=self.rating_column, ascending=False)
201
201
  .groupby(self.query_column)[self.category_column]
@@ -203,7 +203,7 @@ class CategoricalDiversity(Metric):
203
203
  .to_dict()
204
204
  )
205
205
 
206
- def _precalculate_unique_cats(self, recommendations: Dict) -> Dict:
206
+ def _precalculate_unique_cats(self, recommendations: dict) -> dict:
207
207
  """
208
208
  Precalculate unique categories for each prefix for each user.
209
209
  """
@@ -217,14 +217,14 @@ class CategoricalDiversity(Metric):
217
217
  answer[user] = unique_len
218
218
  return answer
219
219
 
220
- def _dict_compute_per_user(self, precalculated_answer: Dict) -> MetricsPerUserReturnType:
220
+ def _dict_compute_per_user(self, precalculated_answer: dict) -> MetricsPerUserReturnType:
221
221
  distribution_per_user = defaultdict(list)
222
222
  for k in self.topk:
223
223
  for user, unique_cats in precalculated_answer.items():
224
224
  distribution_per_user[user].append(unique_cats[min(len(unique_cats), k) - 1] / k)
225
225
  return self._aggregate_results_per_user(distribution_per_user)
226
226
 
227
- def _dict_compute_mean(self, precalculated_answer: Dict) -> MetricsMeanReturnType:
227
+ def _dict_compute_mean(self, precalculated_answer: dict) -> MetricsMeanReturnType:
228
228
  distribution_list = []
229
229
  for unique_cats in precalculated_answer.values():
230
230
  metrics_per_user = []
@@ -238,7 +238,7 @@ class CategoricalDiversity(Metric):
238
238
  metrics = [self._mode.cpu(distribution[:, k]) for k in range(distribution.shape[1])]
239
239
  return self._aggregate_results(metrics)
240
240
 
241
- def _dict_call(self, precalculated_answer: Dict) -> MetricsReturnType:
241
+ def _dict_call(self, precalculated_answer: dict) -> MetricsReturnType:
242
242
  """
243
243
  Calculating metrics in dict format.
244
244
  """
@@ -247,5 +247,5 @@ class CategoricalDiversity(Metric):
247
247
  return self._dict_compute_mean(precalculated_answer)
248
248
 
249
249
  @staticmethod
250
- def _get_metric_value_by_user(ks: List[int], *args: List) -> List[float]: # pragma: no cover
250
+ def _get_metric_value_by_user(ks: list[int], *args: list) -> list[float]: # pragma: no cover
251
251
  pass
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import operator
3
- from typing import Dict, List, Union
3
+ from typing import Union
4
4
 
5
5
  import polars as pl
6
6
 
@@ -60,7 +60,7 @@ class Coverage(Metric):
60
60
 
61
61
  def __init__(
62
62
  self,
63
- topk: Union[List, int],
63
+ topk: Union[list, int],
64
64
  query_column: str = "query_id",
65
65
  item_column: str = "item_id",
66
66
  rating_column: str = "rating",
@@ -173,7 +173,7 @@ class Coverage(Metric):
173
173
  recs = self._get_enriched_recommendations(recommendations)
174
174
  return self._polars_compute(recs, train)
175
175
 
176
- def _dict_call(self, recommendations: Dict, train: Dict) -> MetricsReturnType:
176
+ def _dict_call(self, recommendations: dict, train: dict) -> MetricsReturnType:
177
177
  """
178
178
  Calculating metrics in dict format.
179
179
  """
@@ -229,5 +229,5 @@ class Coverage(Metric):
229
229
  return self._dict_call(recommendations, train)
230
230
 
231
231
  @staticmethod
232
- def _get_metric_value_by_user(ks, *args) -> List[float]: # pragma: no cover
232
+ def _get_metric_value_by_user(ks, *args) -> list[float]: # pragma: no cover
233
233
  pass
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -102,10 +102,10 @@ class Experiment:
102
102
 
103
103
  def __init__(
104
104
  self,
105
- metrics: List[Metric],
105
+ metrics: list[Metric],
106
106
  ground_truth: MetricsDataFrameLike,
107
107
  train: Optional[MetricsDataFrameLike] = None,
108
- base_recommendations: Optional[Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]] = None,
108
+ base_recommendations: Optional[Union[MetricsDataFrameLike, dict[str, MetricsDataFrameLike]]] = None,
109
109
  query_column: str = "query_id",
110
110
  item_column: str = "item_id",
111
111
  rating_column: str = "rating",
replay/metrics/hitrate.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import List
2
-
3
1
  from .base_metric import Metric
4
2
 
5
3
 
@@ -62,7 +60,7 @@ class HitRate(Metric):
62
60
  """
63
61
 
64
62
  @staticmethod
65
- def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
63
+ def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
66
64
  if not ground_truth or not pred:
67
65
  return [0.0 for _ in ks]
68
66
  set_gt = set(ground_truth)
replay/metrics/map.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import List
2
-
3
1
  from .base_metric import Metric
4
2
 
5
3
 
@@ -63,7 +61,7 @@ class MAP(Metric):
63
61
  """
64
62
 
65
63
  @staticmethod
66
- def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
64
+ def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
67
65
  if not ground_truth or not pred:
68
66
  return [0.0 for _ in ks]
69
67
  res = []
replay/metrics/mrr.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import List
2
-
3
1
  from .base_metric import Metric
4
2
 
5
3
 
@@ -55,7 +53,7 @@ class MRR(Metric):
55
53
  """
56
54
 
57
55
  @staticmethod
58
- def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
56
+ def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
59
57
  if not ground_truth or not pred:
60
58
  return [0.0 for _ in ks]
61
59
  set_gt = set(ground_truth)
replay/metrics/ndcg.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import math
2
- from typing import List
3
2
 
4
3
  from .base_metric import Metric
5
4
 
@@ -80,7 +79,7 @@ class NDCG(Metric):
80
79
  """
81
80
 
82
81
  @staticmethod
83
- def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
82
+ def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
84
83
  if not pred or not ground_truth:
85
84
  return [0.0 for _ in ks]
86
85
  set_gt = set(ground_truth)
replay/metrics/novelty.py CHANGED
@@ -1,11 +1,11 @@
1
- from typing import TYPE_CHECKING, List, Type
1
+ from typing import TYPE_CHECKING
2
2
 
3
3
  from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
4
4
 
5
5
  from .base_metric import Metric, MetricsDataFrameLike, MetricsReturnType
6
6
 
7
7
  if TYPE_CHECKING: # pragma: no cover
8
- __class__: Type
8
+ __class__: type
9
9
 
10
10
 
11
11
  class Novelty(Metric):
@@ -139,7 +139,7 @@ class Novelty(Metric):
139
139
  return self._polars_compute(recs)
140
140
 
141
141
  @staticmethod
142
- def _get_metric_value_by_user(ks: List[int], pred: List, train: List) -> List[float]:
142
+ def _get_metric_value_by_user(ks: list[int], pred: list, train: list) -> list[float]:
143
143
  if not train or not pred:
144
144
  return [1.0 for _ in ks]
145
145
  set_train = set(train)
@@ -1,5 +1,5 @@
1
1
  import warnings
2
- from typing import Dict, List, Optional, Tuple, Union
2
+ from typing import Optional, Union
3
3
 
4
4
  from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
5
5
 
@@ -132,7 +132,7 @@ class OfflineMetrics:
132
132
  <BLANKLINE>
133
133
  """
134
134
 
135
- _metrics_call_requirement_map: Dict[str, List[str]] = {
135
+ _metrics_call_requirement_map: dict[str, list[str]] = {
136
136
  "HitRate": ["ground_truth"],
137
137
  "MAP": ["ground_truth"],
138
138
  "NDCG": ["ground_truth"],
@@ -147,7 +147,7 @@ class OfflineMetrics:
147
147
 
148
148
  def __init__(
149
149
  self,
150
- metrics: List[Metric],
150
+ metrics: list[Metric],
151
151
  query_column: str = "query_id",
152
152
  item_column: str = "item_id",
153
153
  rating_column: str = "rating",
@@ -174,9 +174,9 @@ class OfflineMetrics:
174
174
  :param allow_caching: (bool): The flag for using caching to optimize calculations.
175
175
  Default: ``True``.
176
176
  """
177
- self.unexpectedness_metric: List[Metric] = []
178
- self.diversity_metric: List[Metric] = []
179
- self.main_metrics: List[Metric] = []
177
+ self.unexpectedness_metric: list[Metric] = []
178
+ self.diversity_metric: list[Metric] = []
179
+ self.main_metrics: list[Metric] = []
180
180
  self._allow_caching = allow_caching
181
181
 
182
182
  for metric in metrics:
@@ -198,7 +198,7 @@ class OfflineMetrics:
198
198
  recommendations: Union[SparkDataFrame, PolarsDataFrame],
199
199
  ground_truth: Union[SparkDataFrame, PolarsDataFrame],
200
200
  train: Optional[Union[SparkDataFrame, PolarsDataFrame]],
201
- ) -> Tuple[Dict[str, Union[SparkDataFrame, PolarsDataFrame]], Optional[Union[SparkDataFrame, PolarsDataFrame]]]:
201
+ ) -> tuple[dict[str, Union[SparkDataFrame, PolarsDataFrame]], Optional[Union[SparkDataFrame, PolarsDataFrame]]]:
202
202
  if len(self.main_metrics) == 0:
203
203
  return {}, train
204
204
  result_dict = {}
@@ -257,21 +257,21 @@ class OfflineMetrics:
257
257
 
258
258
  return result_dict, train
259
259
 
260
- def _cache_dataframes(self, dataframes: Dict[str, SparkDataFrame]) -> None:
260
+ def _cache_dataframes(self, dataframes: dict[str, SparkDataFrame]) -> None:
261
261
  for data in dataframes.values():
262
262
  data.cache()
263
263
 
264
- def _unpersist_dataframes(self, dataframes: Dict[str, SparkDataFrame]) -> None:
264
+ def _unpersist_dataframes(self, dataframes: dict[str, SparkDataFrame]) -> None:
265
265
  for data in dataframes.values():
266
266
  data.unpersist()
267
267
 
268
268
  def _calculate_metrics(
269
269
  self,
270
- enriched_recs_dict: Dict[str, Union[SparkDataFrame, PolarsDataFrame]],
270
+ enriched_recs_dict: dict[str, Union[SparkDataFrame, PolarsDataFrame]],
271
271
  train: Optional[Union[SparkDataFrame, PolarsDataFrame]] = None,
272
272
  is_spark: bool = True,
273
273
  ) -> MetricsReturnType:
274
- result: Dict = {}
274
+ result: dict = {}
275
275
  for metric in self.metrics:
276
276
  metric_args = {}
277
277
  if metric.__class__.__name__ == "Coverage" and train is not None:
@@ -295,7 +295,7 @@ class OfflineMetrics:
295
295
  recommendations: MetricsDataFrameLike,
296
296
  ground_truth: MetricsDataFrameLike,
297
297
  train: Optional[MetricsDataFrameLike],
298
- base_recommendations: Optional[Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]],
298
+ base_recommendations: Optional[Union[MetricsDataFrameLike, dict[str, MetricsDataFrameLike]]],
299
299
  ) -> None:
300
300
  types = set()
301
301
  types.add(type(recommendations))
@@ -379,8 +379,8 @@ class OfflineMetrics:
379
379
  recommendations: MetricsDataFrameLike,
380
380
  ground_truth: MetricsDataFrameLike,
381
381
  train: Optional[MetricsDataFrameLike] = None,
382
- base_recommendations: Optional[Union[MetricsDataFrameLike, Dict[str, MetricsDataFrameLike]]] = None,
383
- ) -> Dict[str, float]:
382
+ base_recommendations: Optional[Union[MetricsDataFrameLike, dict[str, MetricsDataFrameLike]]] = None,
383
+ ) -> dict[str, float]:
384
384
  """
385
385
  Compute metrics.
386
386
 
@@ -450,12 +450,12 @@ class OfflineMetrics:
450
450
  if is_spark and self._allow_caching:
451
451
  self._unpersist_dataframes(enriched_recs_dict)
452
452
  else: # Calculating metrics in dict format
453
- current_map: Dict[str, Union[PandasDataFrame, Dict]] = {
453
+ current_map: dict[str, Union[PandasDataFrame, dict]] = {
454
454
  "ground_truth": ground_truth,
455
455
  "train": train,
456
456
  }
457
457
  for metric in self.metrics:
458
- args_to_call: Dict[str, Union[PandasDataFrame, Dict]] = {"recommendations": recommendations}
458
+ args_to_call: dict[str, Union[PandasDataFrame, dict]] = {"recommendations": recommendations}
459
459
  for data_name in self._metrics_call_requirement_map[str(metric.__class__.__name__)]:
460
460
  args_to_call[data_name] = current_map[data_name]
461
461
  result.update(metric(**args_to_call))
@@ -1,5 +1,3 @@
1
- from typing import List
2
-
3
1
  from .base_metric import Metric
4
2
 
5
3
 
@@ -61,7 +59,7 @@ class Precision(Metric):
61
59
  """
62
60
 
63
61
  @staticmethod
64
- def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
62
+ def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
65
63
  if not ground_truth or not pred:
66
64
  return [0.0 for _ in ks]
67
65
  set_gt = set(ground_truth)
replay/metrics/recall.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import List
2
-
3
1
  from .base_metric import Metric
4
2
 
5
3
 
@@ -65,7 +63,7 @@ class Recall(Metric):
65
63
  """
66
64
 
67
65
  @staticmethod
68
- def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
66
+ def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
69
67
  if not ground_truth or not pred:
70
68
  return [0.0 for _ in ks]
71
69
  set_gt = set(ground_truth)
replay/metrics/rocauc.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import List
2
-
3
1
  from .base_metric import Metric
4
2
 
5
3
 
@@ -74,7 +72,7 @@ class RocAuc(Metric):
74
72
  """
75
73
 
76
74
  @staticmethod
77
- def _get_metric_value_by_user(ks: List[int], ground_truth: List, pred: List) -> List[float]:
75
+ def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
78
76
  if not ground_truth or not pred:
79
77
  return [0.0 for _ in ks]
80
78
  set_gt = set(ground_truth)
@@ -1,5 +1,5 @@
1
1
  from collections import defaultdict
2
- from typing import Dict, List, Union
2
+ from typing import Union
3
3
 
4
4
  import numpy as np
5
5
  import polars as pl
@@ -82,7 +82,7 @@ class Surprisal(Metric):
82
82
  <BLANKLINE>
83
83
  """
84
84
 
85
- def _get_weights(self, train: Dict) -> Dict:
85
+ def _get_weights(self, train: dict) -> dict:
86
86
  n_users = len(train.keys())
87
87
  items_counter = defaultdict(set)
88
88
  for user, items in train.items():
@@ -93,7 +93,7 @@ class Surprisal(Metric):
93
93
  weights[item] = np.log2(n_users / len(users)) / np.log2(n_users)
94
94
  return weights
95
95
 
96
- def _get_recommendation_weights(self, recommendations: Dict, train: Dict) -> Dict:
96
+ def _get_recommendation_weights(self, recommendations: dict, train: dict) -> dict:
97
97
  weights = self._get_weights(train)
98
98
  recs_with_weights = {}
99
99
  for user, items in recommendations.items():
@@ -183,7 +183,7 @@ class Surprisal(Metric):
183
183
  )
184
184
 
185
185
  @staticmethod
186
- def _get_metric_value_by_user(ks: List[int], pred_item_ids: List, pred_weights: List) -> List[float]:
186
+ def _get_metric_value_by_user(ks: list[int], pred_item_ids: list, pred_weights: list) -> list[float]:
187
187
  if not pred_item_ids:
188
188
  return [0.0 for _ in ks]
189
189
  res = []
@@ -1,6 +1,7 @@
1
1
  import abc
2
+ from collections.abc import Mapping
2
3
  from dataclasses import dataclass
3
- from typing import Any, Dict, List, Literal, Mapping, Optional, Set
4
+ from typing import Any, Literal, Optional
4
5
 
5
6
  import numpy as np
6
7
 
@@ -19,13 +20,13 @@ MetricName = Literal[
19
20
  "coverage",
20
21
  ]
21
22
 
22
- DEFAULT_METRICS: List[MetricName] = [
23
+ DEFAULT_METRICS: list[MetricName] = [
23
24
  "map",
24
25
  "ndcg",
25
26
  "recall",
26
27
  ]
27
28
 
28
- DEFAULT_KS: List[int] = [1, 5, 10, 20]
29
+ DEFAULT_KS: list[int] = [1, 5, 10, 20]
29
30
 
30
31
 
31
32
  @dataclass
@@ -34,7 +35,7 @@ class _MetricRequirements:
34
35
  Stores description of metrics which need to be computed
35
36
  """
36
37
 
37
- top_k: List[int]
38
+ top_k: list[int]
38
39
  need_recall: bool
39
40
  need_precision: bool
40
41
  need_ndcg: bool
@@ -68,14 +69,14 @@ class _MetricRequirements:
68
69
  self._metric_names = metrics
69
70
 
70
71
  @property
71
- def metric_names(self) -> List[str]:
72
+ def metric_names(self) -> list[str]:
72
73
  """
73
74
  Getting metric names
74
75
  """
75
76
  return self._metric_names
76
77
 
77
78
  @classmethod
78
- def from_metrics(cls, metrics: Set[str], top_k: List[int]) -> "_MetricRequirements":
79
+ def from_metrics(cls, metrics: set[str], top_k: list[int]) -> "_MetricRequirements":
79
80
  """
80
81
  Creating a class based on a given list of metrics and K values
81
82
  """
@@ -96,7 +97,7 @@ class _CoverageHelper:
96
97
  Computes coverage metric over multiple batches
97
98
  """
98
99
 
99
- def __init__(self, top_k: List[int], item_count: Optional[int]) -> None:
100
+ def __init__(self, top_k: list[int], item_count: Optional[int]) -> None:
100
101
  """
101
102
  :param top_k: (list): Consider the highest k scores in the ranking.
102
103
  :param item_count: (optional, int): the total number of items in the dataset.
@@ -110,7 +111,7 @@ class _CoverageHelper:
110
111
  Reload the metric counter
111
112
  """
112
113
  self._train_hist = torch.zeros(self.item_count)
113
- self._pred_hist: Dict[int, torch.Tensor] = {k: torch.zeros(self.item_count) for k in self._top_k}
114
+ self._pred_hist: dict[int, torch.Tensor] = {k: torch.zeros(self.item_count) for k in self._top_k}
114
115
 
115
116
  def _ensure_hists_on_device(self, device: torch.device) -> None:
116
117
  self._train_hist = self._train_hist.to(device)
@@ -197,8 +198,8 @@ class TorchMetricsBuilder(_MetricBuilder):
197
198
 
198
199
  def __init__(
199
200
  self,
200
- metrics: List[MetricName] = DEFAULT_METRICS,
201
- top_k: Optional[List[int]] = DEFAULT_KS,
201
+ metrics: list[MetricName] = DEFAULT_METRICS,
202
+ top_k: Optional[list[int]] = DEFAULT_KS,
202
203
  item_count: Optional[int] = None,
203
204
  ) -> None:
204
205
  """
@@ -331,8 +332,8 @@ class TorchMetricsBuilder(_MetricBuilder):
331
332
 
332
333
  def _compute_metrics_sum(
333
334
  self, predictions: torch.LongTensor, ground_truth: torch.LongTensor, train: Optional[torch.LongTensor]
334
- ) -> List[float]:
335
- result: List[float] = []
335
+ ) -> list[float]:
336
+ result: list[float] = []
336
337
 
337
338
  # Getting a tensor of the same size as predictions
338
339
  # The tensor contains information about whether the item from the prediction is present in the test set
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
4
4
 
@@ -152,7 +152,7 @@ class Unexpectedness(Metric):
152
152
  )
153
153
 
154
154
  @staticmethod
155
- def _get_metric_value_by_user(ks: List[int], base_recs: Optional[List], recs: Optional[List]) -> List[float]:
155
+ def _get_metric_value_by_user(ks: list[int], base_recs: Optional[list], recs: Optional[list]) -> list[float]:
156
156
  if not base_recs or not recs:
157
157
  return [0.0 for _ in ks]
158
158
  return [1.0 - len(set(recs[:k]) & set(base_recs[:k])) / k for k in ks]
replay/models/als.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from os.path import join
2
- from typing import Optional, Tuple
2
+ from typing import Optional
3
3
 
4
4
  from replay.data import Dataset
5
5
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
@@ -148,7 +148,7 @@ class ALSWrap(Recommender, ItemVectorModel):
148
148
 
149
149
  def _get_features(
150
150
  self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
151
- ) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
151
+ ) -> tuple[Optional[SparkDataFrame], Optional[int]]:
152
152
  entity = "user" if self.query_column in ids.columns else "item"
153
153
  entity_col = self.query_column if self.query_column in ids.columns else self.item_column
154
154
 
@@ -1,4 +1,5 @@
1
- from typing import Any, Dict, Iterable, List, Optional, Union
1
+ from collections.abc import Iterable
2
+ from typing import Any, Optional, Union
2
3
 
3
4
  import numpy as np
4
5
 
@@ -97,13 +98,13 @@ class AssociationRulesItemRec(NeighbourRec):
97
98
  In this case all items in sessions should have the same rating.
98
99
  """
99
100
 
100
- def _get_ann_infer_params(self) -> Dict[str, Any]:
101
+ def _get_ann_infer_params(self) -> dict[str, Any]:
101
102
  return {
102
103
  "features_col": None,
103
104
  }
104
105
 
105
106
  can_predict_item_to_item = True
106
- item_to_item_metrics: List[str] = ["lift", "confidence", "confidence_gain"]
107
+ item_to_item_metrics: list[str] = ["lift", "confidence", "confidence_gain"]
107
108
  similarity: SparkDataFrame
108
109
  can_change_metric = True
109
110
  _search_space = {
@@ -4,7 +4,8 @@ Part of set of abstract classes (from base_rec.py)
4
4
  """
5
5
 
6
6
  from abc import ABC
7
- from typing import Any, Dict, Iterable, Optional, Union
7
+ from collections.abc import Iterable
8
+ from typing import Any, Optional, Union
8
9
 
9
10
  from replay.data.dataset import Dataset
10
11
  from replay.utils import PYSPARK_AVAILABLE, MissingImport, SparkDataFrame
@@ -187,7 +188,7 @@ class NeighbourRec(ANNMixin, Recommender, ABC):
187
188
  "similarity" if metric is None else metric,
188
189
  )
189
190
 
190
- def _configure_index_builder(self, interactions: SparkDataFrame) -> Dict[str, Any]:
191
+ def _configure_index_builder(self, interactions: SparkDataFrame) -> dict[str, Any]:
191
192
  similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
192
193
  self.index_builder.index_params.items_count = interactions.select(sf.max(self.item_column)).first()[0] + 1
193
194
  return similarity_df, {