replay-rec 0.20.0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +10 -9
  3. replay/data/dataset_utils/dataset_label_encoder.py +5 -4
  4. replay/data/nn/schema.py +9 -18
  5. replay/data/nn/sequence_tokenizer.py +26 -18
  6. replay/data/nn/sequential_dataset.py +22 -18
  7. replay/data/nn/torch_sequential_dataset.py +17 -16
  8. replay/data/nn/utils.py +2 -1
  9. replay/data/schema.py +3 -12
  10. replay/metrics/base_metric.py +11 -10
  11. replay/metrics/categorical_diversity.py +8 -8
  12. replay/metrics/coverage.py +4 -4
  13. replay/metrics/experiment.py +3 -3
  14. replay/metrics/hitrate.py +1 -3
  15. replay/metrics/map.py +1 -3
  16. replay/metrics/mrr.py +1 -3
  17. replay/metrics/ndcg.py +1 -2
  18. replay/metrics/novelty.py +3 -3
  19. replay/metrics/offline_metrics.py +16 -16
  20. replay/metrics/precision.py +1 -3
  21. replay/metrics/recall.py +1 -3
  22. replay/metrics/rocauc.py +1 -3
  23. replay/metrics/surprisal.py +4 -4
  24. replay/metrics/torch_metrics_builder.py +13 -12
  25. replay/metrics/unexpectedness.py +2 -2
  26. replay/models/als.py +2 -2
  27. replay/models/association_rules.py +4 -3
  28. replay/models/base_neighbour_rec.py +3 -2
  29. replay/models/base_rec.py +11 -10
  30. replay/models/cat_pop_rec.py +2 -1
  31. replay/models/extensions/ann/ann_mixin.py +2 -1
  32. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
  33. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
  34. replay/models/lin_ucb.py +57 -11
  35. replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
  36. replay/models/nn/sequential/bert4rec/dataset.py +5 -18
  37. replay/models/nn/sequential/bert4rec/lightning.py +3 -3
  38. replay/models/nn/sequential/bert4rec/model.py +2 -2
  39. replay/models/nn/sequential/callbacks/prediction_callbacks.py +12 -12
  40. replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
  41. replay/models/nn/sequential/compiled/base_compiled_model.py +5 -5
  42. replay/models/nn/sequential/postprocessors/_base.py +2 -3
  43. replay/models/nn/sequential/postprocessors/postprocessors.py +11 -11
  44. replay/models/nn/sequential/sasrec/dataset.py +3 -16
  45. replay/models/nn/sequential/sasrec/lightning.py +3 -3
  46. replay/models/nn/sequential/sasrec/model.py +8 -8
  47. replay/models/slim.py +2 -2
  48. replay/models/ucb.py +2 -2
  49. replay/models/word2vec.py +3 -3
  50. replay/preprocessing/discretizer.py +8 -7
  51. replay/preprocessing/filters.py +4 -4
  52. replay/preprocessing/history_based_fp.py +6 -6
  53. replay/preprocessing/label_encoder.py +8 -7
  54. replay/scenarios/fallback.py +4 -3
  55. replay/splitters/base_splitter.py +3 -3
  56. replay/splitters/cold_user_random_splitter.py +4 -4
  57. replay/splitters/k_folds.py +4 -4
  58. replay/splitters/last_n_splitter.py +10 -10
  59. replay/splitters/new_users_splitter.py +4 -4
  60. replay/splitters/random_splitter.py +4 -4
  61. replay/splitters/ratio_splitter.py +10 -10
  62. replay/splitters/time_splitter.py +6 -6
  63. replay/splitters/two_stage_splitter.py +4 -4
  64. replay/utils/__init__.py +1 -1
  65. replay/utils/common.py +1 -1
  66. replay/utils/session_handler.py +2 -2
  67. replay/utils/spark_utils.py +6 -5
  68. replay/utils/types.py +3 -1
  69. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/METADATA +7 -1
  70. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/RECORD +73 -74
  71. replay/utils/warnings.py +0 -26
  72. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/WHEEL +0 -0
  73. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/LICENSE +0 -0
  74. {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/NOTICE +0 -0
replay/models/base_rec.py CHANGED
@@ -13,8 +13,9 @@ Base abstract classes:
13
13
 
14
14
  import warnings
15
15
  from abc import ABC, abstractmethod
16
+ from collections.abc import Iterable
16
17
  from os.path import join
17
- from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
18
+ from typing import Any, Optional, Union
18
19
 
19
20
  import numpy as np
20
21
  import pandas as pd
@@ -55,14 +56,14 @@ class IsSavable(ABC):
55
56
 
56
57
  @property
57
58
  @abstractmethod
58
- def _init_args(self) -> Dict:
59
+ def _init_args(self) -> dict:
59
60
  """
60
61
  Dictionary of the model attributes passed during model initialization.
61
62
  Used for model saving and loading
62
63
  """
63
64
 
64
65
  @property
65
- def _dataframes(self) -> Dict:
66
+ def _dataframes(self) -> dict:
66
67
  """
67
68
  Dictionary of the model dataframes required for inference.
68
69
  Used for model saving and loading
@@ -508,7 +509,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
508
509
  or None if `file_path` is provided
509
510
  """
510
511
  if dataset is not None:
511
- interactions, query_features, item_features, pairs = [
512
+ interactions, query_features, item_features, pairs = (
512
513
  convert2spark(df)
513
514
  for df in [
514
515
  dataset.interactions,
@@ -516,7 +517,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
516
517
  dataset.item_features,
517
518
  pairs,
518
519
  ]
519
- ]
520
+ )
520
521
  if set(pairs.columns) != {self.item_column, self.query_column}:
521
522
  msg = "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
522
523
  raise ValueError(msg)
@@ -590,7 +591,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
590
591
 
591
592
  def _get_features_wrap(
592
593
  self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
593
- ) -> Optional[Tuple[SparkDataFrame, int]]:
594
+ ) -> Optional[tuple[SparkDataFrame, int]]:
594
595
  if self.query_column not in ids.columns and self.item_column not in ids.columns:
