replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,8 +1,9 @@
1
- from typing import Any, Literal, Optional, Protocol
1
+ from typing import Any, Literal, Optional, Protocol, Union
2
2
 
3
3
  import lightning
4
4
  import torch
5
5
  from lightning.pytorch.utilities.rank_zero import rank_zero_only
6
+ from typing_extensions import deprecated
6
7
 
7
8
  from replay.metrics.torch_metrics_builder import TorchMetricsBuilder, metrics_to_df
8
9
  from replay.models.nn.sequential.postprocessors import BasePostProcessor
@@ -18,6 +19,7 @@ CallbackMetricName = Literal[
18
19
  ]
19
20
 
20
21
 
22
+ @deprecated("`ValidationBatch` class is deprecated.", stacklevel=2)
21
23
  class ValidationBatch(Protocol):
22
24
  """
23
25
  Validation callback batch
@@ -28,12 +30,19 @@ class ValidationBatch(Protocol):
28
30
  train: torch.LongTensor
29
31
 
30
32
 
33
+ @deprecated(
34
+ "`ValidationMetricsCallback` class is deprecated. "
35
+ "Use `replay.nn.lightning.callback.ComputeMetricsCallback` instead."
36
+ )
31
37
  class ValidationMetricsCallback(lightning.Callback):
32
38
  """
33
39
  Callback for validation and testing stages.
34
40
 
35
41
  If multiple validation/testing dataloaders are used,
36
42
  the suffix of the metric name will contain the serial number of the dataloader.
