replay-rec 0.20.0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +10 -9
  3. replay/data/dataset_utils/dataset_label_encoder.py +5 -4
  4. replay/data/nn/schema.py +9 -18
  5. replay/data/nn/sequence_tokenizer.py +26 -18
  6. replay/data/nn/sequential_dataset.py +22 -18
  7. replay/data/nn/torch_sequential_dataset.py +17 -16
  8. replay/data/nn/utils.py +2 -1
  9. replay/data/schema.py +3 -12
  10. replay/metrics/base_metric.py +11 -10
  11. replay/metrics/categorical_diversity.py +8 -8
  12. replay/metrics/coverage.py +4 -4
  13. replay/metrics/experiment.py +3 -3
  14. replay/metrics/hitrate.py +1 -3
  15. replay/metrics/map.py +1 -3
  16. replay/metrics/mrr.py +1 -3
  17. replay/metrics/ndcg.py +1 -2
  18. replay/metrics/novelty.py +3 -3
  19. replay/metrics/offline_metrics.py +16 -16
  20. replay/metrics/precision.py +1 -3
  21. replay/metrics/recall.py +1 -3
  22. replay/metrics/rocauc.py +1 -3
  23. replay/metrics/surprisal.py +4 -4
  24. replay/metrics/torch_metrics_builder.py +13 -12
  25. replay/metrics/unexpectedness.py +2 -2
  26. replay/models/als.py +2 -2
  27. replay/models/association_rules.py +4 -3
  28. replay/models/base_neighbour_rec.py +3 -2
  29. replay/models/base_rec.py +11 -10
  30. replay/models/cat_pop_rec.py +2 -1
  31. replay/models/extensions/ann/ann_mixin.py +2 -1
  32. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
  33. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
  34. replay/models/lin_ucb.py +57 -11
  35. replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
  36. replay/models/nn/sequential/bert4rec/dataset.py +5 -18
  37. replay/models/nn/sequential/bert4rec/lightning.py +3 -3
  38. replay/models/nn/sequential/bert4rec/model.py +2 -2
  39. replay/models/nn/sequential/callbacks/prediction_callbacks.py +12 -12
  40. replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
  41. replay/models/nn/sequential/compiled/base_compiled_model.py +5 -5
  42. replay/models/nn/sequential/postprocessors/_base.py +2 -3
  43. replay/models/nn/sequential/postprocessors/postprocessors.py +11 -11
  44. replay/models/nn/sequential/sasrec/dataset.py +3 -16
  45. replay/models/nn/sequential/sasrec/lightning.py +3 -3
  46. replay/models/nn/sequential/sasrec/model.py +8 -8
  47. replay/models/slim.py +2 -2
  48. replay/models/ucb.py +2 -2
  49. replay/models/word2vec.py +3 -3
  50. replay/preprocessing/discretizer.py +8 -7
  51. replay/preprocessing/filters.py +4 -4
  52. replay/preprocessing/history_based_fp.py +6 -6
  53. replay/preprocessing/label_encoder.py +8 -7
  54. replay/scenarios/fallback.py +4 -3
  55. replay/splitters/base_splitter.py +3 -3
  56. replay/splitters/cold_user_random_splitter.py +4 -4
  57. replay/splitters/k_folds.py +4 -4
  58. replay/splitters/last_n_splitter.py +10 -10
  59. replay/splitters/new_users_splitter.py +4 -4
  60. replay/splitters/random_splitter.py +4 -4
  61. replay/splitters/ratio_splitter.py +10 -10
  62. replay/splitters/time_splitter.py +6 -6
  63. replay/splitters/two_stage_splitter.py +4 -4
  64. replay/utils/__init__.py +1 -1
  65. replay/utils/common.py +1 -1
  66. replay/utils/session_handler.py +2 -2
  67. replay/utils/spark_utils.py +6 -5
  68. replay/utils/types.py +3 -1
  69. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/METADATA +7 -1
  70. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/RECORD +73 -74
  71. replay/utils/warnings.py +0 -26
  72. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/WHEEL +0 -0
  73. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/LICENSE +0 -0
  74. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/NOTICE +0 -0
@@ -10,7 +10,6 @@ from replay.data.nn import (
10
10
  TorchSequentialDataset,
11
11
  TorchSequentialValidationDataset,
12
12
  )
13
- from replay.utils import deprecation_warning
14
13
 
15
14
 
16
15
  class SasRecTrainingBatch(NamedTuple):
@@ -31,17 +30,13 @@ class SasRecTrainingDataset(TorchDataset):
31
30
  Dataset that generates samples to train SasRec-like model