595
596
  msg = f"{self.query_column} or {self.item_column} missing"
596
597
  raise ValueError(msg)
@@ -599,7 +600,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
599
600
 
600
601
  def _get_features(
601
602
  self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
602
- ) -> Tuple[Optional[SparkDataFrame], Optional[int]]:
603
+ ) -> tuple[Optional[SparkDataFrame], Optional[int]]:
603
604
  """
604
605
  Get embeddings from model
605
606
 
@@ -679,7 +680,7 @@ class ItemVectorModel(BaseRecommender):
679
680
  """Parent for models generating items' vector representations"""
680
681
 
681
682
  can_predict_item_to_item: bool = True
682
- item_to_item_metrics: List[str] = [
683
+ item_to_item_metrics: list[str] = [
683
684
  "euclidean_distance_sim",
684
685
  "cosine_similarity",
685
686
  "dot_product",
@@ -899,7 +900,7 @@ class HybridRecommender(BaseRecommender, ABC):
899
900
 
900
901
  def get_features(
901
902
  self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
902
- ) -> Optional[Tuple[SparkDataFrame, int]]:
903
+ ) -> Optional[tuple[SparkDataFrame, int]]:
903
904
  """
904
905
  Returns query or item feature vectors as a Column with type ArrayType
905
906
  If a model does not have a vector for some ids they are not present in the final result.
@@ -1026,7 +1027,7 @@ class Recommender(BaseRecommender, ABC):
1026
1027
  recs_file_path=recs_file_path,
1027
1028
  )
1028
1029
 
1029
- def get_features(self, ids: SparkDataFrame) -> Optional[Tuple[SparkDataFrame, int]]:
1030
+ def get_features(self, ids: SparkDataFrame) -> Optional[tuple[SparkDataFrame, int]]:
1030
1031
  """
1031
1032
  Returns query or item feature vectors as a Column with type ArrayType
1032
1033
 
@@ -1,5 +1,6 @@
1
+ from collections.abc import Iterable
1
2
  from os.path import join
2
- from typing import Iterable, Optional, Union
3
+ from typing import Optional, Union
3
4
 
4
5
  from replay.data import Dataset
5
6
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
@@ -2,7 +2,8 @@ import importlib
2
2
  import logging
3
3
  import sys
4
4
  from abc import abstractmethod
5
- from typing import Any, Iterable, Optional, Union
5
+ from collections.abc import Iterable
6
+ from typing import Any, Optional, Union
6
7
 
7
8
  from replay.data import Dataset
8
9
  from replay.models.common import RecommenderCommons
@@ -1,5 +1,6 @@
1
1
  import logging
2
- from typing import Iterator, Optional
2
+ from collections.abc import Iterator
3
+ from typing import Optional
3
4
 
4
5
  import numpy as np
5
6
 
@@ -1,5 +1,6 @@
1
1
  import logging
2
- from typing import Iterator, Optional
2
+ from collections.abc import Iterator
3
+ from typing import Optional
3
4
 
4
5
  import pandas as pd
5
6
 
replay/models/lin_ucb.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import warnings
2
- from typing import List, Tuple, Union
2
+ from os.path import join
3
+ from typing import Optional, Union
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
@@ -8,7 +9,11 @@ from tqdm import tqdm
8
9
 
9
10
  from replay.data.dataset import Dataset
10
11
  from replay.utils import SparkDataFrame
11
- from replay.utils.spark_utils import convert2spark
12
+ from replay.utils.spark_utils import (
13
+ convert2spark,
14
+ load_pickled_from_parquet,
15
+ save_picklable_to_parquet,
16
+ )
12
17
 
13
18
  from .base_rec import HybridRecommender
14
19
 
@@ -70,7 +75,7 @@ class HybridArm:
70
75
  # right-hand side of the regression
71
76
  self.b = np.zeros(d, dtype=float)
72
77
 
73
- def feature_update(self, usr_features, usr_itm_features, relevances) -> Tuple[np.ndarray, np.ndarray]:
78
+ def feature_update(self, usr_features, usr_itm_features, relevances) -> tuple[np.ndarray, np.ndarray]:
74
79
  """
75
80
  Function to update featurs or each Lin-UCB hand in the current model.
76
81
 
@@ -175,8 +180,9 @@ class LinUCB(HybridRecommender):
175
180
  "alpha": {"type": "uniform", "args": [0.001, 10.0]},
176
181
  }
177
182
  _study = None # field required for proper optuna's optimization
178
- linucb_arms: List[Union[DisjointArm, HybridArm]] # initialize only when working within fit method
183
+ linucb_arms: list[Union[DisjointArm, HybridArm]] # initialize only when working within fit method
179
184
  rel_matrix: np.array # matrix with relevance scores from predict method
185
+ _num_items: int # number of items/arms
180
186
 
181
187
  def __init__(
182
188
  self,
@@ -195,7 +201,7 @@ class LinUCB(HybridRecommender):
195
201
 
196
202
  @property
197
203
  def _init_args(self):
198
- return {"is_hybrid": self.is_hybrid}
204
+ return {"is_hybrid": self.is_hybrid, "eps": self.eps, "alpha": self.alpha}
199
205
 
200
206
  def _verify_features(self, dataset: Dataset):
201
207
  if dataset.query_features is None:
@@ -230,6 +236,7 @@ class LinUCB(HybridRecommender):
230
236
  self._num_items = item_features.shape[0]
231
237
  self._user_dim_size = user_features.shape[1] - 1
232
238
  self._item_dim_size = item_features.shape[1] - 1
239
+ self._user_idxs_list = set(user_features[feature_schema.query_id_column].values)
233
240
 
234
241
  # now initialize an arm object for each potential arm instance
235
242
  if self.is_hybrid:
@@ -248,11 +255,14 @@ class LinUCB(HybridRecommender):
248
255
  ]
249
256
 
250
257
  for i in tqdm(range(self._num_items)):
251
- B = log.loc[log[feature_schema.item_id_column] == i] # noqa: N806
252
- idxs_list = B[feature_schema.query_id_column].values
253
- rel_list = B[feature_schema.interactions_rating_column].values
258
+ B = log.loc[ # noqa: N806
259
+ (log[feature_schema.item_id_column] == i)
260
+ & (log[feature_schema.query_id_column].isin(self._user_idxs_list))
261
+ ]
254
262
  if not B.empty:
255
263
  # if we have at least one user interacting with the hand i
264
+ idxs_list = B[feature_schema.query_id_column].values
265
+ rel_list = B[feature_schema.interactions_rating_column].values
256
266
  cur_usrs = scs.csr_matrix(
257
267
  user_features.query(f"{feature_schema.query_id_column} in @idxs_list")
258
268
  .drop(columns=[feature_schema.query_id_column])
@@ -284,11 +294,14 @@ class LinUCB(HybridRecommender):
284
294
  ]
285
295
 
286
296
  for i in range(self._num_items):
287
- B = log.loc[log[feature_schema.item_id_column] == i] # noqa: N806
288
- idxs_list = B[feature_schema.query_id_column].values # noqa: F841
289
- rel_list = B[feature_schema.interactions_rating_column].values
297
+ B = log.loc[ # noqa: N806
298
+ (log[feature_schema.item_id_column] == i)
299
+ & (log[feature_schema.query_id_column].isin(self._user_idxs_list))
300
+ ]
290
301
  if not B.empty:
291
302
  # if we have at least one user interacting with the hand i
303
+ idxs_list = B[feature_schema.query_id_column].values # noqa: F841
304
+ rel_list = B[feature_schema.interactions_rating_column].values
292
305
  cur_usrs = user_features.query(f"{feature_schema.query_id_column} in @idxs_list").drop(
293
306
  columns=[feature_schema.query_id_column]
294
307
  )
@@ -318,8 +331,10 @@ class LinUCB(HybridRecommender):
318
331
  user_features = dataset.query_features
319
332
  item_features = dataset.item_features
320
333
  big_k = min(oversample * k, item_features.shape[0])
334
+ self._user_idxs_list = set(user_features[feature_schema.query_id_column].values)
321
335
 
322
336
  users = users.toPandas()
337
+ users = users[users[feature_schema.query_id_column].isin(self._user_idxs_list)]
323
338
  num_user_pred = users.shape[0]
324
339
  rel_matrix = np.zeros((num_user_pred, self._num_items), dtype=float)
325
340
 
@@ -404,3 +419,34 @@ class LinUCB(HybridRecommender):
404
419
  warnings.warn(warn_msg)
405
420
  dataset.to_spark()
406
421
  return convert2spark(res_df)
422
+
423
+ def _save_model(self, path: str, additional_params: Optional[dict] = None):
424
+ super()._save_model(path, additional_params)
425
+
426
+ save_picklable_to_parquet(self.linucb_arms, join(path, "linucb_arms.dump"))
427
+
428
+ if self.is_hybrid:
429
+ linucb_hybrid_shared_params = {
430
+ "A_0": self.A_0,
431
+ "A_0_inv": self.A_0_inv,
432
+ "b_0": self.b_0,
433
+ "beta": self.beta,
434
+ }
435
+ save_picklable_to_parquet(
436
+ linucb_hybrid_shared_params,
437
+ join(path, "linucb_hybrid_shared_params.dump"),
438
+ )
439
+
440
+ def _load_model(self, path: str):
441
+ super()._load_model(path)
442
+
443
+ loaded_linucb_arms = load_pickled_from_parquet(join(path, "linucb_arms.dump"))
444
+ self.linucb_arms = loaded_linucb_arms
445
+ self._num_items = len(loaded_linucb_arms)
446
+
447
+ if self.is_hybrid:
448
+ loaded_linucb_hybrid_shared_params = load_pickled_from_parquet(
449
+ join(path, "linucb_hybrid_shared_params.dump")
450
+ )
451
+ for param, value in loaded_linucb_hybrid_shared_params.items():
452
+ setattr(self, param, value)
@@ -1,5 +1,5 @@
1
1
  import abc
2
- from typing import Iterator, Tuple
2
+ from collections.abc import Iterator
3
3
 
4
4
  import torch
5
5
 
@@ -47,7 +47,7 @@ class FatOptimizerFactory(OptimizerFactory):
47
47
  learning_rate: float = 0.001,
48
48
  weight_decay: float = 0.0,
49
49
  sgd_momentum: float = 0.0,
50
- betas: Tuple[float, float] = (0.9, 0.98),
50
+ betas: tuple[float, float] = (0.9, 0.98),
51
51
  ) -> None:
52
52
  super().__init__()
53
53
  self.optimizer = optimizer
@@ -1,5 +1,5 @@
1
1
  import abc
2
- from typing import NamedTuple, Optional, Tuple, cast
2
+ from typing import NamedTuple, Optional, cast
3
3
 
4
4
  import torch
5
5
  from torch.utils.data import Dataset as TorchDataset
@@ -12,7 +12,6 @@ from replay.data.nn import (
12
12
  TorchSequentialDataset,
13
13
  TorchSequentialValidationDataset,
14
14
  )
15
- from replay.utils import deprecation_warning
16
15
 
17
16
 
18
17
  class Bert4RecTrainingBatch(NamedTuple):
@@ -89,10 +88,6 @@ class Bert4RecTrainingDataset(TorchDataset):
89
88
  Dataset that generates samples to train BERT-like model
90
89
  """
91
90
 
92
- @deprecation_warning(
93
- "`padding_value` parameter will be removed in future versions. "
94
- "Instead, you should specify `padding_value` for each column in TensorSchema"
95
- )
96
91
  def __init__(
97
92
  self,
98
93
  sequential: SequentialDataset,
@@ -101,7 +96,7 @@ class Bert4RecTrainingDataset(TorchDataset):
101
96
  sliding_window_step: Optional[int] = None,
102
97
  label_feature_name: Optional[str] = None,
103
98
  custom_masker: Optional[Bert4RecMasker] = None,
104
- padding_value: int = 0,
99
+ padding_value: Optional[int] = None,
105
100
  ) -> None:
106
101
  """
107
102
  :param sequential: Sequential dataset with training data.
@@ -181,15 +176,11 @@ class Bert4RecPredictionDataset(TorchDataset):
181
176
  Dataset that generates samples to infer BERT-like model
182
177
  """
183
178
 
184
- @deprecation_warning(
185
- "`padding_value` parameter will be removed in future versions. "
186
- "Instead, you should specify `padding_value` for each column in TensorSchema"
187
- )
188
179
  def __init__(
189
180
  self,
190
181
  sequential: SequentialDataset,
191
182
  max_sequence_length: int,
192
- padding_value: int = 0,
183
+ padding_value: Optional[int] = None,
193
184
  ) -> None:
194
185
  """
195
186
  :param sequential: Sequential dataset with data to make predictions at.
@@ -239,17 +230,13 @@ class Bert4RecValidationDataset(TorchDataset):
239
230
  Dataset that generates samples to infer and validate BERT-like model
240
231
  """
241
232
 
242
- @deprecation_warning(
243
- "`padding_value` parameter will be removed in future versions. "
244
- "Instead, you should specify `padding_value` for each column in TensorSchema"
245
- )
246
233
  def __init__(
247
234
  self,
248
235
  sequential: SequentialDataset,
249
236
  ground_truth: SequentialDataset,
250
237
  train: SequentialDataset,
251
238
  max_sequence_length: int,
252
- padding_value: int = 0,
239
+ padding_value: Optional[int] = None,
253
240
  label_feature_name: Optional[str] = None,
254
241
  ):
255
242
  """
@@ -295,7 +282,7 @@ def _shift_features(
295
282
  schema: TensorSchema,
296
283
  features: TensorMap,
297
284
  padding_mask: torch.BoolTensor,
298
- ) -> Tuple[TensorMap, torch.BoolTensor, torch.BoolTensor]:
285
+ ) -> tuple[TensorMap, torch.BoolTensor, torch.BoolTensor]:
299
286
  shifted_features: MutableTensorMap = {}
300
287
  for feature_name, feature in schema.items():
301
288
  if feature.is_seq:
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import Any, Dict, Literal, Optional, Tuple, Union, cast
2
+ from typing import Any, Literal, Optional, Union, cast
3
3
 
4
4
  import lightning
5
5
  import torch
@@ -338,7 +338,7 @@ class Bert4Rec(lightning.LightningModule):
338
338
  positive_labels: torch.LongTensor,
339
339
  padding_mask: torch.BoolTensor,
340
340
  tokens_mask: torch.BoolTensor,
341
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.LongTensor, int]:
341
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.LongTensor, int]:
342
342
  assert self._loss_sample_count is not None
343
343
  n_negative_samples = self._loss_sample_count
344
344
 
@@ -440,7 +440,7 @@ class Bert4Rec(lightning.LightningModule):
440
440
  msg = "Not supported loss_type"
441
441
  raise NotImplementedError(msg)
442
442
 
443
- def get_all_embeddings(self) -> Dict[str, torch.nn.Embedding]:
443
+ def get_all_embeddings(self) -> dict[str, torch.nn.Embedding]:
444
444
  """
445
445
  :returns: copy of all embeddings as a dictionary.
446
446
  """
@@ -1,6 +1,6 @@
1
1
  import contextlib
2
2
  from abc import ABC, abstractmethod
3
- from typing import Dict, Optional, Union
3
+ from typing import Optional, Union
4
4
 
5
5
  import torch
6
6
  import torch.nn as nn
@@ -303,7 +303,7 @@ class BertEmbedding(torch.nn.Module):
303
303
  """
304
304
  return self.cat_embeddings[self.schema.item_id_feature_name].weight
305
305
 
306
- def get_all_embeddings(self) -> Dict[str, torch.Tensor]:
306
+ def get_all_embeddings(self) -> dict[str, torch.Tensor]:
307
307
  """
308
308
  :returns: copy of all embeddings presented in this layer as a dict.
309
309
  """
@@ -1,5 +1,5 @@
1
1
  import abc
2
- from typing import Generic, List, Optional, Protocol, Tuple, TypeVar, cast
2
+ from typing import Generic, Optional, Protocol, TypeVar, cast
3
3
 
4
4
  import lightning
5
5
  import torch
@@ -38,7 +38,7 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
38
38
  query_column: str,
39
39
  item_column: str,
40
40
  rating_column: str = "rating",
41
- postprocessors: Optional[List[BasePostProcessor]] = None,
41
+ postprocessors: Optional[list[BasePostProcessor]] = None,
42
42
  ) -> None:
43
43
  """
44
44
  :param top_k: Takes the highest k scores in the ranking.
@@ -52,10 +52,10 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
52
52
  self.item_column = item_column
53
53
  self.rating_column = rating_column
54
54
  self._top_k = top_k
55
- self._postprocessors: List[BasePostProcessor] = postprocessors or []
56
- self._query_batches: List[torch.Tensor] = []
57
- self._item_batches: List[torch.Tensor] = []
58
- self._item_scores: List[torch.Tensor] = []
55
+ self._postprocessors: list[BasePostProcessor] = postprocessors or []
56
+ self._query_batches: list[torch.Tensor] = []
57
+ self._item_batches: list[torch.Tensor] = []
58
+ self._item_scores: list[torch.Tensor] = []
59
59
 
60
60
  def on_predict_epoch_start(
61
61
  self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
@@ -97,7 +97,7 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
97
97
 
98
98
  def _compute_pipeline(
99
99
  self, query_ids: torch.LongTensor, scores: torch.Tensor
100
- ) -> Tuple[torch.LongTensor, torch.Tensor]:
100
+ ) -> tuple[torch.LongTensor, torch.Tensor]:
101
101
  for postprocessor in self._postprocessors:
102
102
  query_ids, scores = postprocessor.on_prediction(query_ids, scores)
103
103
  return query_ids, scores
@@ -166,7 +166,7 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
166
166
  item_column: str,
167
167
  rating_column: str,
168
168
  spark_session: SparkSession,
169
- postprocessors: Optional[List[BasePostProcessor]] = None,
169
+ postprocessors: Optional[list[BasePostProcessor]] = None,
170
170
  ) -> None:
171
171
  """
