replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -1,13 +1,12 @@
1
- from typing import Any, List, Optional, Protocol, Tuple, Literal
1
+ from typing import Any, List, Literal, Optional, Protocol, Tuple
2
2
 
3
- import lightning as L
3
+ import lightning
4
4
  import torch
5
5
  from lightning.pytorch.utilities.rank_zero import rank_zero_only
6
6
 
7
7
  from replay.metrics.torch_metrics_builder import TorchMetricsBuilder, metrics_to_df
8
8
  from replay.models.nn.sequential.postprocessors import BasePostProcessor
9
9
 
10
-
11
10
  CallbackMetricName = Literal[
12
11
  "recall",
13
12
  "precision",
@@ -19,17 +18,17 @@ CallbackMetricName = Literal[
19
18
  ]
20
19
 
21
20
 
22
- # pylint: disable=too-few-public-methods
23
21
  class ValidationBatch(Protocol):
24
22
  """
25
23
  Validation callback batch
26
24
  """
25
+
27
26
  query_id: torch.LongTensor
28
27
  ground_truth: torch.LongTensor
29
28
  train: torch.LongTensor
30
29
 
31
30
 
32
- class ValidationMetricsCallback(L.Callback):
31
+ class ValidationMetricsCallback(lightning.Callback):
33
32
  """
34
33
  Callback for validation and testing stages.
35
34
 
@@ -37,7 +36,6 @@ class ValidationMetricsCallback(L.Callback):
37
36
  the suffix of the metric name will contain the serial number of the dataloader.
38
37
  """
39
38
 
40
- # pylint: disable=invalid-name
41
39
  def __init__(
42
40
  self,
43
41
  metrics: Optional[List[CallbackMetricName]] = None,
@@ -63,8 +61,9 @@ class ValidationMetricsCallback(L.Callback):
63
61
  return [len(dataloaders)]
64
62
  return [len(dataloader) for dataloader in dataloaders]
65
63
 
66
- # pylint: disable=unused-argument
67
- def on_validation_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
64
+ def on_validation_epoch_start(
65
+ self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
66
+ ) -> None:
68
67
  self._dataloaders_size = self._get_dataloaders_size(trainer.val_dataloaders)
69
68
  self._metrics_builders = [
70
69
  TorchMetricsBuilder(self._metrics, self._ks, self._item_count) for _ in self._dataloaders_size
@@ -72,8 +71,11 @@ class ValidationMetricsCallback(L.Callback):
72
71
  for builder in self._metrics_builders:
73
72
  builder.reset()
74
73
 
75
- # pylint: disable=unused-argument
76
- def on_test_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: # pragma: no cover
74
+ def on_test_epoch_start(
75
+ self,
76
+ trainer: lightning.Trainer,
77
+ pl_module: lightning.LightningModule, # noqa: ARG002
78
+ ) -> None: # pragma: no cover
77
79
  self._dataloaders_size = self._get_dataloaders_size(trainer.test_dataloaders)
78
80
  self._metrics_builders = [
79
81
  TorchMetricsBuilder(self._metrics, self._ks, self._item_count) for _ in self._dataloaders_size
@@ -88,11 +90,10 @@ class ValidationMetricsCallback(L.Callback):
88
90
  query_ids, scores, ground_truth = postprocessor.on_validation(query_ids, scores, ground_truth)
89
91
  return query_ids, scores, ground_truth
90
92
 
91
- # pylint: disable=too-many-arguments
92
93
  def on_validation_batch_end(
93
94
  self,
94
- trainer: L.Trainer,
95
- pl_module: L.LightningModule,
95
+ trainer: lightning.Trainer,
96
+ pl_module: lightning.LightningModule,
96
97
  outputs: torch.Tensor,
97
98
  batch: ValidationBatch,
98
99
  batch_idx: int,
@@ -100,11 +101,10 @@ class ValidationMetricsCallback(L.Callback):
100
101
  ) -> None:
101
102
  self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
102
103
 
103
- # pylint: disable=unused-argument, too-many-arguments
104
104
  def on_test_batch_end(
105
105
  self,
106
- trainer: L.Trainer,
107
- pl_module: L.LightningModule,
106
+ trainer: lightning.Trainer,
107
+ pl_module: lightning.LightningModule,
108
108
  outputs: torch.Tensor,
109
109
  batch: ValidationBatch,
110
110
  batch_idx: int,
@@ -112,11 +112,10 @@ class ValidationMetricsCallback(L.Callback):
112
112
  ) -> None: # pragma: no cover
113
113
  self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
114
114
 
115
- # pylint: disable=too-many-arguments
116
115
  def _batch_end(
117
116
  self,
118
- trainer: L.Trainer,
119
- pl_module: L.LightningModule,
117
+ trainer: lightning.Trainer, # noqa: ARG002
118
+ pl_module: lightning.LightningModule,
120
119
  outputs: torch.Tensor,
121
120
  batch: ValidationBatch,
122
121
  batch_idx: int,
@@ -131,31 +130,29 @@ class ValidationMetricsCallback(L.Callback):
131
130
  self._metrics_builders[dataloader_idx].get_metrics(),
132
131
  on_epoch=True,
133
132
  sync_dist=True,
134
- add_dataloader_idx=True
133
+ add_dataloader_idx=True,
135
134
  )
136
135
 
137
- # pylint: disable=unused-argument
138
- def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
136
+ def on_validation_epoch_end(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule) -> None:
139
137
  self._epoch_end(trainer, pl_module)
140
138
 
141
- # pylint: disable=unused-argument
142
- def on_test_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: # pragma: no cover
139
+ def on_test_epoch_end(
140
+ self, trainer: lightning.Trainer, pl_module: lightning.LightningModule
141
+ ) -> None: # pragma: no cover
143
142
  self._epoch_end(trainer, pl_module)
144
143
 
145
- # pylint: disable=unused-argument
146
- def _epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
147
- # pylint: disable=W0212
144
+ def _epoch_end(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule) -> None: # noqa: ARG002
148
145
  @rank_zero_only
149
146
  def print_metrics() -> None:
150
147
  metrics = {}
151
148
  for name, value in trainer.logged_metrics.items():
152
- if '@' in name:
149
+ if "@" in name:
153
150
  metrics[name] = value.item()
154
151
 
155
152
  if metrics:
156
153
  metrics_df = metrics_to_df(metrics)
157
154
 
158
- print(metrics_df)
159
- print()
155
+ print(metrics_df) # noqa: T201
156
+ print() # noqa: T201
160
157
 
161
158
  print_metrics()
@@ -5,6 +5,7 @@ import pandas as pd
5
5
  import torch
6
6
 
7
7
  from replay.data.nn import SequentialDataset
8
+
8
9
  from ._base import BasePostProcessor
9
10
 
10
11
 
@@ -85,7 +86,6 @@ class SampleItems(BasePostProcessor):
85
86
  Generates negative samples to compute sampled metrics
86
87
  """
87
88
 
88
- # pylint: disable=too-many-arguments
89
89
  def __init__(
90
90
  self,
91
91
  grouped_validation_items: pd.DataFrame,
@@ -30,7 +30,6 @@ class SasRecTrainingDataset(TorchDataset):
30
30
  Dataset that generates samples to train SasRec-like model
31
31
  """
32
32
 
33
- # pylint: disable=too-many-arguments
34
33
  def __init__(
35
34
  self,
36
35
  sequential: SequentialDataset,
@@ -56,13 +55,16 @@ class SasRecTrainingDataset(TorchDataset):
56
55
  super().__init__()
57
56
  if label_feature_name:
58
57
  if label_feature_name not in sequential.schema:
59
- raise ValueError("Label feature name not found in provided schema")
58
+ msg = "Label feature name not found in provided schema"
59
+ raise ValueError(msg)
60
60
 
61
61
  if not sequential.schema[label_feature_name].is_cat:
62
- raise ValueError("Label feature must be categorical")
62
+ msg = "Label feature must be categorical"
63
+ raise ValueError(msg)
63
64
 
64
65
  if not sequential.schema[label_feature_name].is_seq:
65
- raise ValueError("Label feature must be sequential")
66
+ msg = "Label feature must be sequential"
67
+ raise ValueError(msg)
66
68
 
67
69
  self._sequence_shift = sequence_shift
68
70
  self._max_sequence_length = max_sequence_length + sequence_shift
@@ -83,8 +85,8 @@ class SasRecTrainingDataset(TorchDataset):
83
85
  query_id, padding_mask, features = self._inner[index]
84
86
 
85
87
  assert self._label_feature_name
86
- labels = features[self._label_feature_name][self._sequence_shift :] # noqa: E203
87
- labels_padding_mask = padding_mask[self._sequence_shift :] # noqa: E203
88
+ labels = features[self._label_feature_name][self._sequence_shift :]
89
+ labels_padding_mask = padding_mask[self._sequence_shift :]
88
90
 
89
91
  output_features: MutableTensorMap = {}
90
92
  for feature_name in self._schema:
@@ -165,7 +167,6 @@ class SasRecValidationDataset(TorchDataset):
165
167
  Dataset that generates samples to infer and validate SasRec-like model
166
168
  """
167
169
 
168
- # pylint: disable=too-many-arguments
169
170
  def __init__(
170
171
  self,
171
172
  sequential: SequentialDataset,
@@ -1,17 +1,17 @@
1
1
  import math
2
- from typing import Any, Optional, Tuple, Union, cast, Dict
2
+ from typing import Any, Dict, Optional, Tuple, Union, cast
3
3
 
4
- import lightning as L
4
+ import lightning
5
5
  import torch
6
6
 
7
7
  from replay.data.nn import TensorMap, TensorSchema
8
8
  from replay.models.nn.optimizer_utils import FatOptimizerFactory, LRSchedulerFactory, OptimizerFactory
9
+
9
10
  from .dataset import SasRecPredictionBatch, SasRecTrainingBatch, SasRecValidationBatch
10
11
  from .model import SasRecModel
11
12
 
12
13
 
13
- # pylint: disable=too-many-instance-attributes
14
- class SasRec(L.LightningModule):
14
+ class SasRec(lightning.LightningModule):
15
15
  """
16
16
  SASRec Lightning module.
17
17
 
@@ -19,7 +19,6 @@ class SasRec(L.LightningModule):
19
19
  for object of SasRec instance.
20
20
  """
21
21
 
22
- # pylint: disable=too-many-arguments, too-many-locals
23
22
  def __init__(
24
23
  self,
25
24
  tensor_schema: TensorSchema,
@@ -94,7 +93,6 @@ class SasRec(L.LightningModule):
94
93
  assert item_count
95
94
  self._vocab_size = item_count
96
95
 
97
- # pylint: disable=unused-argument, arguments-differ
98
96
  def training_step(self, batch: SasRecTrainingBatch, batch_idx: int) -> torch.Tensor:
99
97
  """
100
98
  :param batch (SasRecTrainingBatch): Batch of training data.
@@ -108,7 +106,6 @@ class SasRec(L.LightningModule):
108
106
  self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
109
107
  return loss
110
108
 
111
- # pylint: disable=arguments-differ
112
109
  def forward(self, feature_tensors: TensorMap, padding_mask: torch.BoolTensor) -> torch.Tensor: # pragma: no cover
113
110
  """
114
111
  :param feature_tensors: Batch of features.
@@ -118,8 +115,9 @@ class SasRec(L.LightningModule):
118
115
  """
119
116
  return self._model_predict(feature_tensors, padding_mask)
120
117
 
121
- # pylint: disable=unused-argument
122
- def predict_step(self, batch: SasRecPredictionBatch, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
118
+ def predict_step(
119
+ self, batch: SasRecPredictionBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
120
+ ) -> torch.Tensor:
123
121
  """
124
122
  :param batch: Batch of prediction data.
125
123
  :param batch_idx: Batch index.
@@ -130,8 +128,9 @@ class SasRec(L.LightningModule):
130
128
  batch = self._prepare_prediction_batch(batch)
131
129
  return self._model_predict(batch.features, batch.padding_mask)
132
130
 
133
- # pylint: disable=unused-argument, arguments-differ
134
- def validation_step(self, batch: SasRecValidationBatch, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
131
+ def validation_step(
132
+ self, batch: SasRecValidationBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
133
+ ) -> torch.Tensor:
135
134
  """
136
135
  :param batch (SasRecValidationBatch): Batch of prediction data.
137
136
  :param batch_idx (int): Batch index.
@@ -155,57 +154,46 @@ class SasRec(L.LightningModule):
155
154
 
156
155
  def _prepare_prediction_batch(self, batch: SasRecPredictionBatch) -> SasRecPredictionBatch:
157
156
  if batch.padding_mask.shape[1] > self._model.max_len:
158
- raise ValueError(
159
- f"The length of the submitted sequence \
157
+ msg = f"The length of the submitted sequence \
160
158
  must not exceed the maximum length of the sequence. \
161
159
  The length of the sequence is given {batch.padding_mask.shape[1]}, \
162
- while the maximum length is {self._model.max_len}")
160
+ while the maximum length is {self._model.max_len}"
161
+ raise ValueError(msg)
162
+
163
163
  if batch.padding_mask.shape[1] < self._model.max_len:
164
164
  query_id, padding_mask, features = batch
165
165
  sequence_item_count = padding_mask.shape[1]
166
166
  for feature_name, feature_tensor in features.items():
167
167
  if self._schema[feature_name].is_cat:
168
168
  features[feature_name] = torch.nn.functional.pad(
169
- feature_tensor,
170
- (self._model.max_len - sequence_item_count, 0),
171
- value=0
169
+ feature_tensor, (self._model.max_len - sequence_item_count, 0), value=0
172
170
  )
173
171
  else:
174
172
  features[feature_name] = torch.nn.functional.pad(
175
173
  feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
176
174
  (self._model.max_len - sequence_item_count, 0),
177
- value=0
175
+ value=0,
178
176
  ).unsqueeze(-1)
179
177
  padding_mask = torch.nn.functional.pad(
180
- padding_mask,
181
- (self._model.max_len - sequence_item_count, 0),
182
- value=0
178
+ padding_mask, (self._model.max_len - sequence_item_count, 0), value=0
183
179
  )
184
180
  batch = SasRecPredictionBatch(query_id, padding_mask, features)
185
181
  return batch
186
182
 
187
183
  def _model_predict(self, feature_tensors: TensorMap, padding_mask: torch.BoolTensor) -> torch.Tensor:
188
184
  model: SasRecModel
189
- if isinstance(self._model, torch.nn.DataParallel):
190
- model = cast(SasRecModel, self._model.module) # multigpu
191
- else:
192
- model = self._model
185
+ model = cast(SasRecModel, self._model.module) if isinstance(self._model, torch.nn.DataParallel) else self._model
193
186
  scores = model.predict(feature_tensors, padding_mask)
194
187
  return scores
195
188
 
196
189
  def _compute_loss(self, batch: SasRecTrainingBatch) -> torch.Tensor:
197
190
  if self._loss_type == "BCE":
198
- if self._loss_sample_count is None:
199
- loss_func = self._compute_loss_bce
200
- else:
201
- loss_func = self._compute_loss_bce_sampled
191
+ loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
202
192
  elif self._loss_type == "CE":
203
- if self._loss_sample_count is None:
204
- loss_func = self._compute_loss_ce
205
- else:
206
- loss_func = self._compute_loss_ce_sampled
193
+ loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled
207
194
  else:
208
- raise ValueError(f"Not supported loss type: {self._loss_type}")
195
+ msg = f"Not supported loss type: {self._loss_type}"
196
+ raise ValueError(msg)
209
197
 
210
198
  loss = loss_func(
211
199
  batch.features,
@@ -225,8 +213,10 @@ class SasRec(L.LightningModule):
225
213
  # [B x L x V]
226
214
  logits = self._model.forward(feature_tensors, padding_mask)
227
215
 
228
- # Take only logits which correspond to non-padded tokens
229
- # M = non_zero_count(target_padding_mask)
216
+ """
217
+ Take only logits which correspond to non-padded tokens
218
+ M = non_zero_count(target_padding_mask)
219
+ """
230
220
  logits = logits[target_padding_mask] # [M x V]
231
221
  labels = positive_labels[target_padding_mask] # [M]
232
222
 
@@ -318,7 +308,6 @@ class SasRec(L.LightningModule):
318
308
  loss = self._loss(logits, labels_flat)
319
309
  return loss
320
310
 
321
- # pylint: disable=too-many-locals
322
311
  def _get_sampled_logits(
323
312
  self,
324
313
  feature_tensors: TensorMap,
@@ -354,7 +343,8 @@ class SasRec(L.LightningModule):
354
343
  else:
355
344
  multinomial_sample_distribution = torch.softmax(positive_logits, dim=-1)
356
345
  else:
357
- raise NotImplementedError(f"Unknown negative sampling strategy: {self._negative_sampling_strategy}")
346
+ msg = f"Unknown negative sampling strategy: {self._negative_sampling_strategy}"
347
+ raise NotImplementedError(msg)
358
348
  n_negative_samples = min(n_negative_samples, vocab_size)
359
349
 
360
350
  if self._negatives_sharing:
@@ -405,7 +395,8 @@ class SasRec(L.LightningModule):
405
395
  if self._loss_type == "CE":
406
396
  return torch.nn.CrossEntropyLoss()
407
397
 
408
- raise NotImplementedError("Not supported loss_type")
398
+ msg = "Not supported loss_type"
399
+ raise NotImplementedError(msg)
409
400
 
410
401
  def get_all_embeddings(self) -> Dict[str, torch.nn.Embedding]:
411
402
  """
@@ -415,17 +406,18 @@ class SasRec(L.LightningModule):
415
406
 
416
407
  def set_item_embeddings_by_size(self, new_vocab_size: int):
417
408
  """
418
- Set item embeddings initialized with xavier_normal_ by new size of vocabulary
419
- to item embedder.
409
+ Keep the current item embeddings and expand vocabulary with new embeddings
410
+ initialized with xavier_normal_ for new items.
420
411
 
421
- :param new_vocab_size: Size of vocabulary with new items.
412
+ :param new_vocab_size: Size of vocabulary with new items included.
422
413
  Must be greater then already fitted.
423
414
  """
424
415
  old_vocab_size = self._model.item_embedder.item_emb.weight.data.shape[0] - 1
425
416
  hidden_size = self._model.hidden_size
426
417
 
427
418
  if new_vocab_size <= old_vocab_size:
428
- raise ValueError("New vocabulary size must be greater then already fitted")
419
+ msg = "New vocabulary size must be greater then already fitted"
420
+ raise ValueError(msg)
429
421
 
430
422
  new_embedding = torch.nn.Embedding(new_vocab_size + 1, hidden_size, padding_idx=new_vocab_size)
431
423
  torch.nn.init.xavier_normal_(new_embedding.weight)
@@ -443,16 +435,19 @@ class SasRec(L.LightningModule):
443
435
  shape (n, h), where n - number of all items, h - model hidden size.
444
436
  """
445
437
  if all_item_embeddings.dim() != 2:
446
- raise ValueError("Input tensor must have (number of all items, model hidden size) shape")
438
+ msg = "Input tensor must have (number of all items, model hidden size) shape"
439
+ raise ValueError(msg)
447
440
 
448
441
  old_vocab_size = self._model.item_embedder.item_emb.weight.data.shape[0] - 1
449
442
  new_vocab_size = all_item_embeddings.shape[0]
450
443
  hidden_size = self._model.hidden_size
451
444
 
452
445
  if new_vocab_size < old_vocab_size:
453
- raise ValueError("New vocabulary size can't be less then already fitted")
446
+ msg = "New vocabulary size can't be less then already fitted"
447
+ raise ValueError(msg)
454
448
  if all_item_embeddings.shape[1] != hidden_size:
455
- raise ValueError("Input tensor second dimension doesn't match model hidden size")
449
+ msg = "Input tensor second dimension doesn't match model hidden size"
450
+ raise ValueError(msg)
456
451
 
457
452
  new_embedding = torch.nn.Embedding(new_vocab_size + 1, hidden_size, padding_idx=new_vocab_size)
458
453
  new_embedding.weight.data[:-1, :] = all_item_embeddings
@@ -467,14 +462,16 @@ class SasRec(L.LightningModule):
467
462
  n - number of only new items, h - model hidden size.
468
463
  """
469
464
  if item_embeddings.dim() != 2:
470
- raise ValueError("Input tensor must have (number of new items, model hidden size) shape")
465
+ msg = "Input tensor must have (number of new items, model hidden size) shape"
466
+ raise ValueError(msg)
471
467
 
472
468
  old_vocab_size = self._model.item_embedder.item_emb.weight.data.shape[0] - 1
473
469
  new_vocab_size = item_embeddings.shape[0] + old_vocab_size
474
470
  hidden_size = self._model.hidden_size
475
471
 
476
472
  if item_embeddings.shape[1] != hidden_size:
477
- raise ValueError("Input tensor second dimension doesn't match model hidden size")
473
+ msg = "Input tensor second dimension doesn't match model hidden size"
474
+ raise ValueError(msg)
478
475
 
479
476
  new_embedding = torch.nn.Embedding(new_vocab_size + 1, hidden_size, padding_idx=new_vocab_size)
480
477
  new_embedding.weight.data[:old_vocab_size, :] = self._model.item_embedder.item_emb.weight.data[:-1, :]
@@ -489,3 +486,11 @@ class SasRec(L.LightningModule):
489
486
  self._model.item_count = new_vocab_size
490
487
  self._model.padding_idx = new_vocab_size
491
488
  self._model.masking.padding_idx = new_vocab_size
489
+ self._model.candidates_to_score = torch.tensor(
490
+ list(range(new_embedding.weight.data.shape[0] - 1)),
491
+ device=self._model.candidates_to_score.device,
492
+ dtype=torch.long,
493
+ )
494
+ self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(
495
+ new_embedding.weight.data.shape[0] - 1
496
+ )
@@ -1,18 +1,17 @@
1
1
  import abc
2
- from typing import Any, Optional, Tuple, Union, cast, Dict
2
+ import contextlib
3
+ from typing import Any, Dict, Optional, Tuple, Union, cast
3
4
 
4
5
  import torch
5
6
 
6
7
  from replay.data.nn import TensorMap, TensorSchema
7
8
 
8
9
 
9
- # pylint: disable=too-many-instance-attributes
10
10
  class SasRecModel(torch.nn.Module):
11
11
  """
12
12
  SasRec model
13
13
  """
14
14
 
15
- # pylint: disable=too-many-arguments
16
15
  def __init__(
17
16
  self,
18
17
  schema: TensorSchema,
@@ -189,13 +188,10 @@ class SasRecModel(torch.nn.Module):
189
188
 
190
189
  def _init(self) -> None:
191
190
  for _, param in self.named_parameters():
192
- try:
191
+ with contextlib.suppress(ValueError):
193
192
  torch.nn.init.xavier_normal_(param.data)
194
- except ValueError:
195
- pass
196
193
 
197
194
 
198
- # pylint: disable=too-few-public-methods
199
195
  class SasRecMasks:
200
196
  """
201
197
  SasRec Masks
@@ -316,7 +312,6 @@ class SasRecEmbeddings(torch.nn.Module, BaseSasRecEmbeddings):
316
312
  Link: https://arxiv.org/pdf/1808.09781.pdf
317
313
  """
318
314
 
319
- # pylint: disable=too-many-arguments
320
315
  def __init__(
321
316
  self,
322
317
  schema: TensorSchema,
@@ -406,11 +401,7 @@ class SasRecLayers(torch.nn.Module):
406
401
  """
407
402
  super().__init__()
408
403
  self.attention_layers = self._layers_stacker(
409
- num_blocks,
410
- torch.nn.MultiheadAttention,
411
- hidden_size,
412
- num_heads,
413
- dropout
404
+ num_blocks, torch.nn.MultiheadAttention, hidden_size, num_heads, dropout
414
405
  )
415
406
  self.attention_layernorms = self._layers_stacker(num_blocks, torch.nn.LayerNorm, hidden_size, eps=1e-8)
416
407
  self.forward_layers = self._layers_stacker(num_blocks, SasRecPointWiseFeedForward, hidden_size, dropout)
@@ -513,7 +504,6 @@ class SasRecPositionalEmbedding(torch.nn.Module):
513
504
  Positional embedding.
514
505
  """
515
506
 
516
- # pylint: disable=invalid-name
517
507
  def __init__(self, max_len: int, d_model: int) -> None:
518
508
  """
519
509
  :param max_len: Max sequence length.
@@ -542,7 +532,6 @@ class TiSasRecEmbeddings(torch.nn.Module, BaseSasRecEmbeddings):
542
532
  Link: https://cseweb.ucsd.edu/~jmcauley/pdfs/wsdm20b.pdf
543
533
  """
544
534
 
545
- # pylint: disable=too-many-arguments
546
535
  def __init__(
547
536
  self,
548
537
  schema: TensorSchema,
@@ -678,7 +667,6 @@ class TiSasRecLayers(torch.nn.Module):
678
667
  self.attention_layernorms = self._layers_stacker(num_blocks, torch.nn.LayerNorm, hidden_size, eps=1e-8)
679
668
  self.forward_layernorms = self._layers_stacker(num_blocks, torch.nn.LayerNorm, hidden_size, eps=1e-8)
680
669
 
681
- # pylint: disable=too-many-arguments
682
670
  def forward(
683
671
  self,
684
672
  seqs: torch.Tensor,
@@ -738,7 +726,6 @@ class TiSasRecAttention(torch.nn.Module):
738
726
  self.head_size = hidden_size // head_num
739
727
  self.dropout_rate = dropout_rate
740
728
 
741
- # pylint: disable=too-many-arguments, invalid-name, too-many-locals
742
729
  def forward(
743
730
  self,
744
731
  queries: torch.LongTensor,
replay/models/pop_rec.py CHANGED
@@ -1,8 +1,8 @@
1
-
2
1
  from replay.data.dataset import Dataset
3
- from .base_rec import NonPersonalizedRecommender
4
2
  from replay.utils import PYSPARK_AVAILABLE
5
3
 
4
+ from .base_rec import NonPersonalizedRecommender
5
+
6
6
  if PYSPARK_AVAILABLE:
7
7
  from pyspark.sql import functions as sf
8
8
 
@@ -23,7 +23,11 @@ class PopRec(NonPersonalizedRecommender):
23
23
  >>> import pandas as pd
24
24
  >>> from replay.data.dataset import Dataset, FeatureSchema, FeatureInfo, FeatureHint, FeatureType
25
25
  >>> from replay.utils.spark_utils import convert2spark
26
- >>> data_frame = pd.DataFrame({"user_id": [1, 1, 2, 2, 3, 4], "item_id": [1, 2, 2, 3, 3, 3], "rating": [0.5, 1, 0.1, 0.8, 0.7, 1]})
26
+ >>> data_frame = pd.DataFrame(
27
+ ... {"user_id": [1, 1, 2, 2, 3, 4],
28
+ ... "item_id": [1, 2, 2, 3, 3, 3],
29
+ ... "rating": [0.5, 1, 0.1, 0.8, 0.7, 1]}
30
+ ... )
27
31
  >>> data_frame
28
32
  user_id item_id rating
29
33
  0 1 1 0.5
@@ -104,9 +108,7 @@ class PopRec(NonPersonalizedRecommender):
104
108
  `Cold_weight` value should be in interval (0, 1].
105
109
  """
106
110
  self.use_rating = use_rating
107
- super().__init__(
108
- add_cold_items=add_cold_items, cold_weight=cold_weight
109
- )
111
+ super().__init__(add_cold_items=add_cold_items, cold_weight=cold_weight)
110
112
 
111
113
  @property
112
114
  def _init_args(self):
@@ -120,7 +122,6 @@ class PopRec(NonPersonalizedRecommender):
120
122
  self,
121
123
  dataset: Dataset,
122
124
  ) -> None:
123
-
124
125
  agg_func = sf.countDistinct(self.query_column).alias(self.rating_column)
125
126
  if self.use_rating:
126
127
  agg_func = sf.sum(self.rating_column).alias(self.rating_column)
@@ -128,9 +129,7 @@ class PopRec(NonPersonalizedRecommender):
128
129
  self.item_popularity = (
129
130
  dataset.interactions.groupBy(self.item_column)
130
131
  .agg(agg_func)
131
- .withColumn(
132
- self.rating_column, sf.col(self.rating_column) / sf.lit(self.queries_count)
133
- )
132
+ .withColumn(self.rating_column, sf.col(self.rating_column) / sf.lit(self.queries_count))
134
133
  )
135
134
 
136
135
  self.item_popularity.cache().count()
@@ -1,8 +1,8 @@
1
-
2
1
  from replay.data import Dataset
3
- from .base_rec import Recommender
4
2
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
5
3
 
4
+ from .base_rec import Recommender
5
+
6
6
  if PYSPARK_AVAILABLE:
7
7
  from pyspark.sql import functions as sf
8
8
 
@@ -76,7 +76,6 @@ class QueryPopRec(Recommender):
76
76
  self,
77
77
  dataset: Dataset,
78
78
  ) -> None:
79
-
80
79
  query_rating_sum = (
81
80
  dataset.interactions.groupBy(self.query_column)
82
81
  .agg(sf.sum(self.rating_column).alias("query_rel_sum"))
@@ -94,9 +93,7 @@ class QueryPopRec(Recommender):
94
93
  .select(
95
94
  self.query_column,
96
95
  self.item_column,
97
- (sf.col("query_item_rel_sum") / sf.col("query_rel_sum")).alias(
98
- self.rating_column
99
- ),
96
+ (sf.col("query_item_rel_sum") / sf.col("query_rel_sum")).alias(self.rating_column),
100
97
  )
101
98
  )
102
99
  self.query_item_popularity.cache().count()
@@ -105,20 +102,15 @@ class QueryPopRec(Recommender):
105
102
  if hasattr(self, "query_item_popularity"):
106
103
  self.query_item_popularity.unpersist()
107
104
 
108
- # pylint: disable=too-many-arguments
109
105
  def _predict(
110
106
  self,
111
- dataset: Dataset,
112
- k: int,
107
+ dataset: Dataset, # noqa: ARG002
108
+ k: int, # noqa: ARG002
113
109
  queries: SparkDataFrame,
114
110
  items: SparkDataFrame,
115
111
  filter_seen_items: bool = True,
116
112
  ) -> SparkDataFrame:
117
113
  if filter_seen_items:
118
- self.logger.warning(
119
- "QueryPopRec can't predict new items, recommendations will not be filtered"
120
- )
114
+ self.logger.warning("QueryPopRec can't predict new items, recommendations will not be filtered")
121
115
 
122
- return self.query_item_popularity.join(queries, on=self.query_column).join(
123
- items, on=self.item_column
124
- )
116
+ return self.query_item_popularity.join(queries, on=self.query_column).join(items, on=self.item_column)