32
31
  """
33
32
 
34
- @deprecation_warning(
35
- "`padding_value` parameter will be removed in future versions. "
36
- "Instead, you should specify `padding_value` for each column in TensorSchema"
37
- )
38
33
  def __init__(
39
34
  self,
40
35
  sequential: SequentialDataset,
41
36
  max_sequence_length: int,
42
37
  sequence_shift: int = 1,
43
38
  sliding_window_step: Optional[None] = None,
44
- padding_value: int = 0,
39
+ padding_value: Optional[int] = None,
45
40
  label_feature_name: Optional[str] = None,
46
41
  ) -> None:
47
42
  """
@@ -127,15 +122,11 @@ class SasRecPredictionDataset(TorchDataset):
127
122
  Dataset that generates samples to infer SasRec-like model
128
123
  """
129
124
 
130
- @deprecation_warning(
131
- "`padding_value` parameter will be removed in future versions. "
132
- "Instead, you should specify `padding_value` for each column in TensorSchema"
133
- )
134
125
  def __init__(
135
126
  self,
136
127
  sequential: SequentialDataset,
137
128
  max_sequence_length: int,
138
- padding_value: int = 0,
129
+ padding_value: Optional[int] = None,
139
130
  ) -> None:
140
131
  """
141
132
  :param sequential: Sequential dataset with data to make predictions at.
@@ -179,17 +170,13 @@ class SasRecValidationDataset(TorchDataset):
179
170
  Dataset that generates samples to infer and validate SasRec-like model
180
171
  """
181
172
 
182
- @deprecation_warning(
183
- "`padding_value` parameter will be removed in future versions. "
184
- "Instead, you should specify `padding_value` for each column in TensorSchema"
185
- )
186
173
  def __init__(
187
174
  self,
188
175
  sequential: SequentialDataset,
189
176
  ground_truth: SequentialDataset,
190
177
  train: SequentialDataset,
191
178
  max_sequence_length: int,
192
- padding_value: int = 0,
179
+ padding_value: Optional[int] = None,
193
180
  label_feature_name: Optional[str] = None,
194
181
  ):
195
182
  """
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import Any, Dict, Literal, Optional, Tuple, Union, cast
2
+ from typing import Any, Literal, Optional, Union, cast
3
3
 
4
4
  import lightning
5
5
  import torch
@@ -341,7 +341,7 @@ class SasRec(lightning.LightningModule):
341
341
  positive_labels: torch.LongTensor,
342
342
  padding_mask: torch.BoolTensor,
343
343
  target_padding_mask: torch.BoolTensor,
344
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.LongTensor, int]:
344
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.LongTensor, int]:
345
345
  assert self._loss_sample_count is not None
346
346
  n_negative_samples = self._loss_sample_count
347
347
  positive_labels = cast(
@@ -428,7 +428,7 @@ class SasRec(lightning.LightningModule):
428
428
  msg = "Not supported loss_type"
429
429
  raise NotImplementedError(msg)
430
430
 
431
- def get_all_embeddings(self) -> Dict[str, torch.nn.Embedding]:
431
+ def get_all_embeddings(self) -> dict[str, torch.nn.Embedding]:
432
432
  """
433
433
  :returns: copy of all embeddings as a dictionary.
434
434
  """
@@ -1,6 +1,6 @@
1
1
  import abc
2
2
  import contextlib
3
- from typing import Any, Dict, Optional, Tuple, Union, cast
3
+ from typing import Any, Optional, Union, cast
4
4
 
5
5
  import torch
6
6
 
@@ -212,7 +212,7 @@ class SasRecMasks:
212
212
  self,
213
213
  feature_tensor: TensorMap,
214
214
  padding_mask: torch.BoolTensor,
215
- ) -> Tuple[torch.BoolTensor, torch.BoolTensor, TensorMap]:
215
+ ) -> tuple[torch.BoolTensor, torch.BoolTensor, TensorMap]:
216
216
  """
217
217
  :param feature_tensor: Batch of features.
218
218
  :param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
@@ -260,7 +260,7 @@ class BaseSasRecEmbeddings(abc.ABC):
260
260
  """
261
261
 
262
262
  @abc.abstractmethod
263
- def get_all_embeddings(self) -> Dict[str, torch.Tensor]:
263
+ def get_all_embeddings(self) -> dict[str, torch.Tensor]:
264
264
  """
265
265
  :returns: copy of all embeddings presented in a layer as a dict.
266
266
  """
@@ -366,7 +366,7 @@ class SasRecEmbeddings(torch.nn.Module, BaseSasRecEmbeddings):
366
366
  # Last one is reserved for padding, so we remove it
367
367
  return self.item_emb.weight[:-1, :]
368
368
 
