replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -6,9 +6,9 @@ if TORCH_AVAILABLE:
6
6
  Bert4RecPredictionDataset,
7
7
  Bert4RecTrainingBatch,
8
8
  Bert4RecTrainingDataset,
9
+ Bert4RecUniformMasker,
9
10
  Bert4RecValidationBatch,
10
11
  Bert4RecValidationDataset,
11
- Bert4RecUniformMasker,
12
12
  )
13
13
  from .lightning import Bert4Rec
14
14
  from .model import Bert4RecModel
@@ -27,7 +27,6 @@ class Bert4RecTrainingBatch(NamedTuple):
27
27
  labels: torch.LongTensor
28
28
 
29
29
 
30
- # pylint: disable=too-few-public-methods
31
30
  class Bert4RecMasker(abc.ABC):
32
31
  """
33
32
  Interface for a token masking strategy during BERT model training
@@ -44,7 +43,6 @@ class Bert4RecMasker(abc.ABC):
44
43
  """
45
44
 
46
45
 
47
- # pylint: disable=too-few-public-methods
48
46
  class Bert4RecUniformMasker(Bert4RecMasker):
49
47
  """
50
48
  Token masking strategy that mask random token with uniform distribution.
@@ -90,7 +88,6 @@ class Bert4RecTrainingDataset(TorchDataset):
90
88
  Dataset that generates samples to train BERT-like model
91
89
  """
92
90
 
