replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -17,7 +17,7 @@ from replay.data.nn import (
17
17
  class Bert4RecTrainingBatch(NamedTuple):
18
18
  """
19
19
  Batch of data for training.
20
- Generated by `Bert4RecTrainingDataset`.
20
+ Generated by ``Bert4RecTrainingDataset``.
21
21
  """
22
22
 
23
23
  query_id: torch.LongTensor
@@ -26,6 +26,15 @@ class Bert4RecTrainingBatch(NamedTuple):
26
26
  tokens_mask: torch.BoolTensor
27
27
  labels: torch.LongTensor
28
28
 
29
+ def convert_to_dict(self) -> dict:
30
+ return {
31
+ "query_id": self.query_id,
32
+ "pad_mask": self.padding_mask,
33
+ "inputs": self.features,
34
+ "token_mask": self.tokens_mask,
35
+ "positive_labels": self.labels,
36
+ }
37
+
29
38
 
30
39
  class Bert4RecMasker(abc.ABC):
31
40
  """
@@ -85,7 +94,12 @@ class Bert4RecUniformMasker(Bert4RecMasker):
85
94
 
86
95
  class Bert4RecTrainingDataset(TorchDataset):
87
96
  """
88
- Dataset that generates samples to train BERT-like model
97
+ Dataset that generates samples to train Bert4Rec model.
98
+
99
+ As a result of the dataset iteration, a dictionary is formed.
100
+ The keys in the dictionary match the names of the arguments in the model's `forward` function.
101
+ There are also additional keys needed to calculate losses - 'positive_labels`.
102
+ The `query_id` key is required for possible debugging and calling additional lightning callbacks.
89
103
  """
90
104
 
91
105
  def __init__(
@@ -143,26 +157,26 @@ class Bert4RecTrainingDataset(TorchDataset):
143
157
  def __len__(self) -> int:
144
158
  return len(self._inner)
145
159
 
146
- def __getitem__(self, index: int) -> Bert4RecTrainingBatch:
160
+ def __getitem__(self, index: int) -> dict:
147
161
  query_id, padding_mask, features = self._inner[index]
148
162
  tokens_mask = self._masker.mask(padding_mask)
149
163
 
150
164
  assert self._label_feature_name
151
165
  labels = features[self._label_feature_name]
152
166
 
153
- return Bert4RecTrainingBatch(
154
- query_id=query_id,
155
- padding_mask=padding_mask,
156
- features=features,
157
- tokens_mask=tokens_mask,
158
- labels=cast(torch.LongTensor, labels),
159
- )
167
+ return {
168
+ "query_id": query_id,
169
+ "pad_mask": padding_mask,
170
+ "inputs": features,
171
+ "token_mask": tokens_mask,
172
+ "positive_labels": labels,
173
+ }
160
174
 
161
175
 
162
176
  class Bert4RecPredictionBatch(NamedTuple):
163
177
  """
164
178
  Batch of data for model inference.
165
- Generated by `Bert4RecPredictionDataset`.
179
+ Generated by ``Bert4RecPredictionDataset``.
166
180
  """
167
181
 
168
182
  query_id: torch.LongTensor
@@ -170,10 +184,22 @@ class Bert4RecPredictionBatch(NamedTuple):
170
184
  features: TensorMap
171
185
  tokens_mask: torch.BoolTensor
172
186
 
187
+ def convert_to_dict(self) -> dict:
188
+ return {
189
+ "query_id": self.query_id,
190
+ "pad_mask": self.padding_mask,
191
+ "inputs": self.features,
192
+ "token_mask": self.tokens_mask,
193
+ }
194
+
173
195
 
174
196
  class Bert4RecPredictionDataset(TorchDataset):
175
197
  """
176
- Dataset that generates samples to infer BERT-like model
198
+ Dataset that generates samples to inference Bert4Rec model
199
+
200
+ As a result of the dataset iteration, a dictionary is formed.
201
+ The keys in the dictionary match the names of the arguments in the model's `forward` function.
202
+ The `query_id` key is required for possible debugging and calling additional lightning callbacks.
177
203
  """
178
204
 
179
205
  def __init__(
@@ -198,23 +224,23 @@ class Bert4RecPredictionDataset(TorchDataset):
198
224
  def __len__(self) -> int:
199
225
  return len(self._inner)
200
226
 
201
- def __getitem__(self, index: int) -> Bert4RecPredictionBatch:
227
+ def __getitem__(self, index: int) -> dict:
202
228
  query_id, padding_mask, features = self._inner[index]
203
229
 
204
230
  shifted_features, shifted_padding_mask, tokens_mask = _shift_features(self._schema, features, padding_mask)
205
231
 
206
- return Bert4RecPredictionBatch(
207
- query_id=query_id,
208
- padding_mask=shifted_padding_mask,
209
- features=shifted_features,
210
- tokens_mask=tokens_mask,
211
- )
232
+ return {
233
+ "query_id": query_id,
234
+ "pad_mask": shifted_padding_mask,
235
+ "inputs": shifted_features,
236
+ "token_mask": tokens_mask,
237
+ }
212
238
 
213
239
 
214
240
  class Bert4RecValidationBatch(NamedTuple):
215
241
  """
216
242
  Batch of data for validation.
217
- Generated by `Bert4RecValidationDataset`.
243
+ Generated by ``Bert4RecValidationDataset``.
218
244
  """
219
245
 
220
246
  query_id: torch.LongTensor
@@ -224,10 +250,25 @@ class Bert4RecValidationBatch(NamedTuple):
224
250
  ground_truth: torch.LongTensor
225
251
  train: torch.LongTensor
226
252
 
253
+ def convert_to_dict(self) -> dict:
254
+ return {
255
+ "query_id": self.query_id,
256
+ "pad_mask": self.padding_mask,
257
+ "inputs": self.features,
258
+ "token_mask": self.tokens_mask,
259
+ "ground_truth": self.ground_truth,
260
+ "train": self.train,
261
+ }
262
+
227
263
 
228
264
  class Bert4RecValidationDataset(TorchDataset):
229
265
  """
230
266
  Dataset that generates samples to infer and validate BERT-like model
267
+
268
+ As a result of the dataset iteration, a dictionary is formed.
269
+ The keys in the dictionary match the names of the arguments in the model's `forward` function.
270
+ The `query_id` key is required for possible debugging and calling additional lightning callbacks.
271
+ Keys 'ground_truth` and `train` keys are required for metrics calculation on validation stage.
231
272
  """
232
273
 
233
274
  def __init__(
@@ -263,19 +304,19 @@ class Bert4RecValidationDataset(TorchDataset):
263
304
  def __len__(self) -> int:
264
305
  return len(self._inner)
265
306
 
266
- def __getitem__(self, index: int) -> Bert4RecValidationBatch:
307
+ def __getitem__(self, index: int) -> dict:
267
308
  query_id, padding_mask, features, ground_truth, train = self._inner[index]
268
309
 
269
310
  shifted_features, shifted_padding_mask, tokens_mask = _shift_features(self._schema, features, padding_mask)
270
311
 
271
- return Bert4RecValidationBatch(
272
- query_id=query_id,
273
- padding_mask=shifted_padding_mask,
274
- features=shifted_features,
275
- tokens_mask=tokens_mask,
276
- ground_truth=ground_truth,
277
- train=train,
278
- )
312
+ return {
313
+ "query_id": query_id,
314
+ "pad_mask": shifted_padding_mask,
315
+ "inputs": shifted_features,
316
+ "token_mask": tokens_mask,
317
+ "ground_truth": ground_truth,
318
+ "train": train,
319
+ }
279
320
 
280
321
 
281
322
  def _shift_features(
@@ -1,4 +1,5 @@
1
1
  import math
2
+ import warnings
2
3
  from typing import Any, Literal, Optional, Union, cast
3
4
 
4
5
  import lightning
@@ -29,13 +30,13 @@ class Bert4Rec(lightning.LightningModule):
29
30
  enable_embedding_tying: bool = False,
30
31
  loss_type: Literal["BCE", "CE", "CE_restricted"] = "CE",
31
32
  loss_sample_count: Optional[int] = None,
32
- negative_sampling_strategy: str = "global_uniform",
33
+ negative_sampling_strategy: Literal["global_uniform", "inbatch"] = "global_uniform",
33
34
  negatives_sharing: bool = False,
34
35
  optimizer_factory: OptimizerFactory = FatOptimizerFactory(),
35
36
  lr_scheduler_factory: Optional[LRSchedulerFactory] = None,
36
37
  ):
37
38
  """
38
- :param tensor_schema (TensorSchema): Tensor schema of features.
39
+ :param tensor_schema: Tensor schema of features.
39
40
  :param block_count: Number of Transformer blocks.
40
41
  Default: ``2``.
41
42
  :param head_count: Number of Attention heads.
@@ -44,7 +45,7 @@ class Bert4Rec(lightning.LightningModule):
44
45
  Default: ``256``.
45
46
  :param max_seq_len: Max length of sequence.
46
47
  Default: ``100``.
47
- :param dropout_rate (float): Dropout rate.
48
+ :param dropout_rate: Dropout rate.
48
49
  Default: ``0.1``.
49
50
  :param pass_per_transformer_block_count: Number of times to pass data over each Transformer block.
50
51
  Default: ``1``.
@@ -54,19 +55,18 @@ class Bert4Rec(lightning.LightningModule):
54
55
  If `True` - result scores are calculated by dot product of input and output embeddings,
55
56
  if `False` - default linear layer is applied to calculate logits for each item.
56
57
  Default: ``False``.
57
- :param loss_type: Loss type. Possible values: ``"CE"``, ``"BCE"``, ``"CE_restricted"``.
58
+ :param loss_type: Loss type.
58
59
  Default: ``CE``.
59
- :param loss_sample_count (Optional[int]): Sample count to calculate loss.
60
+ :param loss_sample_count: Sample count to calculate loss.
60
61
  Default: ``None``.
61
62
  :param negative_sampling_strategy: Negative sampling strategy to calculate loss on sampled negatives.
62
- Is used when large count of items in dataset.
63
- Possible values: ``"global_uniform"``, ``"inbatch"``
63
+ Is used when large count of items in dataset.\n
64
64
  Default: ``global_uniform``.
65
- :param negatives_sharing: Apply negative sharing in calculating sampled logits.
65
+ :param negatives_sharing: Apply negative sharing in calculating sampled logits.\n
66
66
  Default: ``False``.
67
- :param optimizer_factory: Optimizer factory.
67
+ :param optimizer_factory: Optimizer factory.\n
68
68
  Default: ``FatOptimizerFactory``.
69
- :param lr_scheduler_factory: Learning rate schedule factory.
69
+ :param lr_scheduler_factory: Learning rate schedule factory.\n
70
70
  Default: ``None``.
71
71
  """
72
72
  super().__init__()
@@ -97,7 +97,7 @@ class Bert4Rec(lightning.LightningModule):
97
97
  self._vocab_size = item_count
98
98
  self.candidates_to_score = None
99
99
 
100
- def training_step(self, batch: Bert4RecTrainingBatch, batch_idx: int) -> torch.Tensor: # noqa: ARG002
100
+ def training_step(self, batch: Union[Bert4RecTrainingBatch, dict], batch_idx: int) -> torch.Tensor: # noqa: ARG002
101
101
  """
102
102
  :param batch: Batch of training data.
103
103
  :param batch_idx: Batch index.
@@ -109,7 +109,7 @@ class Bert4Rec(lightning.LightningModule):
109
109
  return loss
110
110
 
111
111
  def predict_step(
112
- self, batch: Bert4RecPredictionBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
112
+ self, batch: Union[Bert4RecPredictionBatch, dict], batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
113
113
  ) -> torch.Tensor:
114
114
  """
115
115
  :param batch (Bert4RecPredictionBatch): Batch of prediction data.
@@ -118,23 +118,49 @@ class Bert4Rec(lightning.LightningModule):
118
118
 
119
119
  :returns: Calculated scores on prediction batch.
120
120
  """
121
+ if isinstance(batch, Bert4RecPredictionBatch):
122
+ warnings.warn(
123
+ "`Bert4RecPredictionBatch` class will be removed in future versions. "
124
+ "Instead, you should use simple dictionary",
125
+ DeprecationWarning,
126
+ stacklevel=2,
127
+ )
128
+ batch = batch.convert_to_dict()
121
129
  batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
122
- return self._model_predict(batch.features, batch.padding_mask, batch.tokens_mask)
130
+ return self._model_predict(
131
+ feature_tensors=batch["inputs"],
132
+ padding_mask=batch["pad_mask"],
133
+ tokens_mask=batch["token_mask"],
134
+ )
123
135
 
124
136
  def predict(
125
137
  self,
126
- batch: Bert4RecPredictionBatch,
138
+ batch: Union[Bert4RecPredictionBatch, dict],
127
139
  candidates_to_score: Optional[torch.LongTensor] = None,
128
140
  ) -> torch.Tensor:
129
141
  """
130
- :param batch (Bert4RecPredictionBatch): Batch of prediction data.
142
+ :param batch: Batch of prediction data.
131
143
  :param candidates_to_score: Item ids to calculate scores.
132
144
  Default: ``None``.
133
145
 
134
146
  :returns: Calculated scores on prediction batch.
135
147
  """
148
+ if isinstance(batch, Bert4RecPredictionBatch):
149
+ warnings.warn(
150
+ "`Bert4RecPredictionBatch` class will be removed in future versions. "
151
+ "Instead, you should use simple dictionary",
152
+ DeprecationWarning,
153
+ stacklevel=2,
154
+ )
155
+ batch = batch.convert_to_dict()
156
+
136
157
  batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
137
- return self._model_predict(batch.features, batch.padding_mask, batch.tokens_mask, candidates_to_score)
158
+ return self._model_predict(
159
+ feature_tensors=batch["inputs"],
160
+ padding_mask=batch["pad_mask"],
161
+ tokens_mask=batch["token_mask"],
162
+ candidates_to_score=candidates_to_score,
163
+ )
138
164
 
139
165
  def forward(
140
166
  self,
@@ -152,10 +178,15 @@ class Bert4Rec(lightning.LightningModule):
152
178
 
153
179
  :returns: Calculated scores.
154
180
  """
155
- return self._model_predict(feature_tensors, padding_mask, tokens_mask, candidates_to_score)
181
+ return self._model_predict(
182
+ feature_tensors=feature_tensors,
183
+ padding_mask=padding_mask,
184
+ tokens_mask=tokens_mask,
185
+ candidates_to_score=candidates_to_score,
186
+ )
156
187
 
157
188
  def validation_step(
158
- self, batch: Bert4RecValidationBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
189
+ self, batch: Union[Bert4RecValidationBatch, dict], batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
159
190
  ) -> torch.Tensor:
160
191
  """
161
192
  :param batch: Batch of prediction data.
@@ -163,7 +194,20 @@ class Bert4Rec(lightning.LightningModule):
163
194
 
164
195
  :returns: Calculated scores on validation batch.
165
196
  """
166
- return self._model_predict(batch.features, batch.padding_mask, batch.tokens_mask)
197
+ if isinstance(batch, Bert4RecValidationBatch):
198
+ warnings.warn(
199
+ "`Bert4RecValidationBatch` class will be removed in future versions. "
200
+ "Instead, you should use simple dictionary",
201
+ DeprecationWarning,
202
+ stacklevel=2,
203
+ )
204
+ batch = batch.convert_to_dict()
205
+
206
+ return self._model_predict(
207
+ feature_tensors=batch["inputs"],
208
+ padding_mask=batch["pad_mask"],
209
+ tokens_mask=batch["token_mask"],
210
+ )
167
211
 
168
212
  def configure_optimizers(self) -> Any:
169
213
  """
@@ -189,10 +233,15 @@ class Bert4Rec(lightning.LightningModule):
189
233
  cast(Bert4RecModel, self._model.module) if isinstance(self._model, torch.nn.DataParallel) else self._model
190
234
  )
191
235
  candidates_to_score = self.candidates_to_score if candidates_to_score is None else candidates_to_score
192
- scores = model.predict(feature_tensors, padding_mask, tokens_mask, candidates_to_score)
236
+ scores = model.predict(
237
+ inputs=feature_tensors,
238
+ pad_mask=padding_mask,
239
+ token_mask=tokens_mask,
240
+ candidates_to_score=candidates_to_score,
241
+ )
193
242
  return scores
194
243
 
195
- def _compute_loss(self, batch: Bert4RecTrainingBatch) -> torch.Tensor:
244
+ def _compute_loss(self, batch: Union[Bert4RecTrainingBatch, dict]) -> torch.Tensor:
196
245
  if self._loss_type == "BCE":
197
246
  loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
198
247
  elif self._loss_type == "CE":
@@ -203,11 +252,20 @@ class Bert4Rec(lightning.LightningModule):
203
252
  msg = f"Not supported loss type: {self._loss_type}"
204
253
  raise ValueError(msg)
205
254
 
255
+ if isinstance(batch, Bert4RecTrainingBatch):
256
+ warnings.warn(
257
+ "`Bert4RecTrainingBatch` class will be removed in future versions. "
258
+ "Instead, you should use simple dictionary",
259
+ DeprecationWarning,
260
+ stacklevel=2,
261
+ )
262
+ batch = batch.convert_to_dict()
263
+
206
264
  loss = loss_func(
207
- batch.features,
208
- batch.labels,
209
- batch.padding_mask, # 0 - padding_idx, 1 - other tokens
210
- batch.tokens_mask, # 0 - masked token, 1 - non-masked token
265
+ batch["inputs"],
266
+ batch["positive_labels"],
267
+ batch["pad_mask"],
268
+ batch["token_mask"],
211
269
  )
212
270
 
213
271
  return loss
@@ -253,7 +311,7 @@ class Bert4Rec(lightning.LightningModule):
253
311
  padding_mask: torch.BoolTensor,
254
312
  tokens_mask: torch.BoolTensor,
255
313
  ) -> torch.Tensor:
256
- (positive_logits, negative_logits, *_) = self._get_sampled_logits(
314
+ positive_logits, negative_logits, *_ = self._get_sampled_logits(
257
315
  feature_tensors, positive_labels, padding_mask, tokens_mask
258
316
  )
259
317
 
@@ -300,7 +358,7 @@ class Bert4Rec(lightning.LightningModule):
300
358
  tokens_mask: torch.BoolTensor,
301
359
  ) -> torch.Tensor:
302
360
  assert self._loss_sample_count is not None
303
- (positive_logits, negative_logits, positive_labels, negative_labels, vocab_size) = self._get_sampled_logits(
361
+ positive_logits, negative_logits, positive_labels, negative_labels, vocab_size = self._get_sampled_logits(
304
362
  feature_tensors, positive_labels, padding_mask, tokens_mask
305
363
  )
306
364
  n_negative_samples = min(self._loss_sample_count, vocab_size)
@@ -325,7 +383,7 @@ class Bert4Rec(lightning.LightningModule):
325
383
  padding_mask: torch.BoolTensor,
326
384
  tokens_mask: torch.BoolTensor,
327
385
  ) -> torch.Tensor:
328
- (logits, labels) = self._get_restricted_logits_for_ce_loss(
386
+ logits, labels = self._get_restricted_logits_for_ce_loss(
329
387
  feature_tensors, positive_labels, padding_mask, tokens_mask
330
388
  )
331
389
 
@@ -588,20 +646,20 @@ class Bert4Rec(lightning.LightningModule):
588
646
  self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(new_vocab_size)
589
647
 
590
648
 
591
- def _prepare_prediction_batch(
592
- schema: TensorSchema, max_len: int, batch: Bert4RecPredictionBatch
593
- ) -> Bert4RecPredictionBatch:
594
- if batch.padding_mask.shape[1] > max_len:
649
+ def _prepare_prediction_batch(schema: TensorSchema, max_len: int, batch: dict) -> dict:
650
+ seq_len = batch["pad_mask"].shape[1]
651
+ if seq_len > max_len:
595
652
  msg = (
596
653
  f"The length of the submitted sequence "
597
654
  "must not exceed the maximum length of the sequence. "
598
- f"The length of the sequence is given {batch.padding_mask.shape[1]}, "
655
+ f"The length of the sequence is given {seq_len}, "
599
656
  f"while the maximum length is {max_len}"
600
657
  )
601
658
  raise ValueError(msg)
602
659
 
603
- if batch.padding_mask.shape[1] < max_len:
604
- query_id, padding_mask, features, _ = batch
660
+ if seq_len < max_len:
661
+ padding_mask = batch["pad_mask"]
662
+ features = batch["inputs"].copy()
605
663
  sequence_item_count = padding_mask.shape[1]
606
664
  for feature_name, feature_tensor in features.items():
607
665
  if schema[feature_name].is_cat:
@@ -618,5 +676,8 @@ def _prepare_prediction_batch(
618
676
  ).unsqueeze(-1)
619
677
  padding_mask = torch.nn.functional.pad(padding_mask, (max_len - sequence_item_count, 0), value=0)
620
678
  shifted_features, shifted_padding_mask, tokens_mask = _shift_features(schema, features, padding_mask)
621
- batch = Bert4RecPredictionBatch(query_id, shifted_padding_mask, shifted_features, tokens_mask)
679
+
680
+ batch["pad_mask"] = shifted_padding_mask
681
+ batch["inputs"] = shifted_features
682
+ batch["token_mask"] = tokens_mask
622
683
  return batch
@@ -88,8 +88,8 @@ class Bert4RecModel(torch.nn.Module):
88
88
  def forward(self, inputs: TensorMap, pad_mask: torch.BoolTensor, token_mask: torch.BoolTensor) -> torch.Tensor:
89
89
  """
90
90
  :param inputs: Batch of features.
91
- :param pad_mask: Padding mask where 0 - <PAD>, 1 otherwise.
92
- :param token_mask: Token mask where 0 - <MASK> tokens, 1 otherwise.
91
+ :param pad_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
92
+ :param token_mask: Token mask where 0 - ``<MASK>`` tokens, 1 - otherwise.
93
93
 
94
94
  :returns: Calculated scores.
95
95
  """
@@ -107,12 +107,12 @@ class Bert4RecModel(torch.nn.Module):
107
107
  ) -> torch.Tensor:
108
108
  """
109
109
  :param inputs: Batch of features.
110
- :param pad_mask: Padding mask where 0 - <PAD>, 1 otherwise.
111
- :param token_mask: Token mask where 0 - <MASK> tokens, 1 otherwise.
112
- :param candidates_to_score: Item ids to calculate scores.
113
- if `None` predicts for all items
110
+ :param pad_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
111
+ :param token_mask: Token mask where 0 - ``<MASK>`` tokens, 1 - otherwise.
112
+ :param candidates_to_score: Item ids to calculate scores.\n
113
+ If ``None`` then predicts for all items. Default: ``None``.
114
114
 
115
- :returns: Calculated scores among canditates_to_score items.
115
+ :returns: Calculated scores among ``canditates_to_score`` items.
116
116
  """
117
117
  # final_emb: [B x E]
118
118
  final_emb = self.get_query_embeddings(inputs, pad_mask, token_mask)
@@ -123,8 +123,8 @@ class Bert4RecModel(torch.nn.Module):
123
123
  """
124
124
 
125
125
  :param inputs (TensorMap): Batch of features.
126
- :param pad_mask (torch.BoolTensor): Padding mask where 0 - <PAD>, 1 otherwise.
127
- :param token_mask (torch.BoolTensor): Token mask where 0 - <MASK> tokens, 1 otherwise.
126
+ :param pad_mask (torch.BoolTensor): Padding mask where 0 - ``<PAD>``, 1 - otherwise.
127
+ :param token_mask (torch.BoolTensor): Token mask where 0 - ``<MASK>`` tokens, 1 - otherwise.
128
128
 
129
129
  :returns: Output embeddings.
130
130
  """