369
- def get_all_embeddings(self) -> Dict[str, torch.Tensor]:
369
+ def get_all_embeddings(self) -> dict[str, torch.Tensor]:
370
370
  """
371
371
  :returns: copy of all embeddings presented in this layer as a dict.
372
372
  """
@@ -579,7 +579,7 @@ class TiSasRecEmbeddings(torch.nn.Module, BaseSasRecEmbeddings):
579
579
  self,
580
580
  feature_tensor: TensorMap,
581
581
  padding_mask: torch.BoolTensor,
582
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
582
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
583
583
  """
584
584
  :param feature_tensor: Batch of features.
585
585
  :param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
@@ -628,7 +628,7 @@ class TiSasRecEmbeddings(torch.nn.Module, BaseSasRecEmbeddings):
628
628
  # Last one is reserved for padding, so we remove it
629
629
  return self.item_emb.weight[:-1, :]
630
630
 
631
- def get_all_embeddings(self) -> Dict[str, torch.Tensor]:
631
+ def get_all_embeddings(self) -> dict[str, torch.Tensor]:
632
632
  """
633
633
  :returns: copy of all embeddings presented in this layer as a dict.
634
634
  """
@@ -674,7 +674,7 @@ class TiSasRecLayers(torch.nn.Module):
674
674
  seqs: torch.Tensor,
675
675
  attention_mask: torch.BoolTensor,
676
676
  padding_mask: torch.BoolTensor,
677
- ti_embeddings: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
677
+ ti_embeddings: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
678
678
  device: torch.device,
679
679
  ) -> torch.Tensor:
680
680
  """
@@ -734,7 +734,7 @@ class TiSasRecAttention(torch.nn.Module):
734
734
  keys: torch.LongTensor,
735
735
  time_mask: torch.LongTensor,
736
736
  attn_mask: torch.LongTensor,
737
- ti_embeddings: Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor],
737
+ ti_embeddings: tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor],
738
738
  device: torch.device,
739
739
  ) -> torch.Tensor:
740
740
  """
replay/models/slim.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -21,7 +21,7 @@ class SLIM(NeighbourRec):
21
21
  """`SLIM: Sparse Linear Methods for Top-N Recommender Systems
22
22
  <http://glaros.dtc.umn.edu/gkhome/fetch/papers/SLIM2011icdm.pdf>`_"""
23
23
 
24
- def _get_ann_infer_params(self) -> Dict[str, Any]:
24
+ def _get_ann_infer_params(self) -> dict[str, Any]:
25
25
  return {
26
26
  "features_col": None,
27
27
  }
replay/models/ucb.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import Any, Dict, List, Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  from replay.data.dataset import Dataset
5
5
  from replay.metrics import NDCG, Metric
@@ -103,7 +103,7 @@ class UCB(NonPersonalizedRecommender):
103
103
  self,
104
104
  train_dataset: Dataset, # noqa: ARG002
105
105
  test_dataset: Dataset, # noqa: ARG002
106
- param_borders: Optional[Dict[str, List[Any]]] = None, # noqa: ARG002
106
+ param_borders: Optional[dict[str, list[Any]]] = None, # noqa: ARG002
107
107
  criterion: Metric = NDCG, # noqa: ARG002
108
108
  k: int = 10, # noqa: ARG002
109
109
  budget: int = 10, # noqa: ARG002
replay/models/word2vec.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from replay.data import Dataset
4
4
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
@@ -24,7 +24,7 @@ class Word2VecRec(ANNMixin, Recommender, ItemVectorModel):
24
24
  Trains word2vec model where items are treated as words and queries as sentences.
25
25
  """
26
26
 
27
- def _get_ann_infer_params(self) -> Dict[str, Any]:
27
+ def _get_ann_infer_params(self) -> dict[str, Any]:
28
28
  self.index_builder.index_params.dim = self.rank