43
+
44
+ For the callback to work correctly, the batch must contain the `query_id` and `ground_truth` keys.
45
+ If you want to calculate the coverage or novelty metrics then the batch must additionally contain the `train` key.
37
46
  """
38
47
 
39
48
  def __init__(
@@ -95,7 +104,7 @@ class ValidationMetricsCallback(lightning.Callback):
95
104
  trainer: lightning.Trainer,
96
105
  pl_module: lightning.LightningModule,
97
106
  outputs: torch.Tensor,
98
- batch: ValidationBatch,
107
+ batch: Union[ValidationBatch, dict],
99
108
  batch_idx: int,
100
109
  dataloader_idx: int = 0,
101
110
  ) -> None:
@@ -106,7 +115,7 @@ class ValidationMetricsCallback(lightning.Callback):
106
115
  trainer: lightning.Trainer,
107
116
  pl_module: lightning.LightningModule,
108
117
  outputs: torch.Tensor,
109
- batch: ValidationBatch,
118
+ batch: Union[ValidationBatch, dict],
110
119
  batch_idx: int,
111
120
  dataloader_idx: int = 0,
112
121
  ) -> None: # pragma: no cover
@@ -117,13 +126,21 @@ class ValidationMetricsCallback(lightning.Callback):
117
126
  trainer: lightning.Trainer, # noqa: ARG002
118
127
  pl_module: lightning.LightningModule,
119
128
  outputs: torch.Tensor,
120
- batch: ValidationBatch,
129
+ batch: Union[ValidationBatch, dict],
121
130
  batch_idx: int,
122
131
  dataloader_idx: int,
123
132
  ) -> None:
124
- _, seen_scores, seen_ground_truth = self._compute_pipeline(batch.query_id, outputs, batch.ground_truth)
133
+ _, seen_scores, seen_ground_truth = self._compute_pipeline(
134
+ batch["query_id"] if isinstance(batch, dict) else batch.query_id,
135
+ outputs,
136
+ batch["ground_truth"] if isinstance(batch, dict) else batch.ground_truth,
137
+ )
125
138
  sampled_items = torch.topk(seen_scores, k=self._metrics_builders[dataloader_idx].max_k, dim=1).indices
126
- self._metrics_builders[dataloader_idx].add_prediction(sampled_items, seen_ground_truth, batch.train)
139
+ self._metrics_builders[dataloader_idx].add_prediction(
140
+ sampled_items,
141
+ seen_ground_truth,
142
+ batch.get("train") if isinstance(batch, dict) else batch.train,
143
+ )
127
144
 
128
145
  if batch_idx + 1 == self._dataloaders_size[dataloader_idx]:
129
146
  pl_module.log_dict(
@@ -101,17 +101,24 @@ class BaseCompiledModel:
101
101
  )
102
102
  raise ValueError(msg)
103
103
 
104
- def _valilade_predict_input(self, batch: Any, candidates_to_score: Optional[torch.LongTensor] = None) -> None:
104
+ def _validate_predict_input(
105
+ self,
106
+ batch: Any,
107
+ candidates_to_score: Optional[torch.LongTensor] = None,
108
+ padding_mask_key_name: str = "padding_mask",
109
+ ) -> None:
105
110
  if self._num_candidates_to_score is None and candidates_to_score is not None:
106
111
  msg = (
107
112
  "If ``num_candidates_to_score`` is None, "
108
113
  "it is impossible to infer the model with passed ``candidates_to_score``."
109
114
  )
110
115
  raise ValueError(msg)
111
-
112
- if self._batch_size != -1 and batch.padding_mask.shape[0] != self._batch_size:
116
+ input_batch_size = (
117
+ batch[padding_mask_key_name].shape[0] if isinstance(batch, dict) else batch.padding_mask.shape[0]
118
+ )
119
+ if self._batch_size != -1 and input_batch_size != self._batch_size:
113
120
  msg = (
114
- f"The batch is smaller then defined batch_size={self._batch_size}. "
121
+ f"The batch is smaller than defined batch_size={self._batch_size}. "
115
122
  "It is impossible to infer the model with dynamic batch size in ``mode`` = ``batch``. "
116
123
  "Use ``mode`` = ``dynamic_batch_size``."
117
124
  )
@@ -215,6 +222,7 @@ class BaseCompiledModel:
215
222
  input_names=model_input_names,
216
223
  output_names=["scores"],
217
224
  dynamic_axes=model_dynamic_axes_in_input,
225
+ dynamo=False,
218
226
  )
219
227
  del lightning_model
220
228
 
@@ -1,4 +1,5 @@
1
1
  import pathlib
2
+ import warnings
2
3
  from typing import Optional, Union, get_args
3
4
 
4
5
  import openvino as ov
@@ -39,7 +40,7 @@ class Bert4RecCompiled(BaseCompiledModel):
39
40
 
40
41
  def predict(
41
42
  self,
42
- batch: Bert4RecPredictionBatch,
43
+ batch: Union[Bert4RecPredictionBatch, dict],
43
44
  candidates_to_score: Optional[torch.LongTensor] = None,
44
45
  ) -> torch.Tensor:
45
46
  """
@@ -51,13 +52,22 @@ class Bert4RecCompiled(BaseCompiledModel):
51
52
 
52
53
  :return: Tensor with scores.
53
54
  """
54
- self._valilade_predict_input(batch, candidates_to_score)
55
+ self._validate_predict_input(batch, candidates_to_score, "pad_mask")
56
+
57
+ if isinstance(batch, Bert4RecPredictionBatch):
58
+ warnings.warn(
59
+ "`Bert4RecPredictionBatch` class will be removed in future versions. "
60
+ "Instead, you should use simple dictionary",
61
+ DeprecationWarning,
62
+ stacklevel=2,
63
+ )
64
+ batch = batch.convert_to_dict()
55
65
 
56
66
  batch = _prepare_prediction_batch(self._schema, self._max_seq_len, batch)
57
67
  model_inputs = {
58
- self._inputs_names[0]: batch.features[self._inputs_names[0]],
59
- self._inputs_names[1]: batch.padding_mask,
60
- self._inputs_names[2]: batch.tokens_mask,
68
+ self._inputs_names[0]: batch["inputs"][self._inputs_names[0]],
69
+ self._inputs_names[1]: batch["pad_mask"],
70
+ self._inputs_names[2]: batch["token_mask"],
61
71
  }
62
72
  if self._num_candidates_to_score is not None:
63
73
  self._validate_candidates_to_score(candidates_to_score)
@@ -1,4 +1,5 @@
1
1
  import pathlib
2
+ import warnings
2
3
  from typing import Optional, Union, get_args
3
4
 
4
5
  import openvino as ov
@@ -39,7 +40,7 @@ class SasRecCompiled(BaseCompiledModel):
39
40
 
40
41
  def predict(
41
42
  self,
42
- batch: SasRecPredictionBatch,
43
+ batch: Union[SasRecPredictionBatch, dict],
43
44
  candidates_to_score: Optional[torch.LongTensor] = None,
44
45
  ) -> torch.Tensor:
45
46
  """