@@ -158,8 +158,8 @@ class Bert4RecModel(torch.nn.Module):
158
158
  def get_query_embeddings(self, inputs: TensorMap, pad_mask: torch.BoolTensor, token_mask: torch.BoolTensor):
159
159
  """
160
160
  :param inputs: Batch of features.
161
- :param pad_mask: Padding mask where 0 - <PAD>, 1 otherwise.
162
- :param token_mask: Token mask where 0 - <MASK> tokens, 1 otherwise.
161
+ :param pad_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
162
+ :param token_mask: Token mask where 0 - ``<MASK>`` tokens, 1 - otherwise.
163
163
 
164
164
  :returns: Query embeddings.
165
165
  """
@@ -1,8 +1,10 @@
1
1
  import abc
2
- from typing import Generic, Optional, Protocol, TypeVar, cast
2
+ import inspect
3
+ from typing import Generic, Optional, Protocol, TypeVar, Union, cast
3
4
 
4
5
  import lightning
5
6
  import torch
7
+ from typing_extensions import deprecated
6
8
 
7
9
  from replay.models.nn.sequential import Bert4Rec
8
10
  from replay.models.nn.sequential.postprocessors import BasePostProcessor
@@ -16,6 +18,7 @@ else:
16
18
  SparkSession = MissingImport