29
29
  return {
30
30
  "features_col": "query_vector",
@@ -36,7 +36,7 @@ class Word2VecRec(ANNMixin, Recommender, ItemVectorModel):
36
36
  query_vectors = query_vectors.select(self.query_column, vector_to_array("query_vector").alias("query_vector"))
37
37
  return query_vectors
38
38
 
39
- def _configure_index_builder(self, interactions: SparkDataFrame) -> Dict[str, Any]:
39
+ def _configure_index_builder(self, interactions: SparkDataFrame) -> dict[str, Any]:
40
40
  item_vectors = self._get_item_vectors()
41
41
  item_vectors = item_vectors.select(self.item_column, vector_to_array("item_vector").alias("item_vector"))
42
42
 
@@ -2,8 +2,9 @@ import abc
2
2
  import json
3
3
  import os
4
4
  import warnings
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Dict, List, Literal, Sequence
7
+ from typing import Literal
7
8
 
8
9
  import numpy as np
9
10
  import polars as pl
@@ -114,7 +115,7 @@ class GreedyDiscretizingRule(BaseDiscretizingRule):
114
115
  max_bin: int,
115
116
  total_cnt: int,
116
117
  min_data_in_bin: int,
117
- ) -> List[float]:
118
+ ) -> list[float]:
118
119
  """
119
120
  Computes bound for bins.
120
121
 
@@ -363,7 +364,7 @@ Set 'keep' or 'skip' for processing NaN."
363
364
  @classmethod
364
365
  def load(cls, path: str) -> "GreedyDiscretizingRule":
365
366
  base_path = Path(path).with_suffix(".replay").resolve()
366
- with open(base_path / "init_args.json", "r") as file:
367
+ with open(base_path / "init_args.json") as file:
367
368
  discretizer_rule_dict = json.loads(file.read())
368
369
 
369
370
  discretizer_rule = cls(**discretizer_rule_dict["init_args"])
@@ -590,7 +591,7 @@ Set 'keep' or 'skip' for processing NaN."
590
591
  @classmethod
591
592
  def load(cls, path: str) -> "QuantileDiscretizingRule":
592
593
  base_path = Path(path).with_suffix(".replay").resolve()
593
- with open(base_path / "init_args.json", "r") as file:
594
+ with open(base_path / "init_args.json") as file:
594
595
  discretizer_rule_dict = json.loads(file.read())
595
596
 
596
597
  discretizer_rule = cls(**discretizer_rule_dict["init_args"])
@@ -655,7 +656,7 @@ class Discretizer:
655
656
  """
656
657
  return self.fit(df).transform(df)
657
658
 
658
- def set_handle_invalid(self, handle_invalid_rules: Dict[str, HandleInvalidStrategies]) -> None:
659
+ def set_handle_invalid(self, handle_invalid_rules: dict[str, HandleInvalidStrategies]) -> None:
659
660
  """
660
661
  Modify handle_invalid strategy on already fitted Discretizer.
661
662
 
@@ -704,13 +705,13 @@ class Discretizer:
704
705
  @classmethod
705
706
  def load(cls, path: str) -> "Discretizer":
706
707
  base_path = Path(path).with_suffix(".replay").resolve()
707
- with open(base_path / "init_args.json", "r") as file:
708
+ with open(base_path / "init_args.json") as file:
708
709
  discretizer_dict = json.loads(file.read())
709
710
  rules = []
710
711
  for root, dirs, files in os.walk(str(base_path) + "/rules/"):
711
712
  for d in dirs:
712
713
  if d.split(".")[0] in discretizer_dict["rule_names"]:
713
- with open(root + d + "/init_args.json", "r") as file:
714
+ with open(root + d + "/init_args.json") as file:
714
715
  discretizer_rule_dict = json.loads(file.read())
715
716
  rules.append(globals()[discretizer_rule_dict["_class_name"]].load(root + d))
716
717
 
@@ -4,7 +4,7 @@ Select or remove data by some criteria
4
4
 
5
5
  from abc import ABC, abstractmethod
6
6
  from datetime import datetime, timedelta
7
- from typing import Callable, Literal, Optional, Tuple, Union
7
+ from typing import Callable, Literal, Optional, Union
8
8
  from uuid import uuid4
9
9
 
10
10
  import numpy as np
@@ -182,7 +182,7 @@ class InteractionEntriesFilter(_BaseFilter):
182
182
  non_agg_column: str,
183
183
  min_inter: Optional[int] = None,
184
184
  max_inter: Optional[int] = None,
185
- ) -> Tuple[PandasDataFrame, int, int]:
185
+ ) -> tuple[PandasDataFrame, int, int]:
186
186
  filtered_interactions = interactions.copy(deep=True)
187
187
 
188
188
  filtered_interactions["count"] = filtered_interactions.groupby(agg_column, sort=False)[
@@ -207,7 +207,7 @@ class InteractionEntriesFilter(_BaseFilter):
207
207
  non_agg_column: str,
208
208
  min_inter: Optional[int] = None,
209
209
  max_inter: Optional[int] = None,
210
- ) -> Tuple[SparkDataFrame, int, int]:
210
+ ) -> tuple[SparkDataFrame, int, int]:
211
211
  filtered_interactions = interactions.withColumn(
212
212
  "count", sf.count(non_agg_column).over(Window.partitionBy(agg_column))
213
213
  )
@@ -233,7 +233,7 @@ class InteractionEntriesFilter(_BaseFilter):
233
233
  non_agg_column: str,
234
234
  min_inter: Optional[int] = None,
235
235
  max_inter: Optional[int] = None,
236
- ) -> Tuple[PolarsDataFrame, int, int]:
236
+ ) -> tuple[PolarsDataFrame, int, int]:
237
237
  filtered_interactions = interactions.with_columns(
238
238
  pl.col(non_agg_column).count().over(pl.col(agg_column)).alias("count")
239
239
  )
@@ -9,7 +9,7 @@ Contains classes for users' and items' features generation based on interactions
9
9
  """