@@ -51,12 +52,21 @@ class SasRecCompiled(BaseCompiledModel):
51
52
 
52
53
  :return: Tensor with scores.
53
54
  """
54
- self._valilade_predict_input(batch, candidates_to_score)
55
+ self._validate_predict_input(batch, candidates_to_score)
56
+
57
+ if isinstance(batch, SasRecPredictionBatch):
58
+ warnings.warn(
59
+ "`SasRecPredictionBatch` class will be removed in future versions. "
60
+ "Instead, you should use simple dictionary",
61
+ DeprecationWarning,
62
+ stacklevel=2,
63
+ )
64
+ batch = batch.convert_to_dict()
55
65
 
56
66
  batch = _prepare_prediction_batch(self._schema, self._max_seq_len, batch)
57
67
  model_inputs = {
58
- self._inputs_names[0]: batch.features[self._inputs_names[0]],
59
- self._inputs_names[1]: batch.padding_mask,
68
+ self._inputs_names[0]: batch["feature_tensor"][self._inputs_names[0]],
69
+ self._inputs_names[1]: batch["padding_mask"],
60
70
  }
61
71
  if self._num_candidates_to_score is not None:
62
72
  self._validate_candidates_to_score(candidates_to_score)
@@ -77,15 +87,14 @@ class SasRecCompiled(BaseCompiledModel):
77
87
  Model compilation.
78
88
 
79
89
  :param model: Path to lightning SasRec model saved in .ckpt format or the SasRec object itself.
80
- :param mode: Inference mode, defines shape of inputs.
81
- Could be one of [``one_query``, ``batch``, ``dynamic_batch_size``].\n
90
+ :param mode: Inference mode, defines shape of inputs.\n
82
91
  ``one_query`` - sets input shape to [1, max_seq_len]\n
83
92
  ``batch`` - sets input shape to [batch_size, max_seq_len]\n
84
93
  ``dynamic_batch_size`` - sets batch_size to dynamic range [?, max_seq_len]\n
85
94
  Default: ``one_query``.
86
95
  :param batch_size: Batch size, required for ``batch`` mode.
87
96
  Default: ``None``.
88
- :param num_candidates_to_score: Number of item ids to calculate scores.
97
+ :param num_candidates_to_score: Number of item ids to calculate scores.\n
89
98
  Could be one of [``None``, ``-1``, ``N``].\n
90
99
  ``-1`` - sets candidates_to_score shape to dynamic range [1, ?]\n
91
100
  ``N`` - sets candidates_to_score shape to [1, N]\n
@@ -1,8 +1,13 @@
1
1
  import abc
2
2
 
3
3
  import torch
4
+ from typing_extensions import deprecated
4
5
 
5
6
 
7
+ @deprecated(
8
+ "`BasePostProcessor` class is deprecated. Use `replay.nn.lightning.postprocessor.PostprocessorBase` instead.",
9
+ stacklevel=2,
10
+ )
6
11
  class BasePostProcessor(abc.ABC): # pragma: no cover
7
12
  """
8
13
  Abstract base class for post processor
@@ -3,12 +3,14 @@ from typing import Optional, Union, cast
3
3
  import numpy as np
4
4
  import pandas as pd
5
5
  import torch
6
+ from typing_extensions import deprecated
6
7
 
7
8
  from replay.data.nn import SequentialDataset
8
9
 