17
19
 
18
20
 
21
+ @deprecated("`PredictionBatch` class is deprecated.", stacklevel=2)
19
22
  class PredictionBatch(Protocol):
20
23
  """
21
24
  Prediction callback batch
@@ -27,6 +30,10 @@ class PredictionBatch(Protocol):
27
30
  _T = TypeVar("_T")
28
31
 
29
32
 
33
+ @deprecated(
34
+ "`BasePredictionCallback` class is deprecated. Use `replay.nn.lightning.callback.TopItemsCallbackBase` instead.",
35
+ stacklevel=2,
36
+ )
30
37
  class BasePredictionCallback(lightning.Callback, Generic[_T]):
31
38
  """
32
39
  Base callback for prediction stage
@@ -48,6 +55,7 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
48
55
  :param postprocessors: postprocessors to apply.
49
56
  """
50
57
  super().__init__()
58
+
51
59
  self.query_column = query_column
52
60
  self.item_column = item_column
53
61
  self.rating_column = rating_column
@@ -74,11 +82,14 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
74
82
  trainer: lightning.Trainer, # noqa: ARG002
75
83
  pl_module: lightning.LightningModule, # noqa: ARG002
76
84
  outputs: torch.Tensor,
77
- batch: PredictionBatch,
85
+ batch: Union[PredictionBatch, dict],
78
86
  batch_idx: int, # noqa: ARG002
79
87
  dataloader_idx: int = 0, # noqa: ARG002
80
88
  ) -> None:
81
- query_ids, scores = self._compute_pipeline(batch.query_id, outputs)
89
+ query_ids, scores = self._compute_pipeline(
90
+ batch["query_id"] if isinstance(batch, dict) else batch.query_id,
91
+ outputs,
92
+ )
82
93
  top_scores, top_item_ids = torch.topk(scores, k=self._top_k, dim=1)
83
94
  self._query_batches.append(query_ids)
84
95
  self._item_batches.append(top_item_ids)
@@ -112,6 +123,10 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
112
123
  pass
113
124
 
114
125
 
126
+ @deprecated(
127
+ "`PandasPredictionCallback` class is deprecated. "
128
+ "Use `replay.nn.lightning.callback.PandasTopItemsCallback` instead."
129
+ )
115
130
  class PandasPredictionCallback(BasePredictionCallback[PandasDataFrame]):
116
131
  """