10
10
 
11
11
  from datetime import datetime
12
- from typing import Dict, List, Optional
12
+ from typing import Optional
13
13
 
14
14
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
15
15
 
@@ -64,7 +64,7 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
64
64
  user_log_features: Optional[SparkDataFrame] = None
65
65
  item_log_features: Optional[SparkDataFrame] = None
66
66
 
67
- def _create_log_aggregates(self, agg_col: str = "user_idx") -> List:
67
+ def _create_log_aggregates(self, agg_col: str = "user_idx") -> list:
68
68
  """
69
69
  Create features based on relevance type
70
70
  (binary or not) and whether timestamp is present.
@@ -289,12 +289,12 @@ class ConditionalPopularityProcessor(EmptyFeatureProcessor):
289
289
  If user features are provided, item features will be generated and vice versa.
290
290
  """
291
291
 
292
- conditional_pop_dict: Optional[Dict[str, SparkDataFrame]]
292
+ conditional_pop_dict: Optional[dict[str, SparkDataFrame]]
293
293
  entity_name: str
294
294
 
295
295
  def __init__(
296
296
  self,
297
- cat_features_list: List,
297
+ cat_features_list: list,
298
298
  ):
299
299
  """
300
300
  :param cat_features_list: List of columns with categorical features to use
@@ -397,8 +397,8 @@ class HistoryBasedFeaturesProcessor:
397
397
  self,
398
398
  use_log_features: bool = True,
399
399
  use_conditional_popularity: bool = True,
400
- user_cat_features_list: Optional[List] = None,
401
- item_cat_features_list: Optional[List] = None,
400
+ user_cat_features_list: Optional[list] = None,
401
+ item_cat_features_list: Optional[list] = None,
402
402
  ):
403
403
  """
404
404
  :param use_log_features: if add statistical log-based features
@@ -10,8 +10,9 @@ import abc
10
10
  import json
11
11
  import os
12
12
  import warnings
13
+ from collections.abc import Mapping, Sequence
13
14
  from pathlib import Path
14
- from typing import Dict, List, Literal, Mapping, Optional, Sequence, Union
15
+ from typing import Literal, Optional, Union
15
16
 
16
17
  import polars as pl
17
18
 
@@ -162,7 +163,7 @@ class LabelEncodingRule(BaseLabelEncodingRule):
162
163
  def _make_inverse_mapping(self) -> Mapping:
163
164
  return {val: key for key, val in self.get_mapping().items()}
164
165
 
165
- def _make_inverse_mapping_list(self) -> List:
166
+ def _make_inverse_mapping_list(self) -> list:
166
167
  inverse_mapping_list = [0 for _ in range(len(self.get_mapping()))]
167
168
  for k, value in self.get_mapping().items():
168
169
  inverse_mapping_list[value] = k
@@ -543,7 +544,7 @@ Convert type to string, integer, or float."
543
544
  @classmethod
544
545
  def load(cls, path: str) -> "LabelEncodingRule":
545
546
  base_path = Path(path).with_suffix(".replay").resolve()
546
- with open(base_path / "init_args.json", "r") as file:
547
+ with open(base_path / "init_args.json") as file:
547
548
  encoder_rule_dict = json.loads(file.read())
548
549
 
549
550
  string_column_type = encoder_rule_dict["fitted_args"]["column_type"]
@@ -901,7 +902,7 @@ class LabelEncoder:
901
902
  """
902
903
  return self.fit(df).transform(df)
903
904
 
904
- def set_handle_unknowns(self, handle_unknown_rules: Dict[str, HandleUnknownStrategies]) -> None:
905
+ def set_handle_unknowns(self, handle_unknown_rules: dict[str, HandleUnknownStrategies]) -> None:
905
906
  """
906
907
  Modify handle unknown strategy on already fitted encoder.
907
908
 
@@ -923,7 +924,7 @@ class LabelEncoder:
923
924
  rule = list(filter(lambda x: x.column == column, self.rules))
924
925
  rule[0].set_handle_unknown(handle_unknown)
925
926
 