9
10
  from ._base import BasePostProcessor
10
11
 
11
12
 
13
+ @deprecated("`RemoveSeenItems` class is deprecated. Use `replay.nn.lightning.postprocessor.SeenItemsFilter` instead.")
12
14
  class RemoveSeenItems(BasePostProcessor):
13
15
  """
14
16
  Filters out the items that already have been seen in dataset.
@@ -16,6 +18,7 @@ class RemoveSeenItems(BasePostProcessor):
16
18
 
17
19
  def __init__(self, sequential: SequentialDataset) -> None:
18
20
  super().__init__()
21
+
19
22
  self._sequential = sequential
20
23
  self._apply_candidates = False
21
24
  self._candidates = None
@@ -107,6 +110,7 @@ class RemoveSeenItems(BasePostProcessor):
107
110
  self._candidates = candidates
108
111
 
109
112
 
113
+ @deprecated("`SampleItems` class is deprecated.")
110
114
  class SampleItems(BasePostProcessor):
111
115
  """
112
116
  Generates negative samples to compute sampled metrics
@@ -1,7 +1,8 @@
1
- from typing import NamedTuple, Optional, cast
1
+ from typing import NamedTuple, Optional
2
2
 
3
3
  import torch
4
4
  from torch.utils.data import Dataset as TorchDataset
5
+ from typing_extensions import deprecated
5
6
 