172
172
  :param top_k: Takes the highest k scores in the ranking.
@@ -213,7 +213,7 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
213
213
  return prediction
214
214
 
215
215
 
216
- class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]]):
216
+ class TorchPredictionCallback(BasePredictionCallback[tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]]):
217
217
  """
218
218
  Callback for predition stage with tuple of tensors
219
219
  """
@@ -221,7 +221,7 @@ class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, tor
221
221
  def __init__(
222
222
  self,
223
223
  top_k: int,
224
- postprocessors: Optional[List[BasePostProcessor]] = None,
224
+ postprocessors: Optional[list[BasePostProcessor]] = None,
225
225
  ) -> None:
226
226
  """
227
227
  :param top_k: Takes the highest k scores in the ranking.
@@ -240,7 +240,7 @@ class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, tor
240
240
  query_ids: torch.Tensor,
241
241
  item_ids: torch.Tensor,
242
242
  item_scores: torch.Tensor,
243
- ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]:
243
+ ) -> tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]:
244
244
  return (
245
245
  cast(torch.LongTensor, query_ids.flatten().cpu().long()),
246
246
  cast(torch.LongTensor, item_ids.cpu().long()),
@@ -254,7 +254,7 @@ class QueryEmbeddingsPredictionCallback(lightning.Callback):
254
254
  """