926
- def set_default_values(self, default_value_rules: Dict[str, Optional[Union[int, str]]]) -> None:
927
+ def set_default_values(self, default_value_rules: dict[str, Optional[Union[int, str]]]) -> None:
927
928
  """
928
929
  Modify handle unknown strategy on already fitted encoder.
929
930
  Default value that will fill the unknown labels
@@ -974,13 +975,13 @@ class LabelEncoder:
974
975
  @classmethod
975
976
  def load(cls, path: str) -> "LabelEncoder":
976
977
  base_path = Path(path).with_suffix(".replay").resolve()
977
- with open(base_path / "init_args.json", "r") as file:
978
+ with open(base_path / "init_args.json") as file:
978
979
  encoder_dict = json.loads(file.read())
979
980
  rules = []
980
981
  for root, dirs, files in os.walk(str(base_path) + "/rules/"):
981
982
  for d in dirs:
982
983
  if d.split(".")[0] in encoder_dict["rule_names"]:
983
- with open(root + d + "/init_args.json", "r") as file:
984
+ with open(root + d + "/init_args.json") as file:
984
985
  encoder_rule_dict = json.loads(file.read())
985
986
  rules.append(globals()[encoder_rule_dict["_class_name"]].load(root + d))
986
987
 
@@ -1,4 +1,5 @@
1
- from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
1
+ from collections.abc import Iterable
2
+ from typing import Any, Optional, Union
2
3
 
3
4
  from replay.data import Dataset
4
5
  from replay.metrics import NDCG, Metric
@@ -125,12 +126,12 @@ class Fallback(BaseRecommender):
125
126
  self,
126
127
  train_dataset: Dataset,
127
128
  test_dataset: Dataset,
128
- param_borders: Optional[Dict[str, Dict[str, List[Any]]]] = None,
129
+ param_borders: Optional[dict[str, dict[str, list[Any]]]] = None,
129
130
  criterion: Metric = NDCG,
130
131
  k: int = 10,
131
132
  budget: int = 10,
132
133
  new_study: bool = True,
133
- ) -> Tuple[Dict[str, Any]]:
134
+ ) -> tuple[dict[str, Any]]:
134
135
  """
135
136
  Searches best parameters with optuna.
136
137
 
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  from abc import ABC, abstractmethod
3
3
  from pathlib import Path
4
- from typing import Optional, Tuple
4
+ from typing import Optional
5
5
 
6
6
  import polars as pl
7
7
 
@@ -20,7 +20,7 @@ if PYSPARK_AVAILABLE:
20
20
  )
21
21
 
22
22
 
23
- SplitterReturnType = Tuple[DataFrameLike, DataFrameLike]
23
+ SplitterReturnType = tuple[DataFrameLike, DataFrameLike]
24
24
 
25
25
 
26
26
  class Splitter(ABC):
@@ -90,7 +90,7 @@ class Splitter(ABC):
90
90
  Method for loading splitter from `.replay` directory.
91
91
  """
92
92
  base_path = Path(path).with_suffix(".replay").resolve()
93
- with open(base_path / "init_args.json", "r") as file:
93
+ with open(base_path / "init_args.json") as file:
94
94
  splitter_dict = json.loads(file.read())
95
95
  splitter = cls(**splitter_dict["init_args"])
96
96
 
@@ -1,4 +1,4 @@
1
- from typing import Optional, Tuple
1
+ from typing import Optional
2
2
 
3
3
  import polars as pl
4
4
 
@@ -62,7 +62,7 @@ class ColdUserRandomSplitter(Splitter):
62
62
 
63
63
  def _core_split_pandas(
64
64
  self, interactions: PandasDataFrame, threshold: float
65
- ) -> Tuple[PandasDataFrame, PandasDataFrame]:
65
+ ) -> tuple[PandasDataFrame, PandasDataFrame]:
66
66
  users = PandasDataFrame(interactions[self.query_column].unique(), columns=[self.query_column])
67
67
  train_users = users.sample(frac=(1 - threshold), random_state=self.seed)
68
68
  train_users["is_test"] = False
@@ -78,7 +78,7 @@ class ColdUserRandomSplitter(Splitter):
78
78
 
79
79
  def _core_split_spark(
80
80
  self, interactions: SparkDataFrame, threshold: float
81
- ) -> Tuple[SparkDataFrame, SparkDataFrame]:
81
+ ) -> tuple[SparkDataFrame, SparkDataFrame]:
82
82
  users = interactions.select(self.query_column).distinct()
83
83
  train_users, _ = users.randomSplit(
84
84
  [1 - threshold, threshold],
@@ -97,7 +97,7 @@ class ColdUserRandomSplitter(Splitter):
97
97
 
98
98
  def _core_split_polars(
99
99
  self, interactions: PolarsDataFrame, threshold: float
100
- ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
100
+ ) -> tuple[PolarsDataFrame, PolarsDataFrame]:
101
101
  train_users = (
102
102
  interactions.select(self.query_column)
103
103
  .unique(maintain_order=True)
@@ -1,4 +1,4 @@
1
- from typing import Literal, Optional, Tuple
1
+ from typing import Literal, Optional
2
2
 
3
3
  import polars as pl
4
4
 
@@ -83,7 +83,7 @@ class KFolds(Splitter):
83
83
  """