93
- # pylint: disable=too-many-arguments
94
91
  def __init__(
95
92
  self,
96
93
  sequential: SequentialDataset,
@@ -121,13 +118,16 @@ class Bert4RecTrainingDataset(TorchDataset):
121
118
  super().__init__()
122
119
  if label_feature_name:
123
120
  if label_feature_name not in sequential.schema:
124
- raise ValueError("Label feature name not found in provided schema")
121
+ msg = "Label feature name not found in provided schema"
122
+ raise ValueError(msg)
125
123
 
126
124
  if not sequential.schema[label_feature_name].is_cat:
127
- raise ValueError("Label feature must be categorical")
125
+ msg = "Label feature must be categorical"
126
+ raise ValueError(msg)
128
127
 
129
128
  if not sequential.schema[label_feature_name].is_seq:
130
- raise ValueError("Label feature must be sequential")
129
+ msg = "Label feature must be sequential"
130
+ raise ValueError(msg)
131
131
 
132
132
  self._max_sequence_length = max_sequence_length
133
133
  self._label_feature_name = label_feature_name or sequential.schema.item_id_feature_name
@@ -230,7 +230,6 @@ class Bert4RecValidationDataset(TorchDataset):
230
230
  Dataset that generates samples to infer and validate BERT-like model
231
231
  """
232
232
 
233
- # pylint: disable=too-many-arguments
234
233
  def __init__(
235
234
  self,
236
235
  sequential: SequentialDataset,
@@ -1,27 +1,21 @@
1
1
  import math
2
- from typing import Any, Optional, Tuple, Union, cast, Dict
2
+ from typing import Any, Dict, Optional, Tuple, Union, cast
3
3
 
4
- import lightning as L
4
+ import lightning
5
5
  import torch
6
6
 
7
7
  from replay.data.nn import TensorMap, TensorSchema
8
8
  from replay.models.nn.optimizer_utils import FatOptimizerFactory, LRSchedulerFactory, OptimizerFactory
9
- from .dataset import (
10
- Bert4RecPredictionBatch,
11
- Bert4RecTrainingBatch,
12
- Bert4RecValidationBatch,
13
- _shift_features
14
- )
9
+
10
+ from .dataset import Bert4RecPredictionBatch, Bert4RecTrainingBatch, Bert4RecValidationBatch, _shift_features
15
11
  from .model import Bert4RecModel, CatFeatureEmbedding
16
12
 
17
13
 
18
- # pylint: disable=too-many-instance-attributes
19
- class Bert4Rec(L.LightningModule):
14
+ class Bert4Rec(lightning.LightningModule):
20
15
  """
21
16
  Implements BERT training-validation loop
22
17
  """
23
18
 
24
- # pylint: disable=too-many-arguments, too-many-locals
25
19
  def __init__(
26
20
  self,
27
21
  tensor_schema: TensorSchema,
@@ -102,8 +96,7 @@ class Bert4Rec(L.LightningModule):
102
96
  assert item_count
103
97
  self._vocab_size = item_count
104
98
 
105
- # pylint: disable=unused-argument, arguments-differ
106
- def training_step(self, batch: Bert4RecTrainingBatch, batch_idx: int) -> torch.Tensor:
99
+ def training_step(self, batch: Bert4RecTrainingBatch, batch_idx: int) -> torch.Tensor: # noqa: ARG002
107
100
  """
108
101
  :param batch: Batch of training data.
109
102
  :param batch_idx: Batch index.
@@ -129,8 +122,9 @@ class Bert4Rec(L.LightningModule):
129
122
  """
130
123
  return self._model_predict(feature_tensors, padding_mask, tokens_mask)
131
124
 
132
- # pylint: disable=unused-argument
133
- def predict_step(self, batch: Bert4RecPredictionBatch, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
125
+ def predict_step(
126
+ self, batch: Bert4RecPredictionBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
127
+ ) -> torch.Tensor:
134
128
  """
135
129
  :param batch (Bert4RecPredictionBatch): Batch of prediction data.
136
130
  :param batch_idx (int): Batch index.
@@ -141,8 +135,9 @@ class Bert4Rec(L.LightningModule):
141
135
  batch = self._prepare_prediction_batch(batch)
142
136
  return self._model_predict(batch.features, batch.padding_mask, batch.tokens_mask)
143
137
 
144
- # pylint: disable=unused-argument
145
- def validation_step(self, batch: Bert4RecValidationBatch, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
138
+ def validation_step(
139
+ self, batch: Bert4RecValidationBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
140
+ ) -> torch.Tensor:
146
141
  """
147
142
  :param batch: Batch of prediction data.
148
143
  :param batch_idx: Batch index.
@@ -166,31 +161,28 @@ class Bert4Rec(L.LightningModule):
166
161
 
167
162
  def _prepare_prediction_batch(self, batch: Bert4RecPredictionBatch) -> Bert4RecPredictionBatch:
168
163
  if batch.padding_mask.shape[1] > self._model.max_len:
169
- raise ValueError(
170
- f"The length of the submitted sequence \
164
+ msg = f"The length of the submitted sequence \
171
165
  must not exceed the maximum length of the sequence. \
172
166
  The length of the sequence is given {batch.padding_mask.shape[1]}, \
173
- while the maximum length is {self._model.max_len}")
167
+ while the maximum length is {self._model.max_len}"
168
+ raise ValueError(msg)
169
+
174
170
  if batch.padding_mask.shape[1] < self._model.max_len:
175
171
  query_id, padding_mask, features, _ = batch
176
172
  sequence_item_count = padding_mask.shape[1]
177
173
  for feature_name, feature_tensor in features.items():
178
174
  if self._schema[feature_name].is_cat:
179
175
  features[feature_name] = torch.nn.functional.pad(
180
- feature_tensor,
181
- (self._model.max_len - sequence_item_count, 0),
182
- value=0
176
+ feature_tensor, (self._model.max_len - sequence_item_count, 0), value=0
183
177
  )
184
178
  else:
185
179
  features[feature_name] = torch.nn.functional.pad(
186
180
  feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
187
181
  (self._model.max_len - sequence_item_count, 0),
188
- value=0
182
+ value=0,
189
183
  ).unsqueeze(-1)
190
184
  padding_mask = torch.nn.functional.pad(
191
- padding_mask,
192
- (self._model.max_len - sequence_item_count, 0),
193
- value=0
185
+ padding_mask, (self._model.max_len - sequence_item_count, 0), value=0
194
186
  )
195
187
  shifted_features, shifted_padding_mask, tokens_mask = _shift_features(self._schema, features, padding_mask)
196
188
  batch = Bert4RecPredictionBatch(query_id, shifted_padding_mask, shifted_features, tokens_mask)
@@ -213,17 +205,12 @@ class Bert4Rec(L.LightningModule):
213
205
 
214
206
  def _compute_loss(self, batch: Bert4RecTrainingBatch) -> torch.Tensor:
215
207
  if self._loss_type == "BCE":
216
- if self._loss_sample_count is None:
217
- loss_func = self._compute_loss_bce
218
- else:
219
- loss_func = self._compute_loss_bce_sampled
208
+ loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
220
209
  elif self._loss_type == "CE":
221
- if self._loss_sample_count is None:
222
- loss_func = self._compute_loss_ce
223
- else:
224
- loss_func = self._compute_loss_ce_sampled
210
+ loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled
225
211
  else:
226
- raise ValueError(f"Not supported loss type: {self._loss_type}")
212
+ msg = f"Not supported loss type: {self._loss_type}"
213
+ raise ValueError(msg)
227
214
 
228
215
  loss = loss_func(
229
216
  batch.features,
@@ -246,8 +233,10 @@ class Bert4Rec(L.LightningModule):
246
233
 
247
234
  labels_mask = (~padding_mask) + tokens_mask
248
235
  masked_tokens = ~labels_mask
249
- # Take only logits which correspond to non-padded tokens
250
- # M = non_zero_count(target_padding_mask)
236
+ """
237
+ Take only logits which correspond to non-padded tokens
238
+ M = non_zero_count(target_padding_mask)
239
+ """
251
240
  logits = logits[masked_tokens] # [M x V]
252
241
  labels = positive_labels[masked_tokens] # [M]
253
242
 
@@ -374,7 +363,8 @@ class Bert4Rec(L.LightningModule):
374
363
  else:
375
364
  multinomial_sample_distribution = torch.softmax(positive_logits, dim=-1)
376
365
  else:
377
- raise NotImplementedError(f"Unknown negative sampling strategy: {self._negative_sampling_strategy}")
366
+ msg = f"Unknown negative sampling strategy: {self._negative_sampling_strategy}"
367
+ raise NotImplementedError(msg)
378
368
  n_negative_samples = min(n_negative_samples, vocab_size)
379
369
 
380
370
  if self._negatives_sharing:
@@ -426,7 +416,8 @@ class Bert4Rec(L.LightningModule):
426
416
  if self._loss_type == "CE":
427
417
  return torch.nn.CrossEntropyLoss()
428
418
 
429
- raise NotImplementedError("Not supported loss_type")
419
+ msg = "Not supported loss_type"
420
+ raise NotImplementedError(msg)
430
421
 
431
422
  def get_all_embeddings(self) -> Dict[str, torch.nn.Embedding]:
432
423
  """
@@ -436,21 +427,22 @@ class Bert4Rec(L.LightningModule):
436
427
 
437
428
  def set_item_embeddings_by_size(self, new_vocab_size: int):
438
429
  """
439
- Set item embeddings initialized with xavier_normal_ by new size of vocabulary
440
- to item embedder.
430
+ Keep the current item embeddings and expand vocabulary with new embeddings
431
+ initialized with xavier_normal_ for new items.
441
432
 
442
- :param new_vocab_size: Size of vocabulary with new items.
433
+ :param new_vocab_size: Size of vocabulary with new items included.
443
434
  Must be greater then already fitted.
444
435
  """
445
436
  if new_vocab_size <= self._vocab_size:
446
- raise ValueError("New vocabulary size must be greater then already fitted")
437
+ msg = "New vocabulary size must be greater then already fitted"
438
+ raise ValueError(msg)
447
439
 
448
440
  item_tensor_feature_info = self._model.schema.item_id_features.item()
449
441
  item_tensor_feature_info._set_cardinality(new_vocab_size)
450
442
 
451
443
  weights_new = CatFeatureEmbedding(item_tensor_feature_info)
452
444
  torch.nn.init.xavier_normal_(weights_new.weight)
453
- weights_new.weight.data[:self._vocab_size, :] = self._model.item_embedder.item_embeddings.data
445
+ weights_new.weight.data[: self._vocab_size, :] = self._model.item_embedder.item_embeddings.data
454
446
 
455
447
  self._set_new_item_embedder_to_model(weights_new, new_vocab_size)
456
448
 
@@ -464,15 +456,18 @@ class Bert4Rec(L.LightningModule):
464
456
  shape (n, h), where n - number of all items, h - model hidden size.
465
457
  """
466
458
  if all_item_embeddings.dim() != 2:
467
- raise ValueError("Input tensor must have (number of all items, model hidden size) shape")
459
+ msg = "Input tensor must have (number of all items, model hidden size) shape"
460
+ raise ValueError(msg)
468
461
 
469
462
  new_vocab_size = all_item_embeddings.shape[0]
470
463
  if new_vocab_size < self._vocab_size:
471
- raise ValueError("New vocabulary size can't be less then already fitted")
464
+ msg = "New vocabulary size can't be less then already fitted"
465
+ raise ValueError(msg)
472
466
 
473
467
  item_tensor_feature_info = self._model.schema.item_id_features.item()
474
468
  if all_item_embeddings.shape[1] != item_tensor_feature_info.embedding_dim:
475
- raise ValueError("Input tensor second dimension doesn't match embedding dim")
469
+ msg = "Input tensor second dimension doesn't match embedding dim"
470
+ raise ValueError(msg)
476
471
 
477
472
  item_tensor_feature_info._set_cardinality(new_vocab_size)
478
473
 
@@ -490,37 +485,39 @@ class Bert4Rec(L.LightningModule):
490
485
  n - number of only new items, h - model hidden size.
491
486
  """
492
487
  if item_embeddings.dim() != 2:
493
- raise ValueError("Input tensor must have (number of all items, model hidden size) shape")
488
+ msg = "Input tensor must have (number of all items, model hidden size) shape"
489
+ raise ValueError(msg)
494
490
 
495
491
  new_vocab_size = item_embeddings.shape[0] + self._vocab_size
496
492
 
497
493
  item_tensor_feature_info = self._model.schema.item_id_features.item()
498
494
  if item_embeddings.shape[1] != item_tensor_feature_info.embedding_dim:
499
- raise ValueError("Input tensor second dimension doesn't match embedding dim")
495
+ msg = "Input tensor second dimension doesn't match embedding dim"
496
+ raise ValueError(msg)
500
497
 
501
498
  item_tensor_feature_info._set_cardinality(new_vocab_size)
502
499
 
503
500
  weights_new = CatFeatureEmbedding(item_tensor_feature_info)
504
501
  torch.nn.init.xavier_normal_(weights_new.weight)
505
- weights_new.weight.data[:self._vocab_size, :] = self._model.item_embedder.item_embeddings.data
506
- weights_new.weight.data[self._vocab_size:, :] = item_embeddings.data
502
+ weights_new.weight.data[: self._vocab_size, :] = self._model.item_embedder.item_embeddings.data
503
+ weights_new.weight.data[self._vocab_size :, :] = item_embeddings.data
507
504
 
508
505
  self._set_new_item_embedder_to_model(weights_new, new_vocab_size)
509
506
 
510
507
  def _set_new_item_embedder_to_model(self, weights_new: torch.nn.Embedding, new_vocab_size: int):
511
508
  self._model.item_embedder.cat_embeddings[self._model.schema.item_id_feature_name] = weights_new
512
-
513
509
  if self._model.enable_embedding_tying is True:
514
510
  self._model._head._item_embedder = self._model.item_embedder
515
511
  new_bias = torch.Tensor(new_vocab_size)
516
512
  new_bias.normal_(0, 0.01)
517
- new_bias[:self._vocab_size] = self._model._head.out_bias.data
513
+ new_bias[: self._vocab_size] = self._model._head.out_bias.data
518
514
  self._model._head.out_bias = torch.nn.Parameter(new_bias)
519
515
  else:
520
516
  new_linear = torch.nn.Linear(self._model.hidden_size, new_vocab_size)
521
- new_linear.weight.data[:self._vocab_size, :] = self._model._head.linear.weight.data
522
- new_linear.bias.data[:self._vocab_size] = self._model._head.linear.bias.data
517
+ new_linear.weight.data[: self._vocab_size, :] = self._model._head.linear.weight.data
518
+ new_linear.bias.data[: self._vocab_size] = self._model._head.linear.bias.data
523
519
  self._model._head.linear = new_linear
524
520
 
525
521
  self._vocab_size = new_vocab_size
526
522
  self._model.item_count = new_vocab_size
523
+ self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(new_vocab_size)
@@ -1,18 +1,18 @@
1
+ import contextlib
1
2
  import math
2
3
  from abc import ABC, abstractmethod
3
- from typing import Optional, Tuple, Union, cast, Dict
4
+ from typing import Dict, Optional, Tuple, Union, cast
4
5
 
5
6
  import torch
6
7
 
7
8
  from replay.data.nn import TensorFeatureInfo, TensorMap, TensorSchema
8
9
 
9
10
 
10
- # pylint: disable=too-many-instance-attributes
11
11
  class Bert4RecModel(torch.nn.Module):
12
12
  """
13
13
  BERT model
14
14
  """
15
- # pylint: disable=too-many-arguments
15
+
16
16
  def __init__(
17
17
  self,
18
18
  schema: TensorSchema,
@@ -137,12 +137,7 @@ class Bert4RecModel(torch.nn.Module):
137
137
  """
138
138
  return self._head(out_embeddings, item_ids)
139
139
 
140
- def get_query_embeddings(
141
- self,
142
- inputs: TensorMap,
143
- pad_mask: torch.BoolTensor,
144
- token_mask: torch.BoolTensor
145
- ):
140
+ def get_query_embeddings(self, inputs: TensorMap, pad_mask: torch.BoolTensor, token_mask: torch.BoolTensor):
146
141
  """
147
142
  :param inputs: Batch of features.
148
143
  :param pad_mask: Padding mask where 0 - <PAD>, 1 otherwise.
@@ -159,13 +154,10 @@ class Bert4RecModel(torch.nn.Module):
159
154
 
160
155
  def _init(self) -> None:
161
156
  for _, param in self.named_parameters():
162
- try:
157
+ with contextlib.suppress(ValueError):
163
158
  torch.nn.init.xavier_normal_(param.data)
164
- except ValueError:
165
- pass
166
159
 
167
160
 
168
- # pylint: disable=too-many-instance-attributes
169
161
  class BertEmbedding(torch.nn.Module):
170
162
  """
171
163
  BERT Embedding which is consisted with under features
@@ -174,7 +166,6 @@ class BertEmbedding(torch.nn.Module):
174
166
  sum of all these features are output of BertEmbedding
175
167
  """
176
168
 
177
- # pylint: disable=too-many-arguments
178
169
  def __init__(
179
170
  self,
180
171
  schema: TensorSchema,
@@ -206,19 +197,18 @@ class BertEmbedding(torch.nn.Module):
206
197
 
207
198
  for feature_name, tensor_info in schema.items():
208
199
  if not tensor_info.is_seq:
209
- raise NotImplementedError("Non-sequential features is not yet supported")
200
+ msg = "Non-sequential features is not yet supported"
201
+ raise NotImplementedError(msg)
210
202
 
211
- if tensor_info.is_cat:
212
- dim = tensor_info.embedding_dim
213
- else:
214
- dim = tensor_info.tensor_dim
203
+ dim = tensor_info.embedding_dim if tensor_info.is_cat else tensor_info.tensor_dim
215
204
 
216
205
  if aggregation_method == "sum":
217
206
  if common_dim is None:
218
207
  common_dim = dim
219
208
 
220
209
  if dim != common_dim:
221
- raise ValueError("Dimension of all features must be the same for sum aggregation")
210
+ msg = "Dimension of all features must be the same for sum aggregation"
211
+ raise ValueError(msg)
222
212
  else:
223
213
  raise NotImplementedError()
224
214
 
@@ -242,7 +232,7 @@ class BertEmbedding(torch.nn.Module):
242
232
  :returns: Embeddings for input features.
243
233
  """
244
234
  if self.aggregation_method == "sum":
245
- aggregated_embedding: torch.Tensor = None # type: ignore
235
+ aggregated_embedding: torch.Tensor = None
246
236
 
247
237
  for feature_name in self.schema.categorical_features:
248
238
  x = inputs[feature_name]
@@ -307,7 +297,7 @@ class BertEmbedding(torch.nn.Module):
307
297
  embeddings = {
308
298
  "item_embedding": self.item_embeddings.data.detach().clone(),
309
299
  }
310
- for feature_name, _ in self.schema.items():
300
+ for feature_name in self.schema:
311
301
  if feature_name != self.schema.item_id_feature_name:
312
302
  embeddings[feature_name] = self.cat_embeddings[feature_name].weight.data.detach().clone()
313
303
  if self.enable_positional_embedding:
@@ -335,7 +325,6 @@ class PositionalEmbedding(torch.nn.Module):
335
325
  Positional embedding.
336
326
  """
337
327
 
338
- # pylint: disable=invalid-name
339
328
  def __init__(self, max_len: int, d_model: int) -> None:
340
329
  """
341
330
  :param max_len: Max sequence length.
@@ -477,7 +466,6 @@ class TransformerBlock(torch.nn.Module):
477
466
 
478
467
  self.dropout = torch.nn.Dropout(p=dropout)
479
468
 
480
- # pylint: disable=invalid-name
481
469
  def forward(
482
470
  self,
483
471
  x: torch.Tensor,
@@ -537,7 +525,6 @@ class MultiHeadedAttention(torch.nn.Module):
537
525
  Take in model size and number of heads.
538
526
  """
539
527
 
540
- # pylint: disable=invalid-name
541
528
  def __init__(self, h: int, d_model: int, dropout: float = 0.1) -> None:
542
529
  """
543
530
  :param h: Head sizes of multi-head attention.
@@ -2,8 +2,8 @@ from .prediction_callbacks import (
2
2
  BasePredictionCallback,
3
3
  PandasPredictionCallback,
4
4
  PolarsPredictionCallback,
5
+ QueryEmbeddingsPredictionCallback,
5
6
  SparkPredictionCallback,
6
7
  TorchPredictionCallback,
7
- QueryEmbeddingsPredictionCallback
8
8
  )
9
9
  from .validation_callback import ValidationMetricsCallback
@@ -1,39 +1,37 @@
1
1
  import abc
2
2
  from typing import Generic, List, Optional, Protocol, Tuple, TypeVar, cast
3
3
 
4
- import lightning as L
4
+ import lightning
5
5
  import torch
6
6
 
7
7
  from replay.models.nn.sequential import Bert4Rec
8
8
  from replay.models.nn.sequential.postprocessors import BasePostProcessor
9
- from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, PolarsDataFrame, SparkDataFrame, MissingImportType
9
+ from replay.utils import PYSPARK_AVAILABLE, MissingImportType, PandasDataFrame, PolarsDataFrame, SparkDataFrame
10
10
 
11
11
  if PYSPARK_AVAILABLE: # pragma: no cover
12
+ import pyspark.sql.functions as sf
12
13
  from pyspark.sql import SparkSession
13
- import pyspark.sql.functions as F
14
14
  from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType
15
15
  else:
16
16
  SparkSession = MissingImportType
17
17
 
18
18
 
19
- # pylint: disable=too-few-public-methods
20
19
  class PredictionBatch(Protocol):
21
20
  """
22
21
  Prediction callback batch
23
22
  """
23
+
24
24
  query_id: torch.LongTensor
25
25
 
26
26
 
27
27
  _T = TypeVar("_T")
28
28
 
29
29
 
30
- # pylint: disable=too-many-instance-attributes
31
- class BasePredictionCallback(L.Callback, Generic[_T]):
30
+ class BasePredictionCallback(lightning.Callback, Generic[_T]):
32
31
  """
33
32
  Base callback for prediction stage
34
33
  """
35
34
 
36
- # pylint: disable=too-many-arguments
37
35
  def __init__(
38
36
  self,
39
37
  top_k: int,
@@ -59,21 +57,21 @@ class BasePredictionCallback(L.Callback, Generic[_T]):
59
57
  self._item_batches: List[torch.Tensor] = []
60
58
  self._item_scores: List[torch.Tensor] = []
61
59
 
62
- # pylint: disable=unused-argument
63
- def on_predict_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
60
+ def on_predict_epoch_start(
61
+ self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
62
+ ) -> None:
64
63
  self._query_batches.clear()
65
64
  self._item_batches.clear()
66
65
  self._item_scores.clear()
67
66
 
68
- # pylint: disable=unused-argument, too-many-arguments
69
67
  def on_predict_batch_end(
70
68
  self,
71
- trainer: L.Trainer,
72
- pl_module: L.LightningModule,
69
+ trainer: lightning.Trainer, # noqa: ARG002
70
+ pl_module: lightning.LightningModule, # noqa: ARG002
73
71
  outputs: torch.Tensor,
74
72
  batch: PredictionBatch,
75
- batch_idx: int,
76
- dataloader_idx: int = 0,
73
+ batch_idx: int, # noqa: ARG002
74
+ dataloader_idx: int = 0, # noqa: ARG002
77
75
  ) -> None:
78
76
  query_ids, scores = self._compute_pipeline(batch.query_id, outputs)
79
77
  top_scores, top_item_ids = torch.topk(scores, k=self._top_k, dim=1)
@@ -157,7 +155,6 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
157
155
  Callback for prediction stage with spark data frame
158
156
  """
159
157
 
160
- # pylint: disable=too-many-arguments
161
158
  def __init__(
162
159
  self,
163
160
  top_k: int,
@@ -206,7 +203,7 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
206
203
  ),
207
204
  schema=schema,
208
205
  )
209
- .withColumn("exploded_columns", F.explode(F.arrays_zip(self.item_column, self.rating_column)))
206
+ .withColumn("exploded_columns", sf.explode(sf.arrays_zip(self.item_column, self.rating_column)))
210
207
  .select(self.query_column, f"exploded_columns.{self.item_column}", f"exploded_columns.{self.rating_column}")
211
208
  )
212
209
  return prediction
@@ -247,26 +244,27 @@ class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, tor
247
244
  )
248
245
 
249
246
 
250
- class QueryEmbeddingsPredictionCallback(L.Callback):
247
+ class QueryEmbeddingsPredictionCallback(lightning.Callback):
251
248
  """
252
249
  Callback for prediction stage to get query embeddings.
253
250
  """
251
+
254
252
  def __init__(self):
255
253
  self._embeddings_per_batch: List[torch.Tensor] = []
256
254
 
257
- # pylint: disable=unused-argument
258
- def on_predict_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
255
+ def on_predict_epoch_start(
256
+ self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
257
+ ) -> None:
259
258
  self._embeddings_per_batch.clear()
260
259
 
261
- # pylint: disable=unused-argument, too-many-arguments
262
260
  def on_predict_batch_end(
263
261
  self,
264
- trainer: L.Trainer,
265
- pl_module: L.LightningModule,
266
- outputs: torch.Tensor,
262
+ trainer: lightning.Trainer, # noqa: ARG002
263
+ pl_module: lightning.LightningModule,
264
+ outputs: torch.Tensor, # noqa: ARG002
267
265
  batch: PredictionBatch,
268
- batch_idx: int,
269
- dataloader_idx: int = 0,
266
+ batch_idx: int, # noqa: ARG002
267
+ dataloader_idx: int = 0, # noqa: ARG002
270
268
  ) -> None:
271
269
  args = [batch.features, batch.padding_mask]
272
270
  if isinstance(pl_module, Bert4Rec):