255
255
 
256
256
  def __init__(self):
257
- self._embeddings_per_batch: List[torch.Tensor] = []
257
+ self._embeddings_per_batch: list[torch.Tensor] = []
258
258
 
259
259
  def on_predict_epoch_start(
260
260
  self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
@@ -1,4 +1,4 @@
1
- from typing import Any, List, Literal, Optional, Protocol, Tuple
1
+ from typing import Any, Literal, Optional, Protocol
2
2
 
3
3
  import lightning
4
4
  import torch
@@ -38,9 +38,9 @@ class ValidationMetricsCallback(lightning.Callback):
38
38
 
39
39
  def __init__(
40
40
  self,
41
- metrics: Optional[List[CallbackMetricName]] = None,
42
- ks: Optional[List[int]] = None,
43
- postprocessors: Optional[List[BasePostProcessor]] = None,
41
+ metrics: Optional[list[CallbackMetricName]] = None,
42
+ ks: Optional[list[int]] = None,
43
+ postprocessors: Optional[list[BasePostProcessor]] = None,
44
44
  item_count: Optional[int] = None,
45
45
  ):
46
46
  """
@@ -52,11 +52,11 @@ class ValidationMetricsCallback(lightning.Callback):
52
52
  self._metrics = metrics
53
53
  self._ks = ks
54
54
  self._item_count = item_count
55
- self._metrics_builders: List[TorchMetricsBuilder] = []
56
- self._dataloaders_size: List[int] = []
57
- self._postprocessors: List[BasePostProcessor] = postprocessors or []
55
+ self._metrics_builders: list[TorchMetricsBuilder] = []
56
+ self._dataloaders_size: list[int] = []
57
+ self._postprocessors: list[BasePostProcessor] = postprocessors or []
58
58
 
59
- def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> List[int]:
59
+ def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> list[int]:
60
60
  if isinstance(dataloaders, torch.utils.data.DataLoader):
61
61
  return [len(dataloaders)]
62
62
  return [len(dataloader) for dataloader in dataloaders]
@@ -85,7 +85,7 @@ class ValidationMetricsCallback(lightning.Callback):
85
85
 
86
86
  def _compute_pipeline(
87
87
  self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
88
- ) -> Tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
88
+ ) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
89
89
  for postprocessor in self._postprocessors:
90
90
  query_ids, scores, ground_truth = postprocessor.on_validation(query_ids, scores, ground_truth)
91
91
  return query_ids, scores, ground_truth
@@ -1,7 +1,7 @@
1
1
  import pathlib
2
2
  import tempfile
3
3
  from abc import abstractmethod
4
- from typing import Any, Dict, List, Literal, Optional, Tuple, Union
4
+ from typing import Any, Literal, Optional, Union
5
5
 
6
6
  import lightning
7
7
  import openvino as ov
@@ -68,7 +68,7 @@ class BaseCompiledModel:
68
68
  """
69
69
  self._batch_size: int
70
70
  self._max_seq_len: int
71
- self._inputs_names: List[str]
71
+ self._inputs_names: list[str]
72
72
  self._output_name: str
73
73
 
74
74
  self._set_inner_params_from_openvino_model(compiled_model)
@@ -171,9 +171,9 @@ class BaseCompiledModel:
171
171
  @staticmethod
172
172
  def _run_model_compilation(
173
173
  lightning_model: lightning.LightningModule,
174
- model_input_sample: Tuple[Union[torch.Tensor, Dict[str, torch.Tensor]]],
175
- model_input_names: List[str],
176
- model_dynamic_axes_in_input: Dict[str, Dict],
174
+ model_input_sample: tuple[Union[torch.Tensor, dict[str, torch.Tensor]]],
175
+ model_input_names: list[str],
176
+ model_dynamic_axes_in_input: dict[str, dict],
177
177
  batch_size: int,
178
178
  num_candidates_to_score: Union[int, None],
179
179
  num_threads: Optional[int] = None,
@@ -1,5 +1,4 @@
1
1
  import abc
2
- from typing import Tuple
3
2
 
4
3
  import torch
5
4
 
@@ -10,7 +9,7 @@ class BasePostProcessor(abc.ABC): # pragma: no cover
10
9
  """
11
10
 
12
11
  @abc.abstractmethod
13
- def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> Tuple[torch.LongTensor, torch.Tensor]:
12
+ def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> tuple[torch.LongTensor, torch.Tensor]:
14
13
  """
15
14
  Prediction step.
16
15
 
@@ -24,7 +23,7 @@ class BasePostProcessor(abc.ABC): # pragma: no cover
24
23
  @abc.abstractmethod
25
24
  def on_validation(
26
25
  self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
27
- ) -> Tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
26
+ ) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
28
27
  """
29
28
  Validation step.
30
29
 
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Set, Tuple, Union, cast
1
+ from typing import Optional, Union, cast
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -22,7 +22,7 @@ class RemoveSeenItems(BasePostProcessor):
22
22
 
23
23
  def on_validation(
24
24
  self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
25
- ) -> Tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
25
+ ) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
26
26
  """
27
27
  Validation step.
28
28
 
@@ -36,7 +36,7 @@ class RemoveSeenItems(BasePostProcessor):
36
36
  modified_scores = self._compute_scores(query_ids, scores)
37
37
  return query_ids, modified_scores, ground_truth
38
38
 
39
- def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> Tuple[torch.LongTensor, torch.Tensor]:
39
+ def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> tuple[torch.LongTensor, torch.Tensor]:
40
40
  """
41
41
  Prediction step.
42
42
 
@@ -51,7 +51,7 @@ class RemoveSeenItems(BasePostProcessor):
51
51
 
52
52
  def _compute_scores(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> torch.Tensor:
53
53
  flat_seen_item_ids = self._get_flat_seen_item_ids(query_ids)
54
- return self._fill_item_ids(scores, flat_seen_item_ids, -np.inf)
54
+ return self._fill_item_ids(scores.clone(), flat_seen_item_ids, -np.inf)
55
55
 
56
56
  def _fill_item_ids(
57
57
  self,
@@ -124,13 +124,13 @@ class SampleItems(BasePostProcessor):
124
124
  self.sample_count = sample_count
125
125
  users = grouped_validation_items[user_col].to_numpy()
126
126
  items = grouped_validation_items[item_col].to_numpy()
127
- self.items_list: List[Set[int]] = [set() for _ in range(users.shape[0])]
127
+ self.items_list: list[set[int]] = [set() for _ in range(users.shape[0])]
128
128
  for i in range(users.shape[0]):
129
129
  self.items_list[users[i]] = set(items[i])
130
130
 
131
131
  def on_validation(
132
132
  self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
133
- ) -> Tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
133
+ ) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
134
134
  """
135
135
  Validation step.
136
136
 
@@ -143,7 +143,7 @@ class SampleItems(BasePostProcessor):
143
143
  modified_score = self._compute_score(query_ids, scores, ground_truth)
144
144
  return query_ids, modified_score, ground_truth
145
145
 
146
- def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> Tuple[torch.LongTensor, torch.Tensor]:
146
+ def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> tuple[torch.LongTensor, torch.Tensor]:
147
147
  """
148
148
  Prediction step.
149
149
 
@@ -160,8 +160,8 @@ class SampleItems(BasePostProcessor):
160
160
  ) -> torch.Tensor:
161
161
  batch_size = query_ids.shape[0]
162
162
  item_ids = ground_truth.cpu().numpy() if ground_truth is not None else None
163
- candidate_ids: List[torch.Tensor] = []
164
- candidate_labels: List[torch.Tensor] = []
163
+ candidate_ids: list[torch.Tensor] = []
164
+ candidate_labels: list[torch.Tensor] = []
165
165
  for user in range(batch_size):
166
166
  ground_truth_items = set(item_ids[user]) if ground_truth is not None else set()
167
167
  sample, label = self._generate_samples_for_user(ground_truth_items, self.items_list[user])
@@ -183,8 +183,8 @@ class SampleItems(BasePostProcessor):
183
183
  return new_scores.reshape_as(scores)
184
184
 
185
185
  def _generate_samples_for_user(
186
- self, ground_truth_items: Set[int], input_items: Set[int]
187
- ) -> Tuple[torch.Tensor, torch.Tensor]:
186
+ self, ground_truth_items: set[int], input_items: set[int]
187
+ ) -> tuple[torch.Tensor, torch.Tensor]:
188
188
  negative_sample_count = self.sample_count - len(ground_truth_items)
189
189
  assert negative_sample_count > 0
190
190