84
84
  return self._core_split(interactions)
85
85
 
86
- def _query_split_spark(self, interactions: SparkDataFrame) -> Tuple[SparkDataFrame, SparkDataFrame]:
86
+ def _query_split_spark(self, interactions: SparkDataFrame) -> tuple[SparkDataFrame, SparkDataFrame]:
87
87
  dataframe = interactions.withColumn("_rand", sf.rand(self.seed))
88
88
  dataframe = dataframe.withColumn(
89
89
  "fold",
@@ -100,7 +100,7 @@ class KFolds(Splitter):
100
100
  test = self._drop_cold_items_and_users(train, test)
101
101
  yield train, test
102
102
 
103
- def _query_split_pandas(self, interactions: PandasDataFrame) -> Tuple[PandasDataFrame, PandasDataFrame]:
103
+ def _query_split_pandas(self, interactions: PandasDataFrame) -> tuple[PandasDataFrame, PandasDataFrame]:
104
104
  dataframe = interactions.sample(frac=1, random_state=self.seed).sort_values(self.query_column)
105
105
  dataframe["fold"] = (dataframe.groupby(self.query_column, sort=False).cumcount() + 1) % self.n_folds
106
106
  for i in range(self.n_folds):
@@ -115,7 +115,7 @@ class KFolds(Splitter):
115
115
  test = self._drop_cold_items_and_users(train, test)
116
116
  yield train, test
117
117
 
118
- def _query_split_polars(self, interactions: PolarsDataFrame) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
118
+ def _query_split_polars(self, interactions: PolarsDataFrame) -> tuple[PolarsDataFrame, PolarsDataFrame]:
119
119
  dataframe = interactions.sample(fraction=1, shuffle=True, seed=self.seed).sort(self.query_column)
120
120
  dataframe = dataframe.with_columns(
121
121
  (pl.cum_count(self.query_column).over(self.query_column) % self.n_folds).alias("fold")
@@ -1,4 +1,4 @@
1
- from typing import List, Literal, Optional, Tuple
1
+ from typing import Literal, Optional
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -240,7 +240,7 @@ class LastNSplitter(Splitter):
240
240
 
241
241
  return interactions
242
242
 
243
- def _partial_split_interactions(self, interactions: DataFrameLike, n: int) -> Tuple[DataFrameLike, DataFrameLike]:
243
+ def _partial_split_interactions(self, interactions: DataFrameLike, n: int) -> tuple[DataFrameLike, DataFrameLike]:
244
244
  res = self._add_time_partition(interactions)
245
245
  if isinstance(interactions, SparkDataFrame):
246
246
  return self._partial_split_interactions_spark(res, n)
@@ -250,7 +250,7 @@ class LastNSplitter(Splitter):
250
250
 
251
251
  def _partial_split_interactions_pandas(
252
252
  self, interactions: PandasDataFrame, n: int
253
- ) -> Tuple[PandasDataFrame, PandasDataFrame]:
253
+ ) -> tuple[PandasDataFrame, PandasDataFrame]:
254
254
  interactions["count"] = interactions.groupby(self.divide_column, sort=False)[self.divide_column].transform(len)
255
255
  interactions["is_test"] = interactions["row_num"] > (interactions["count"] - float(n))
256
256
  if self.session_id_column:
@@ -263,7 +263,7 @@ class LastNSplitter(Splitter):
263
263
 
264
264
  def _partial_split_interactions_spark(
265
265
  self, interactions: SparkDataFrame, n: int
266
- ) -> Tuple[SparkDataFrame, SparkDataFrame]:
266
+ ) -> tuple[SparkDataFrame, SparkDataFrame]:
267
267
  interactions = interactions.withColumn(
268
268
  "count",
269
269
  sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column)),
@@ -281,7 +281,7 @@ class LastNSplitter(Splitter):
281
281
 
282
282
  def _partial_split_interactions_polars(
283
283
  self, interactions: PolarsDataFrame, n: int
284
- ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
284
+ ) -> tuple[PolarsDataFrame, PolarsDataFrame]:
285
285
  interactions = interactions.with_columns(
286
286
  pl.col(self.timestamp_column).count().over(self.divide_column).alias("count")
287
287
  )