117
132
  Callback for predition stage with pandas data frame
@@ -133,6 +148,10 @@ class PandasPredictionCallback(BasePredictionCallback[PandasDataFrame]):
133
148
  return prediction.explode([self.item_column, self.rating_column])
134
149
 
135
150
 
151
+ @deprecated(
152
+ "`PolarsPredictionCallback` class is deprecated. "
153
+ "Use `replay.nn.lightning.callback.PolarsTopItemsCallback` instead."
154
+ )
136
155
  class PolarsPredictionCallback(BasePredictionCallback[PolarsDataFrame]):
137
156
  """
138
157
  Callback for predition stage with polars data frame
@@ -154,6 +173,10 @@ class PolarsPredictionCallback(BasePredictionCallback[PolarsDataFrame]):
154
173
  return prediction.explode([self.item_column, self.rating_column])
155
174
 
156
175
 
176
+ @deprecated(
177
+ "`SparkPredictionCallback` class is deprecated. "
178
+ "Use `replay.nn.lightning.callback.SparkTopItemsCallback` instead."
179
+ )
157
180
  class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
158
181
  """
159
182
  Callback for prediction stage with spark data frame
@@ -213,6 +236,10 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
213
236
  return prediction
214
237
 
215
238
 
239
+ @deprecated(
240
+ "`TorchPredictionCallback` class is deprecated. "
241
+ "Use `replay.nn.lightning.callback.TorchTopItemsCallback` instead."
242
+ )
216
243
  class TorchPredictionCallback(BasePredictionCallback[tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]]):
