replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
replay/nn/ffn.py ADDED
@@ -0,0 +1,135 @@
1
+ import contextlib
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from replay.data.nn.schema import TensorMap
7
+
8
+ from .utils import create_activation
9
+
10
+
11
+ class PointWiseFeedForward(torch.nn.Module):
12
+ """
13
+ Point wise feed forward network layer.
14
+
15
+ Source paper: https://arxiv.org/pdf/1808.09781.pdf
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ embedding_dim: int,
21
+ dropout: float,
22
+ activation: Literal["relu", "gelu"] = "gelu",
23
+ ) -> None:
24
+ """
25
+ :param embedding_dim: Dimension of the input features.
26
+ :param dropout: probability of an element to be zeroed.
27
+ :param activation: the name of the activation function.
28
+ Default: ``"gelu"``.
29
+ """
30
+ super().__init__()
31
+
32
+ self.conv1 = torch.nn.Conv1d(embedding_dim, embedding_dim, kernel_size=1)
33
+ self.dropout1 = torch.nn.Dropout(p=dropout)
34
+ self.activation = create_activation(activation)
35
+ self.conv2 = torch.nn.Conv1d(embedding_dim, embedding_dim, kernel_size=1)
36
+ self.dropout2 = torch.nn.Dropout(p=dropout)
37
+
38
+ def reset_parameters(self) -> None:
39
+ for _, param in self.named_parameters():
40
+ with contextlib.suppress(ValueError):
41
+ torch.nn.init.xavier_normal_(param.data)
42
+
43
+ def forward(self, input_embeddings: torch.LongTensor) -> torch.LongTensor:
44
+ """
45
+ :param input_embeddings: Query feature tensor.
46
+
47
+ :returns: Output tensors.
48
+ """
49
+ x: torch.Tensor = self.conv1(input_embeddings.transpose(-1, -2))
50
+ x = self.activation(x)
51
+ x = self.dropout1(x)
52
+ x = self.conv2(x)
53
+ x = self.dropout2(x)
54
+ x = x.transpose(-1, -2)
55
+ x += input_embeddings
56
+
57
+ return x
58
+
59
+
60
+ class SwiGLU(torch.nn.Module):
61
+ """
62
+ SwiGLU Activation Function.
63
+ Combines the Swish activation with Gated Linear Units.
64
+ """
65
+
66
+ def __init__(self, embedding_dim: int, hidden_dim: int):
67
+ """
68
+ :param embedding_dim: Dimension of the input features.
69
+ :param hidden_dim: Dimension of hidden layer.
70
+ According to the original source,
71
+ it is recommended to set the size of the hidden layer as :math:`2 \\cdot \\text{embedding_dim}`.
72
+ """
73
+ super().__init__()
74
+ # Intermediate projection layers
75
+ # Typically, SwiGLU splits the computation into two parts
76
+ self.WG = torch.nn.Linear(embedding_dim, hidden_dim)
77
+ self.W1 = torch.nn.Linear(embedding_dim, hidden_dim)
78
+ self.W2 = torch.nn.Linear(hidden_dim, embedding_dim)
79
+
80
+ def reset_parameters(self) -> None:
81
+ for _, param in self.named_parameters():
82
+ with contextlib.suppress(ValueError):
83
+ torch.nn.init.xavier_normal_(param.data)
84
+
85
+ def forward(
86
+ self,
87
+ input_embeddings: torch.Tensor,
88
+ ) -> torch.Tensor:
89
+ """
90
+ Forward pass for SwiGLU.
91
+
92
+ :param input_embeddings: Input tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
93
+
94
+ :returns: Output tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
95
+ """
96
+ # Apply the gates
97
+ activation = torch.nn.functional.silu(self.WG(input_embeddings)) # Activation part
98
+ linear = self.W1(input_embeddings) # Linear part
99
+ return self.W2(activation * linear) # Element-wise multiplication and projection
100
+
101
+
102
+ class SwiGLUEncoder(torch.nn.Module):
103
+ """
104
+ MLP block consists of SwiGLU Feed-Forward network followed by a RMSNorm layer with skip connection.
105
+ """
106
+
107
+ def __init__(self, embedding_dim: int, hidden_dim: int) -> None:
108
+ """
109
+ :param embedding_dim: Dimension of the input features.
110
+ """
111
+ super().__init__()
112
+ self.sw1 = SwiGLU(embedding_dim, hidden_dim)
113
+ self.norm1 = torch.nn.RMSNorm(embedding_dim)
114
+ self.sw2 = SwiGLU(embedding_dim, hidden_dim)
115
+ self.norm2 = torch.nn.RMSNorm(embedding_dim)
116
+
117
+ def reset_parameters(self) -> None:
118
+ self.sw1.reset_parameters()
119
+ self.sw2.reset_parameters()
120
+ self.norm1.reset_parameters()
121
+ self.norm2.reset_parameters()
122
+
123
+ def forward(
124
+ self,
125
+ feature_tensors: TensorMap, # noqa: ARG002
126
+ input_embeddings: torch.Tensor,
127
+ ) -> torch.Tensor:
128
+ """
129
+ forward(input_embeddings)
130
+ :param input_embeddings: Input tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
131
+ :returns: Output tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
132
+ """
133
+ x = self.norm1(self.sw1(input_embeddings) + input_embeddings)
134
+ x = self.norm2(self.sw2(x) + x)
135
+ return x
replay/nn/head.py ADDED
@@ -0,0 +1,49 @@
1
+ import torch
2
+
3
+
4
+ class EmbeddingTyingHead(torch.nn.Module):
5
+ """
6
+ The model head for calculating the output logits as a dot product
7
+ between the model hidden state and the item embeddings.
8
+ The module supports both 2-d and 3-d tensors for the hidden state and the item embeddings.
9
+
10
+ As a result of the work, the scores for each item will be obtained.
11
+ """
12
+
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ def forward(
17
+ self,
18
+ hidden_states: torch.Tensor,
19
+ item_embeddings: torch.Tensor,
20
+ ) -> torch.Tensor:
21
+ """
22
+ :param hidden_states: hidden state of shape
23
+ ``(batch_size, embedding_dim)`` or ``(batch_size, sequence_length, embedding_dim)``.
24
+ :param item_embeddings: item embeddings of shape
25
+ ``(num_items, embedding_dim)`` or ``(batch_size, num_items, embedding_dim)``.
26
+ :return: logits of shape ``(batch_size, num_items)``
27
+ or ``(batch_size, sequence_length, num_items)``.
28
+ """
29
+ if item_embeddings.dim() == 2:
30
+ item_embeddings = item_embeddings.transpose(-1, -2).contiguous()
31
+ # hidden_states shape [B, *, E]
32
+ # item embeddings shape [I, E]
33
+ # [B, *, E] x [E, I] -> [B, *, I]
34
+ return hidden_states.matmul(item_embeddings)
35
+ elif item_embeddings.dim() == 3 and hidden_states.dim() == 2:
36
+ item_embeddings = item_embeddings.transpose(-1, -2).contiguous()
37
+ # out_embeddings shape [B, E]
38
+ # item embeddings shape [B, I, E]
39
+ # [B, E] x [B, E, I] -> [B, I]
40
+ hidden_states = hidden_states.unsqueeze(-2)
41
+ logits = hidden_states.matmul(item_embeddings)
42
+ return logits.squeeze(-2)
43
+ # out_embeddings shape: [B, *, E]
44
+ # item embeddings shape [B, *, E]
45
+ # [*, 1, E] x [*, E, 1] -> [B, *]
46
+ return torch.bmm(
47
+ hidden_states.view(-1, 1, hidden_states.size(-1)),
48
+ item_embeddings.view(-1, item_embeddings.size(-1), 1),
49
+ ).view(hidden_states.size(0), *item_embeddings.shape[1:-1])
@@ -0,0 +1 @@
1
+ from .module import LightningModule
@@ -0,0 +1,9 @@
1
+ from .metrics_callback import ComputeMetricsCallback
2
+ from .predictions_callback import (
3
+ HiddenStatesCallback,
4
+ PandasTopItemsCallback,
5
+ PolarsTopItemsCallback,
6
+ SparkTopItemsCallback,
7
+ TopItemsCallbackBase,
8
+ TorchTopItemsCallback,
9
+ )
@@ -0,0 +1,183 @@
1
+ from typing import Any, Optional
2
+
3
+ import lightning
4
+ import torch
5
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
6
+ from lightning.pytorch.utilities.rank_zero import rank_zero_only
7
+
8
+ from replay.metrics.torch_metrics_builder import (
9
+ MetricName,
10
+ TorchMetricsBuilder,
11
+ metrics_to_df,
12
+ )
13
+ from replay.nn.lightning import LightningModule
14
+ from replay.nn.lightning.postprocessor import PostprocessorBase
15
+ from replay.nn.output import InferenceOutput
16
+
17
+
18
+ class ComputeMetricsCallback(lightning.Callback):
19
+ """
20
+ Callback for validation and testing stages.
21
+
22
+ If multiple validation/testing dataloaders are used,
23
+ the suffix of the metric name will contain the serial number of the dataloader.
24
+
25
+ For the correct calculation of metrics inside the callback,
26
+ the batch must contain the ``ground_truth_column`` key - the padding value of this tensor can be any,
27
+ the main condition is that the padding value does not overlap with the existing item ID values.
28
+ For example, these can be negative values.
29
+
30
+ To calculate the ``coverage`` and ``novelty`` metrics, the batch must additionally contain the ``train_column`` key.
31
+ The padding value of this tensor can be any, the main condition is that the padding value does not overlap
32
+ with the existing item ID values. For example, these can be negative values.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ metrics: Optional[list[MetricName]] = None,
38
+ ks: Optional[list[int]] = None,
39
+ postprocessors: Optional[list[PostprocessorBase]] = None,
40
+ item_count: Optional[int] = None,
41
+ ground_truth_column: str = "ground_truth",
42
+ train_column: str = "train",
43
+ ):
44
+ """
45
+ :param metrics: Sequence of metrics to calculate.\n
46
+ Default: ``None``. This means that the default metrics will be used - ``Map``, ``NDCG``, ``Recall``.
47
+ :param ks: highest k scores in ranking.\n
48
+ Default: ``None``. This means that the default ``ks`` will be ``[1, 5, 10, 20]``.
49
+ :param postprocessors: A list of postprocessors for modifying logits from the model.
50
+ For example, it can be a softmax operation to logits or set the ``-inf`` value for some IDs.
51
+ Default: ``None``.
52
+ :param item_count: the total number of items in the dataset, required only for ``Coverage`` calculations.
53
+ Default: ``None``.
54
+ :param ground_truth_column: Name of key in batch that contains ground truth items.
55
+ :param train_column: Name of key in batch that contains items on which the model is trained.
56
+ """
57
+ self._metrics = metrics
58
+ self._ks = ks
59
+ self._item_count = item_count
60
+ self._metrics_builders: list[TorchMetricsBuilder] = []
61
+ self._dataloaders_size: list[int] = []
62
+ self._postprocessors: list[PostprocessorBase] = postprocessors or []
63
+ self._ground_truth_column = ground_truth_column
64
+ self._train_column = train_column
65
+
66
+ def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> list[int]:
67
+ if isinstance(dataloaders, CombinedLoader):
68
+ return [len(dataloader) for dataloader in dataloaders.flattened] # pragma: no cover
69
+ return [len(dataloaders)]
70
+
71
+ def on_validation_epoch_start(
72
+ self,
73
+ trainer: lightning.Trainer,
74
+ pl_module: LightningModule, # noqa: ARG002
75
+ ) -> None:
76
+ self._dataloaders_size = self._get_dataloaders_size(trainer.val_dataloaders)
77
+ self._metrics_builders = [
78
+ TorchMetricsBuilder(self._metrics, self._ks, self._item_count) for _ in self._dataloaders_size
79
+ ]
80
+ for builder in self._metrics_builders:
81
+ builder.reset()
82
+
83
+ def on_test_epoch_start(
84
+ self,
85
+ trainer: lightning.Trainer,
86
+ pl_module: LightningModule, # noqa: ARG002
87
+ ) -> None:
88
+ self._dataloaders_size = self._get_dataloaders_size(trainer.test_dataloaders)
89
+ self._metrics_builders = [
90
+ TorchMetricsBuilder(self._metrics, self._ks, self._item_count) for _ in self._dataloaders_size
91
+ ]
92
+ for builder in self._metrics_builders:
93
+ builder.reset()
94
+
95
+ def _apply_postproccesors(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
96
+ for postprocessor in self._postprocessors:
97
+ logits = postprocessor.on_validation(batch, logits)
98
+ return logits
99
+
100
+ def on_validation_batch_end(
101
+ self,
102
+ trainer: lightning.Trainer,
103
+ pl_module: LightningModule,
104
+ outputs: InferenceOutput,
105
+ batch: dict,
106
+ batch_idx: int,
107
+ dataloader_idx: int = 0,
108
+ ) -> None:
109
+ self._batch_end(
110
+ trainer,
111
+ pl_module,
112
+ outputs,
113
+ batch,
114
+ batch_idx,
115
+ dataloader_idx,
116
+ )
117
+
118
+ def on_test_batch_end(
119
+ self,
120
+ trainer: lightning.Trainer,
121
+ pl_module: LightningModule,
122
+ outputs: InferenceOutput,
123
+ batch: dict,
124
+ batch_idx: int,
125
+ dataloader_idx: int = 0,
126
+ ) -> None: # pragma: no cover
127
+ self._batch_end(
128
+ trainer,
129
+ pl_module,
130
+ outputs,
131
+ batch,
132
+ batch_idx,
133
+ dataloader_idx,
134
+ )
135
+
136
+ def _batch_end(
137
+ self,
138
+ trainer: lightning.Trainer, # noqa: ARG002
139
+ pl_module: LightningModule,
140
+ outputs: InferenceOutput,
141
+ batch: dict,
142
+ batch_idx: int,
143
+ dataloader_idx: int,
144
+ ) -> None:
145
+ seen_scores = self._apply_postproccesors(batch, outputs["logits"])
146
+ sampled_items = torch.topk(seen_scores, k=self._metrics_builders[dataloader_idx].max_k, dim=1).indices
147
+ self._metrics_builders[dataloader_idx].add_prediction(
148
+ sampled_items, batch[self._ground_truth_column], batch.get(self._train_column)
149
+ )
150
+
151
+ if batch_idx + 1 == self._dataloaders_size[dataloader_idx]:
152
+ pl_module.log_dict(
153
+ self._metrics_builders[dataloader_idx].get_metrics(),
154
+ on_epoch=True,
155
+ sync_dist=True,
156
+ add_dataloader_idx=True,
157
+ )
158
+
159
+ def on_validation_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None:
160
+ self._epoch_end(trainer, pl_module)
161
+
162
+ def on_test_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None: # pragma: no cover
163
+ self._epoch_end(trainer, pl_module)
164
+
165
+ def _epoch_end(
166
+ self,
167
+ trainer: lightning.Trainer,
168
+ pl_module: LightningModule, # noqa: ARG002
169
+ ) -> None:
170
+ @rank_zero_only
171
+ def print_metrics() -> None:
172
+ metrics = {}
173
+ for name, value in trainer.logged_metrics.items():
174
+ if "@" in name:
175
+ metrics[name] = value.item()
176
+
177
+ if metrics:
178
+ metrics_df = metrics_to_df(metrics)
179
+
180
+ print(metrics_df) # noqa: T201
181
+ print() # noqa: T201
182
+
183
+ print_metrics()
@@ -0,0 +1,314 @@
1
+ import abc
2
+ from typing import Generic, Optional, TypeVar
3
+
4
+ import lightning
5
+ import torch
6
+
7
+ from replay.nn.lightning import LightningModule
8
+ from replay.nn.lightning.postprocessor import PostprocessorBase
9
+ from replay.nn.output import InferenceOutput
10
+ from replay.utils import (
11
+ PYSPARK_AVAILABLE,
12
+ MissingImport,
13
+ PandasDataFrame,
14
+ PolarsDataFrame,
15
+ SparkDataFrame,
16
+ )
17
+
18
+ if PYSPARK_AVAILABLE: # pragma: no cover
19
+ import pyspark.sql.functions as sf
20
+ from pyspark.sql import SparkSession
21
+ from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType
22
+ else: # pragma: no cover
23
+ SparkSession = MissingImport
24
+
25
+
26
+ _T = TypeVar("_T")
27
+
28
+
29
+ class TopItemsCallbackBase(lightning.Callback, Generic[_T]):
30
+ """
31
+ The base class for a callback that records the result at the inference stage via ``LightningModule``.
32
+ The result consists of top K the highest logit values, IDs of these top K logit values
33
+ and corresponding query ids (encoded IDs of users named ``query_id``).
34
+
35
+ For the callback to work correctly, the batch is expected to contain the ``query_id`` key.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ top_k: int,
41
+ query_column: str,
42
+ item_column: str,
43
+ rating_column: str = "rating",
44
+ postprocessors: Optional[list[PostprocessorBase]] = None,
45
+ ) -> None:
46
+ """
47
+ :param top_k: Take the ``top_k`` IDs with the highest logit values.
48
+ :param query_column: The name of the query column in the resulting dataframe.
49
+ :param item_column: The name of the item column in the resulting dataframe.
50
+ :param rating_column: The name of the rating column in the resulting dataframe.
51
+ This column will contain the ``top_k`` items with the highest logit values.
52
+ :param postprocessors: A list of postprocessors for modifying logits from the model
53
+ before sorting and taking top K ones.
54
+ For example, it can be a softmax operation to logits or set the ``-inf`` value for some IDs.
55
+ Default: ``None``.
56
+ """
57
+ super().__init__()
58
+ self.query_column = query_column
59
+ self.item_column = item_column
60
+ self.rating_column = rating_column
61
+ self._top_k = top_k
62
+ self._postprocessors: list[PostprocessorBase] = postprocessors or []
63
+ self._query_batches: list[torch.Tensor] = []
64
+ self._item_batches: list[torch.Tensor] = []
65
+ self._item_scores: list[torch.Tensor] = []
66
+
67
+ def on_predict_epoch_start(
68
+ self,
69
+ trainer: lightning.Trainer, # noqa: ARG002
70
+ pl_module: LightningModule,
71
+ ) -> None:
72
+ self._query_batches.clear()
73
+ self._item_batches.clear()
74
+ self._item_scores.clear()
75
+
76
+ candidates = pl_module.candidates_to_score
77
+ for postprocessor in self._postprocessors:
78
+ postprocessor.candidates = candidates
79
+
80
+ def on_predict_batch_end(
81
+ self,
82
+ trainer: lightning.Trainer, # noqa: ARG002
83
+ pl_module: LightningModule,
84
+ outputs: InferenceOutput,
85
+ batch: dict,
86
+ batch_idx: int, # noqa: ARG002
87
+ dataloader_idx: int = 0, # noqa: ARG002
88
+ ) -> None:
89
+ logits = self._apply_postproccesors(batch, outputs["logits"])
90
+ top_scores, top_item_ids = torch.topk(logits, k=self._top_k, dim=1)
91
+ if pl_module.candidates_to_score is not None:
92
+ top_item_ids = torch.take(pl_module.candidates_to_score, top_item_ids)
93
+
94
+ self._query_batches.append(batch["query_id"])
95
+ self._item_batches.append(top_item_ids)
96
+ self._item_scores.append(top_scores)
97
+
98
+ def get_result(self) -> _T:
99
+ """
100
+ :returns: prediction result
101
+ """
102
+ prediction = self._ids_to_result(
103
+ torch.cat(self._query_batches),
104
+ torch.cat(self._item_batches),
105
+ torch.cat(self._item_scores),
106
+ )
107
+ return prediction
108
+
109
+ def _apply_postproccesors(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
110
+ for postprocessor in self._postprocessors:
111
+ logits = postprocessor.on_prediction(batch, logits)
112
+ return logits
113
+
114
+ @abc.abstractmethod
115
+ def _ids_to_result(
116
+ self,
117
+ query_ids: torch.Tensor,
118
+ item_ids: torch.Tensor,
119
+ item_scores: torch.Tensor,
120
+ ) -> _T: # pragma: no cover
121
+ pass
122
+
123
+
124
+ class PandasTopItemsCallback(TopItemsCallbackBase[PandasDataFrame]):
125
+ """
126
+ A callback that records the result of the model's forward function at the inference stage in a Pandas Dataframe.
127
+ """
128
+
129
+ def _ids_to_result(
130
+ self,
131
+ query_ids: torch.Tensor,
132
+ item_ids: torch.Tensor,
133
+ item_scores: torch.Tensor,
134
+ ) -> PandasDataFrame:
135
+ prediction = PandasDataFrame(
136
+ {
137
+ self.query_column: query_ids.flatten().cpu().numpy(),
138
+ self.item_column: list(item_ids.cpu().numpy()),
139
+ self.rating_column: list(item_scores.cpu().numpy()),
140
+ }
141
+ )
142
+ return prediction.explode([self.item_column, self.rating_column])
143
+
144
+
145
+ class PolarsTopItemsCallback(TopItemsCallbackBase[PolarsDataFrame]):
146
+ """
147
+ A callback that records the result of the model's forward function at the inference stage in a Polars Dataframe.
148
+ """
149
+
150
+ def _ids_to_result(
151
+ self,
152
+ query_ids: torch.Tensor,
153
+ item_ids: torch.Tensor,
154
+ item_scores: torch.Tensor,
155
+ ) -> PolarsDataFrame:
156
+ prediction = PolarsDataFrame(
157
+ {
158
+ self.query_column: query_ids.flatten().cpu().numpy(),
159
+ self.item_column: list(item_ids.cpu().numpy()),
160
+ self.rating_column: list(item_scores.cpu().numpy()),
161
+ }
162
+ )
163
+ return prediction.explode([self.item_column, self.rating_column])
164
+
165
+
166
+ class SparkTopItemsCallback(TopItemsCallbackBase[SparkDataFrame]):
167
+ """
168
+ A callback that records the result of the model's forward function at the inference stage in a Spark Dataframe.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ top_k: int,
174
+ query_column: str,
175
+ item_column: str,
176
+ rating_column: str,
177
+ spark_session: SparkSession,
178
+ postprocessors: Optional[list[PostprocessorBase]] = None,
179
+ ) -> None:
180
+ """
181
+ :param top_k: Take the ``top_k`` IDs with the highest logit values.
182
+ :param query_column: The name of the query column in the resulting dataframe.
183
+ :param item_column: The name of the item column in the resulting dataframe.
184
+ :param rating_column: The name of the rating column in the resulting dataframe.
185
+ This column will contain the ``top_k`` items with the highest logit values.
186
+ :param spark_session: Spark session. Required to create a Spark DataFrame.
187
+ :param postprocessors: A list of postprocessors for modifying logits from the model
188
+ before sorting and taking top K ones.
189
+ For example, it can be a softmax operation to logits or set the ``-inf`` value for some IDs.
190
+ Default: ``None``.
191
+ """
192
+ super().__init__(
193
+ top_k=top_k,
194
+ query_column=query_column,
195
+ item_column=item_column,
196
+ rating_column=rating_column,
197
+ postprocessors=postprocessors,
198
+ )
199
+ self.spark_session = spark_session
200
+
201
+ def _ids_to_result(
202
+ self,
203
+ query_ids: torch.Tensor,
204
+ item_ids: torch.Tensor,
205
+ item_scores: torch.Tensor,
206
+ ) -> SparkDataFrame:
207
+ schema = (
208
+ StructType()
209
+ .add(self.query_column, IntegerType(), False)
210
+ .add(self.item_column, ArrayType(IntegerType()), False)
211
+ .add(self.rating_column, ArrayType(DoubleType()), False)
212
+ )
213
+ prediction = (
214
+ self.spark_session.createDataFrame(
215
+ data=list(
216
+ zip(
217
+ query_ids.flatten().cpu().numpy().tolist(),
218
+ item_ids.cpu().numpy().tolist(),
219
+ item_scores.cpu().numpy().tolist(),
220
+ )
221
+ ),
222
+ schema=schema,
223
+ )
224
+ .withColumn(
225
+ "exploded_columns",
226
+ sf.explode(sf.arrays_zip(self.item_column, self.rating_column)),
227
+ )
228
+ .select(
229
+ self.query_column,
230
+ f"exploded_columns.{self.item_column}",
231
+ f"exploded_columns.{self.rating_column}",
232
+ )
233
+ )
234
+ return prediction
235
+
236
+
237
+ class TorchTopItemsCallback(TopItemsCallbackBase[tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]]):
238
+ """
239
+ A callback that records the result of the model's forward function at the inference stage in a PyTorch Tensors.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ top_k: int,
245
+ postprocessors: Optional[list[PostprocessorBase]] = None,
246
+ ) -> None:
247
+ """
248
+ :param top_k: Take the ``top_k`` IDs with the highest logit values.
249
+ :param postprocessors: A list of postprocessors for modifying logits from the model
250
+ before sorting and taking top K.
251
+ For example, it can be a softmax operation to logits or set the ``-inf`` value for some IDs.
252
+ Default: ``None``.
253
+ """
254
+ super().__init__(
255
+ top_k=top_k,
256
+ query_column="query_id",
257
+ item_column="item_id",
258
+ rating_column="rating",
259
+ postprocessors=postprocessors,
260
+ )
261
+
262
+ def _ids_to_result(
263
+ self,
264
+ query_ids: torch.Tensor,
265
+ item_ids: torch.Tensor,
266
+ item_scores: torch.Tensor,
267
+ ) -> tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]:
268
+ return (
269
+ query_ids.flatten().cpu().long(),
270
+ item_ids.cpu().long(),
271
+ item_scores.cpu(),
272
+ )
273
+
274
+
275
+ class HiddenStatesCallback(lightning.Callback):
276
+ """
277
+ A callback for getting any hidden state from the model.
278
+
279
+ When applying this callback,
280
+ it is expected that the result of the model's forward function contains the ``hidden_states`` key.
281
+ """
282
+
283
+ def __init__(self, hidden_state_index: int):
284
+ """
285
+ :param hidden_state_index: It is expected that the result of the model's forward function
286
+ contains the ``hidden_states`` key. ``hidden_states`` key contains Tuple of PyTorch Tensors.
287
+ Therefore, to get a specific hidden state, you need to submit an index from this tuple.
288
+ """
289
+ self._hidden_state_index = hidden_state_index
290
+ self._embeddings_per_batch: list[torch.Tensor] = []
291
+
292
+ def on_predict_epoch_start(
293
+ self,
294
+ trainer: lightning.Trainer, # noqa: ARG002
295
+ pl_module: LightningModule, # noqa: ARG002
296
+ ) -> None:
297
+ self._embeddings_per_batch.clear()
298
+
299
+ def on_predict_batch_end(
300
+ self,
301
+ trainer: lightning.Trainer, # noqa: ARG002
302
+ pl_module: LightningModule, # noqa: ARG002
303
+ outputs: InferenceOutput,
304
+ batch: dict, # noqa: ARG002
305
+ batch_idx: int, # noqa: ARG002
306
+ dataloader_idx: int = 0, # noqa: ARG002
307
+ ) -> None:
308
+ self._embeddings_per_batch.append(outputs["hidden_states"][self._hidden_state_index].detach().cpu())
309
+
310
+ def get_result(self):
311
+ """
312
+ :returns: Hidden states through all batches.
313
+ """
314
+ return torch.cat(self._embeddings_per_batch)