@@ -296,7 +296,7 @@ class LastNSplitter(Splitter):
296
296
 
297
297
  def _partial_split_timedelta(
298
298
  self, interactions: DataFrameLike, timedelta: int
299
- ) -> Tuple[DataFrameLike, DataFrameLike]:
299
+ ) -> tuple[DataFrameLike, DataFrameLike]:
300
300
  if isinstance(interactions, SparkDataFrame):
301
301
  return self._partial_split_timedelta_spark(interactions, timedelta)
302
302
  if isinstance(interactions, PandasDataFrame):
@@ -305,7 +305,7 @@ class LastNSplitter(Splitter):
305
305
 
306
306
  def _partial_split_timedelta_pandas(
307
307
  self, interactions: PandasDataFrame, timedelta: int
308
- ) -> Tuple[PandasDataFrame, PandasDataFrame]:
308
+ ) -> tuple[PandasDataFrame, PandasDataFrame]:
309
309
  res = interactions.copy(deep=True)
310
310
  res["diff_timestamp"] = (
311
311
  res.groupby(self.divide_column)[self.timestamp_column].transform(max) - res[self.timestamp_column]
@@ -321,7 +321,7 @@ class LastNSplitter(Splitter):
321
321
 
322
322
  def _partial_split_timedelta_spark(
323
323
  self, interactions: SparkDataFrame, timedelta: int
324
- ) -> Tuple[SparkDataFrame, SparkDataFrame]:
324
+ ) -> tuple[SparkDataFrame, SparkDataFrame]:
325
325
  inter_with_max_time = interactions.withColumn(
326
326
  "max_timestamp",
327
327
  sf.max(self.timestamp_column).over(Window.partitionBy(self.divide_column)),
@@ -343,7 +343,7 @@ class LastNSplitter(Splitter):
343
343
 
344
344
  def _partial_split_timedelta_polars(
345
345
  self, interactions: PolarsDataFrame, timedelta: int
346
- ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
346
+ ) -> tuple[PolarsDataFrame, PolarsDataFrame]:
347
347
  res = interactions.with_columns(
348
348
  (pl.col(self.timestamp_column).max().over(self.divide_column) - pl.col(self.timestamp_column)).alias(
349
349
  "diff_timestamp"
@@ -358,7 +358,7 @@ class LastNSplitter(Splitter):
358
358
 
359
359
  return train, test
360
360
 
361
- def _core_split(self, interactions: DataFrameLike) -> List[DataFrameLike]:
361
+ def _core_split(self, interactions: DataFrameLike) -> list[DataFrameLike]:
362
362
  if self.strategy == "timedelta":
363
363
  interactions = self._to_unix_timestamp(interactions)
364
364
  train, test = getattr(self, "_partial_split_" + self.strategy)(interactions, self.N)
@@ -1,4 +1,4 @@
1
- from typing import Optional, Tuple
1
+ from typing import Optional
2
2
 
3
3
  import polars as pl
4
4
 
@@ -100,7 +100,7 @@ class NewUsersSplitter(Splitter):
100
100
 
101
101
  def _core_split_pandas(
102
102
  self, interactions: PandasDataFrame, threshold: float
103
- ) -> Tuple[PandasDataFrame, PandasDataFrame]:
103
+ ) -> tuple[PandasDataFrame, PandasDataFrame]:
104
104
  start_date_by_user = (
105
105
  interactions.groupby(self.query_column).agg(_start_dt_by_user=(self.timestamp_column, "min")).reset_index()
106
106
  )
@@ -134,7 +134,7 @@ class NewUsersSplitter(Splitter):
134
134
 
135
135
  def _core_split_spark(
136
136
  self, interactions: SparkDataFrame, threshold: float
137
- ) -> Tuple[SparkDataFrame, SparkDataFrame]:
137
+ ) -> tuple[SparkDataFrame, SparkDataFrame]:
138
138
  start_date_by_user = interactions.groupby(self.query_column).agg(
139
139
  sf.min(self.timestamp_column).alias("_start_dt_by_user")
140
140
  )
@@ -171,7 +171,7 @@ class NewUsersSplitter(Splitter):
171
171
 
172
172
  def _core_split_polars(
173
173
  self, interactions: PolarsDataFrame, threshold: float
174
- ) -> Tuple[PolarsDataFrame, PolarsDataFrame]:
174
+ ) -> tuple[PolarsDataFrame, PolarsDataFrame]:
175
175
  start_date_by_user = interactions.group_by(self.query_column).agg(
176
176
  pl.col(self.timestamp_column).min().alias("_start_dt_by_user")
177
177
  )