217
244
  """
218
245
  Callback for predition stage with tuple of tensors
@@ -248,6 +275,10 @@ class TorchPredictionCallback(BasePredictionCallback[tuple[torch.LongTensor, tor
248
275
  )
249
276
 
250
277
 
278
+ @deprecated(
279
+ "`QueryEmbeddingsPredictionCallback` class is deprecated. "
280
+ "Use `replay.nn.lightning.callback.HiddenStatesCallback` instead."
281
+ )
251
282
  class QueryEmbeddingsPredictionCallback(lightning.Callback):
252
283
  """
253
284
  Callback for prediction stage to get query embeddings.
@@ -266,15 +297,26 @@ class QueryEmbeddingsPredictionCallback(lightning.Callback):
266
297
  trainer: lightning.Trainer, # noqa: ARG002
267
298
  pl_module: lightning.LightningModule,
268
299
  outputs: torch.Tensor, # noqa: ARG002
269
- batch: PredictionBatch,
300
+ batch: Union[PredictionBatch, dict],
270
301
  batch_idx: int, # noqa: ARG002
271
302
  dataloader_idx: int = 0, # noqa: ARG002
272
303
  ) -> None:
273
- args = [batch.features, batch.padding_mask]
274
- if isinstance(pl_module, Bert4Rec):
275
- args.append(batch.tokens_mask)
304
+ if isinstance(batch, dict):
305
+ modified_batch = {
306
+ k: v
307
+ for k, v in batch.items()
308
+ if k in inspect.signature(pl_module._model.get_query_embeddings).parameters
309
+ }
310
+ query_embeddings = pl_module._model.get_query_embeddings(**modified_batch)
311
+ else:
312
+ args = [
313
+ batch.features,
314
+ batch.padding_mask,
315
+ ]
316
+ if isinstance(pl_module, Bert4Rec):
317
+ args.append(batch.tokens_mask)
318
+ query_embeddings = pl_module._model.get_query_embeddings(*args)
276
319
 
277
- query_embeddings = pl_module._model.get_query_embeddings(*args)
278
320
  self._embeddings_per_batch.append(query_embeddings)
279
321
 
280
322
  def get_result(self):