6
7
  from replay.data.nn import (
7
8
  MutableTensorMap,
@@ -12,6 +13,10 @@ from replay.data.nn import (
12
13
  )
13
14
 
14
15
 
16
+ @deprecated(
17
+ "`SasRecTrainingBatch` class is deprecated.",
18
+ stacklevel=2,
19
+ )
15
20
  class SasRecTrainingBatch(NamedTuple):
16
21
  """
17
22
  Batch of data for training.
@@ -24,10 +29,25 @@ class SasRecTrainingBatch(NamedTuple):
24
29
  labels: torch.LongTensor
25
30
  labels_padding_mask: torch.BoolTensor
26
31
 
32
+ def convert_to_dict(self) -> dict:
33
+ return {
34
+ "query_id": self.query_id,
35
+ "feature_tensor": self.features,
36
+ "padding_mask": self.padding_mask,
37
+ "positive_labels": self.labels,
38
+ "target_padding_mask": self.labels_padding_mask,
39
+ }
40
+
27
41
 
42
+ @deprecated("`SasRecTrainingDataset` class is deprecated. Use `replay.data.nn.ParquetModule` instead.")
28
43
  class SasRecTrainingDataset(TorchDataset):
29
44
  """
30
- Dataset that generates samples to train SasRec-like model
45
+ Dataset that generates samples to train SasRec model.
46
+
47
+ As a result of the dataset iteration, a dictionary is formed.
48
+ The keys in the dictionary match the names of the arguments in the model's `forward` function.
49
+ There are also additional keys needed to calculate losses - 'positive_labels`, `target_padding_mask`.
50
+ The `query_id` key is required for possible debugging and calling additional lightning callbacks.
31
51
  """
32
52
 
33
53
  def __init__(
@@ -81,7 +101,7 @@ class SasRecTrainingDataset(TorchDataset):
81
101
  def __len__(self) -> int:
82
102
  return len(self._inner)
83
103
 
84
- def __getitem__(self, index: int) -> SasRecTrainingBatch:
104
+ def __getitem__(self, index: int) -> dict:
85
105
  query_id, padding_mask, features = self._inner[index]
86
106
 
87
107
  assert self._label_feature_name
@@ -97,15 +117,19 @@ class SasRecTrainingDataset(TorchDataset):
97
117
 
98
118
  output_features_padding_mask = padding_mask[: -self._sequence_shift]
99
119
 
100
- return SasRecTrainingBatch(
101
- query_id=query_id,
102
- features=output_features,
103
- padding_mask=cast(torch.BoolTensor, output_features_padding_mask),
104
- labels=cast(torch.LongTensor, labels),
105
- labels_padding_mask=cast(torch.BoolTensor, labels_padding_mask),
106
- )
120
+ return {
121
+ "query_id": query_id,
122
+ "feature_tensor": output_features,
123
+ "padding_mask": output_features_padding_mask,
124
+ "positive_labels": labels,
125
+ "target_padding_mask": labels_padding_mask,
126
+ }
107
127
 
108
128
 
129
+ @deprecated(
130
+ "`SasRecPredictionBatch` class is deprecated.",
131
+ stacklevel=2,
132
+ )
109
133
  class SasRecPredictionBatch(NamedTuple):
110
134
  """
111
135
  Batch of data for model inference.
@@ -116,10 +140,22 @@ class SasRecPredictionBatch(NamedTuple):
116
140
  padding_mask: torch.BoolTensor
117
141
  features: TensorMap
118
142
 
143
+ def convert_to_dict(self) -> dict:
144
+ return {
145
+ "query_id": self.query_id,
146
+ "feature_tensor": self.features,
147
+ "padding_mask": self.padding_mask,
148
+ }
149
+
119
150
 
151
+ @deprecated("`SasRecPredictionDataset` class is deprecated. Use `replay.data.nn.ParquetModule` instead.")
120
152
  class SasRecPredictionDataset(TorchDataset):
121
153
  """
122
- Dataset that generates samples to infer SasRec-like model
154
+ Dataset that generates samples to infer SasRec model
155
+
156
+ As a result of the dataset iteration, a dictionary is formed.
157
+ The keys in the dictionary match the names of the arguments in the model's `forward` function.
158
+ The `query_id` key is required for possible debugging and calling additional lightning callbacks.
123
159
  """
124
160
 
125
161
  def __init__(
@@ -143,15 +179,19 @@ class SasRecPredictionDataset(TorchDataset):
143
179
  def __len__(self) -> int:
144
180
  return len(self._inner)
145
181
 
146
- def __getitem__(self, index: int) -> SasRecPredictionBatch:
182
+ def __getitem__(self, index: int) -> dict:
147
183
  query_id, padding_mask, features = self._inner[index]
148
- return SasRecPredictionBatch(
149
- query_id=query_id,
150
- padding_mask=padding_mask,
151
- features=features,
152
- )
184
+ return {
185
+ "query_id": query_id,
186
+ "padding_mask": padding_mask,
187
+ "feature_tensor": features,
188
+ }
153
189
 
154
190
 
191
+ @deprecated(
192
+ "`SasRecValidationBatch` class is deprecated.",
193
+ stacklevel=2,
194
+ )
155
195
  class SasRecValidationBatch(NamedTuple):
156
196
  """
157
197
  Batch of data for validation.
@@ -164,10 +204,25 @@ class SasRecValidationBatch(NamedTuple):
164
204
  ground_truth: torch.LongTensor
165
205
  train: torch.LongTensor
166
206
 
207
+ def convert_to_dict(self) -> dict:
208
+ return {
209
+ "query_id": self.query_id,
210
+ "feature_tensor": self.features,
211
+ "padding_mask": self.padding_mask,
212
+ "ground_truth": self.ground_truth,
213
+ "train": self.train,
214
+ }
167
215
 
216
+
217
+ @deprecated("`SasRecValidationDataset` class is deprecated. Use `replay.data.nn.ParquetModule` instead.")
168
218
  class SasRecValidationDataset(TorchDataset):
169
219
  """
170
- Dataset that generates samples to infer and validate SasRec-like model
220
+ Dataset that generates samples to infer and validate SasRec model.
221
+
222
+ As a result of the dataset iteration, a dictionary is formed.
223
+ The keys in the dictionary match the names of the arguments in the model's `forward` function.
224
+ The `query_id` key is required for possible debugging and calling additional lightning callbacks.
225
+ Keys 'ground_truth` and `train` keys are required for metrics calculation on validation stage.
171
226
  """
172
227
 
173
228
  def __init__(
@@ -202,12 +257,12 @@ class SasRecValidationDataset(TorchDataset):
202
257
  def __len__(self) -> int:
203
258
  return len(self._inner)
204
259
 
205
- def __getitem__(self, index: int) -> SasRecValidationBatch:
260
+ def __getitem__(self, index: int) -> dict:
206
261
  query_id, padding_mask, features, ground_truth, train = self._inner[index]
207
- return SasRecValidationBatch(
208
- query_id=query_id,
209
- padding_mask=padding_mask,
210
- features=features,
211
- ground_truth=ground_truth,
212
- train=train,
213
- )
262
+ return {
263
+ "query_id": query_id,
264
+ "padding_mask": padding_mask,
265
+ "feature_tensor": features,
266
+ "ground_truth": ground_truth,
267
+ "train": train,
268
+ }
@@ -1,8 +1,10 @@
1
1
  import math
2
+ import warnings
2
3
  from typing import Any, Literal, Optional, Union, cast
3
4
 
4
5
  import lightning
5
6
  import torch
7
+ from typing_extensions import deprecated
6
8
 
7
9
  from replay.data.nn import TensorMap, TensorSchema
8
10
  from replay.models.nn.loss import ScalableCrossEntropyLoss, SCEParams
@@ -12,6 +14,11 @@ from .dataset import SasRecPredictionBatch, SasRecTrainingBatch, SasRecValidatio
12
14
  from .model import SasRecModel
13
15
 
14
16
 
17
+ @deprecated(
18
+ "`SasRec` class is deprecated. "
19
+ "Use `replay.nn.sequential.SasRec` "
20
+ "and `replay.nn.lightning.LightningModule` instead."
21
+ )
15
22
  class SasRec(lightning.LightningModule):
16
23
  """
17
24
  SASRec Lightning module.
@@ -54,9 +61,9 @@ class SasRec(lightning.LightningModule):
54
61
  Default: ``False``.
55
62
  :param time_span: Time span value.
56
63
  Default: ``256``.
57
- :param loss_type: Loss type. Possible values: ``"CE"``, ``"BCE"``, ``"SCE"``.
64
+ :param loss_type: Loss type.
58
65
  Default: ``CE``.
59
- :param loss_sample_count (Optional[int]): Sample count to calculate loss.
66
+ :param loss_sample_count: Sample count to calculate loss.
60
67
  Suitable for ``"CE"`` and ``"BCE"`` loss functions.
61
68
  Default: ``None``.
62
69
  :param negative_sampling_strategy: Negative sampling strategy to calculate loss on sampled negatives.
@@ -74,6 +81,7 @@ class SasRec(lightning.LightningModule):
74
81
  """
75
82
  super().__init__()
76
83
  self.save_hyperparameters()
84
+
77
85
  self._model = SasRecModel(
78
86
  schema=tensor_schema,
79
87
  num_blocks=block_count,
@@ -102,7 +110,7 @@ class SasRec(lightning.LightningModule):
102
110
  self._vocab_size = item_count
103
111
  self.candidates_to_score = None
104
112
 
105
- def training_step(self, batch: SasRecTrainingBatch, batch_idx: int) -> torch.Tensor:
113
+ def training_step(self, batch: Union[SasRecTrainingBatch, dict], batch_idx: int) -> torch.Tensor:
106
114
  """
107
115
  :param batch (SasRecTrainingBatch): Batch of training data.
108
116
  :param batch_idx (int): Batch index.
@@ -117,7 +125,7 @@ class SasRec(lightning.LightningModule):
117
125
 
118
126
  def predict_step(
119
127
  self,
120
- batch: SasRecPredictionBatch,
128
+ batch: Union[SasRecPredictionBatch, dict],
121
129
  batch_idx: int, # noqa: ARG002
122
130
  dataloader_idx: int = 0, # noqa: ARG002
123
131
  ) -> torch.Tensor:
@@ -128,12 +136,23 @@ class SasRec(lightning.LightningModule):
128
136
 
129
137
  :returns: Calculated scores.
130
138
  """
139
+ if isinstance(batch, SasRecPredictionBatch):
140
+ warnings.warn(
141
+ "`SasRecPredictionBatch` class will be removed in future versions. "
142
+ "Instead, you should use simple dictionary",
143
+ DeprecationWarning,
144
+ stacklevel=2,
145
+ )
146
+ batch = batch.convert_to_dict()
131
147
  batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
132
- return self._model_predict(batch.features, batch.padding_mask)
148
+ return self._model_predict(
149
+ feature_tensors=batch["feature_tensor"],
150
+ padding_mask=batch["padding_mask"],
151
+ )
133
152
 
134
153
  def predict(
135
154
  self,
136
- batch: SasRecPredictionBatch,
155
+ batch: Union[SasRecPredictionBatch, dict],
137
156
  candidates_to_score: Optional[torch.LongTensor] = None,
138
157
  ) -> torch.Tensor:
139
158
  """
@@ -143,8 +162,20 @@ class SasRec(lightning.LightningModule):
143
162
 
144
163
  :returns: Calculated scores.
145
164
  """
165
+ if isinstance(batch, SasRecPredictionBatch):
166
+ warnings.warn(
167
+ "`SasRecPredictionBatch` class will be removed in future versions. "
168
+ "Instead, you should use simple dictionary",
169
+ DeprecationWarning,
170
+ stacklevel=2,
171
+ )
172
+ batch = batch.convert_to_dict()
146
173
  batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
147
- return self._model_predict(batch.features, batch.padding_mask, candidates_to_score)
174
+ return self._model_predict(
175
+ feature_tensors=batch["feature_tensor"],
176
+ padding_mask=batch["padding_mask"],
177
+ candidates_to_score=candidates_to_score,
178
+ )
148
179
 
149
180
  def forward(
150
181
  self,
@@ -164,7 +195,7 @@ class SasRec(lightning.LightningModule):
164
195
 
165
196
  def validation_step(
166
197
  self,
167
- batch: SasRecValidationBatch,
198
+ batch: Union[SasRecValidationBatch, dict],
168
199
  batch_idx: int, # noqa: ARG002
169
200
  dataloader_idx: int = 0, # noqa: ARG002
170
201
  ) -> torch.Tensor:
@@ -174,7 +205,19 @@ class SasRec(lightning.LightningModule):
174
205
 
175
206
  :returns: Calculated scores.
176
207
  """
177
- return self._model_predict(batch.features, batch.padding_mask)
208
+ if isinstance(batch, SasRecValidationBatch):
209
+ warnings.warn(
210
+ "`SasRecValidationBatch` class will be removed in future versions. "
211
+ "Instead, you should use simple dictionary",
212
+ DeprecationWarning,
213
+ stacklevel=2,
214
+ )
215
+ batch = batch.convert_to_dict()
216
+
217
+ return self._model_predict(
218
+ feature_tensors=batch["feature_tensor"],
219
+ padding_mask=batch["padding_mask"],
220
+ )
178
221
 
179
222
  def configure_optimizers(self) -> Any:
180
223
  """
@@ -197,10 +240,14 @@ class SasRec(lightning.LightningModule):
197
240
  model: SasRecModel
198
241
  model = cast(SasRecModel, self._model.module) if isinstance(self._model, torch.nn.DataParallel) else self._model
199
242
  candidates_to_score = self.candidates_to_score if candidates_to_score is None else candidates_to_score
200
- scores = model.predict(feature_tensors, padding_mask, candidates_to_score)
243
+ scores = model.predict(
244
+ feature_tensor=feature_tensors,
245
+ padding_mask=padding_mask,
246
+ candidates_to_score=candidates_to_score,
247
+ )
201
248
  return scores
202
249
 
203
- def _compute_loss(self, batch: SasRecTrainingBatch) -> torch.Tensor:
250
+ def _compute_loss(self, batch: Union[SasRecTrainingBatch, dict]) -> torch.Tensor:
204
251
  if self._loss_type == "BCE":
205
252
  loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
206
253
  elif self._loss_type == "CE":
@@ -211,11 +258,20 @@ class SasRec(lightning.LightningModule):
211
258
  msg = f"Not supported loss type: {self._loss_type}"
212
259
  raise ValueError(msg)
213
260
 
261
+ if isinstance(batch, SasRecTrainingBatch):
262
+ warnings.warn(
263
+ "`SasRecTrainingBatch` class will be removed in future versions. "
264
+ "Instead, you should use simple dictionary",
265
+ DeprecationWarning,
266
+ stacklevel=2,
267
+ )
268
+ batch = batch.convert_to_dict()
269
+
214
270
  loss = loss_func(
215
- batch.features,
216
- batch.labels,
217
- batch.padding_mask,
218
- batch.labels_padding_mask,
271
+ batch["feature_tensor"],
272
+ batch["positive_labels"],
273
+ batch["padding_mask"],
274
+ batch["target_padding_mask"],
219
275
  )
220
276
  return loss
221
277
 
@@ -258,7 +314,7 @@ class SasRec(lightning.LightningModule):
258
314
  padding_mask: torch.BoolTensor,
259
315
  target_padding_mask: torch.BoolTensor,
260
316
  ) -> torch.Tensor:
261
- (positive_logits, negative_logits, *_) = self._get_sampled_logits(
317
+ positive_logits, negative_logits, *_ = self._get_sampled_logits(
262
318
  feature_tensors, positive_labels, padding_mask, target_padding_mask
263
319
  )
264
320
 
@@ -306,7 +362,7 @@ class SasRec(lightning.LightningModule):
306
362
  target_padding_mask: torch.BoolTensor,
307
363
  ) -> torch.Tensor:
308
364
  assert self._loss_sample_count is not None
309
- (positive_logits, negative_logits, positive_labels, negative_labels, vocab_size) = self._get_sampled_logits(
365
+ positive_logits, negative_logits, positive_labels, negative_labels, vocab_size = self._get_sampled_logits(
310
366
  feature_tensors, positive_labels, padding_mask, target_padding_mask
311
367
  )
312
368
  n_negative_samples = min(self._loss_sample_count, vocab_size)
@@ -566,19 +622,23 @@ class SasRec(lightning.LightningModule):
566
622
 
567
623
 
568
624
  def _prepare_prediction_batch(
569
- schema: TensorSchema, max_len: int, batch: SasRecPredictionBatch
570
- ) -> SasRecPredictionBatch:
571
- if batch.padding_mask.shape[1] > max_len:
625
+ schema: TensorSchema,
626
+ max_len: int,
627
+ batch: dict,
628
+ ) -> dict:
629
+ seq_len = batch["padding_mask"].shape[1]
630
+ if seq_len > max_len:
572
631
  msg = (
573
632
  "The length of the submitted sequence "
574
633
  "must not exceed the maximum length of the sequence. "
575
- f"The length of the sequence is given {batch.padding_mask.shape[1]}, "
634
+ f"The length of the sequence is given {seq_len}, "
576
635
  f"while the maximum length is {max_len}"
577
636
  )
578
637
  raise ValueError(msg)
579
638
 
580
- if batch.padding_mask.shape[1] < max_len:
581
- query_id, padding_mask, features = batch
639
+ if seq_len < max_len:
640
+ padding_mask = batch["padding_mask"]
641
+ features = batch["feature_tensor"].copy()
582
642
  sequence_item_count = padding_mask.shape[1]
583
643
  for feature_name, feature_tensor in features.items():
584
644
  if schema[feature_name].is_cat:
@@ -592,5 +652,7 @@ def _prepare_prediction_batch(
592
652
  value=0,
593
653
  ).unsqueeze(-1)
594
654
  padding_mask = torch.nn.functional.pad(padding_mask, (max_len - sequence_item_count, 0), value=0)
595
- batch = SasRecPredictionBatch(query_id, padding_mask, features)
655
+ batch["padding_mask"] = padding_mask
656
+ batch["feature_tensor"] = features
657
+
596
658
  return batch