replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay_rec-0.20.3.dist-info/RECORD +0 -138
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
replay/nn/ffn.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from replay.data.nn.schema import TensorMap
|
|
7
|
+
|
|
8
|
+
from .utils import create_activation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PointWiseFeedForward(torch.nn.Module):
|
|
12
|
+
"""
|
|
13
|
+
Point wise feed forward network layer.
|
|
14
|
+
|
|
15
|
+
Source paper: https://arxiv.org/pdf/1808.09781.pdf
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
embedding_dim: int,
|
|
21
|
+
dropout: float,
|
|
22
|
+
activation: Literal["relu", "gelu"] = "gelu",
|
|
23
|
+
) -> None:
|
|
24
|
+
"""
|
|
25
|
+
:param embedding_dim: Dimension of the input features.
|
|
26
|
+
:param dropout: probability of an element to be zeroed.
|
|
27
|
+
:param activation: the name of the activation function.
|
|
28
|
+
Default: ``"gelu"``.
|
|
29
|
+
"""
|
|
30
|
+
super().__init__()
|
|
31
|
+
|
|
32
|
+
self.conv1 = torch.nn.Conv1d(embedding_dim, embedding_dim, kernel_size=1)
|
|
33
|
+
self.dropout1 = torch.nn.Dropout(p=dropout)
|
|
34
|
+
self.activation = create_activation(activation)
|
|
35
|
+
self.conv2 = torch.nn.Conv1d(embedding_dim, embedding_dim, kernel_size=1)
|
|
36
|
+
self.dropout2 = torch.nn.Dropout(p=dropout)
|
|
37
|
+
|
|
38
|
+
def reset_parameters(self) -> None:
|
|
39
|
+
for _, param in self.named_parameters():
|
|
40
|
+
with contextlib.suppress(ValueError):
|
|
41
|
+
torch.nn.init.xavier_normal_(param.data)
|
|
42
|
+
|
|
43
|
+
def forward(self, input_embeddings: torch.LongTensor) -> torch.LongTensor:
|
|
44
|
+
"""
|
|
45
|
+
:param input_embeddings: Query feature tensor.
|
|
46
|
+
|
|
47
|
+
:returns: Output tensors.
|
|
48
|
+
"""
|
|
49
|
+
x: torch.Tensor = self.conv1(input_embeddings.transpose(-1, -2))
|
|
50
|
+
x = self.activation(x)
|
|
51
|
+
x = self.dropout1(x)
|
|
52
|
+
x = self.conv2(x)
|
|
53
|
+
x = self.dropout2(x)
|
|
54
|
+
x = x.transpose(-1, -2)
|
|
55
|
+
x += input_embeddings
|
|
56
|
+
|
|
57
|
+
return x
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SwiGLU(torch.nn.Module):
|
|
61
|
+
"""
|
|
62
|
+
SwiGLU Activation Function.
|
|
63
|
+
Combines the Swish activation with Gated Linear Units.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self, embedding_dim: int, hidden_dim: int):
|
|
67
|
+
"""
|
|
68
|
+
:param embedding_dim: Dimension of the input features.
|
|
69
|
+
:param hidden_dim: Dimension of hidden layer.
|
|
70
|
+
According to the original source,
|
|
71
|
+
it is recommended to set the size of the hidden layer as :math:`2 \\cdot \\text{embedding_dim}`.
|
|
72
|
+
"""
|
|
73
|
+
super().__init__()
|
|
74
|
+
# Intermediate projection layers
|
|
75
|
+
# Typically, SwiGLU splits the computation into two parts
|
|
76
|
+
self.WG = torch.nn.Linear(embedding_dim, hidden_dim)
|
|
77
|
+
self.W1 = torch.nn.Linear(embedding_dim, hidden_dim)
|
|
78
|
+
self.W2 = torch.nn.Linear(hidden_dim, embedding_dim)
|
|
79
|
+
|
|
80
|
+
def reset_parameters(self) -> None:
|
|
81
|
+
for _, param in self.named_parameters():
|
|
82
|
+
with contextlib.suppress(ValueError):
|
|
83
|
+
torch.nn.init.xavier_normal_(param.data)
|
|
84
|
+
|
|
85
|
+
def forward(
|
|
86
|
+
self,
|
|
87
|
+
input_embeddings: torch.Tensor,
|
|
88
|
+
) -> torch.Tensor:
|
|
89
|
+
"""
|
|
90
|
+
Forward pass for SwiGLU.
|
|
91
|
+
|
|
92
|
+
:param input_embeddings: Input tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
93
|
+
|
|
94
|
+
:returns: Output tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
95
|
+
"""
|
|
96
|
+
# Apply the gates
|
|
97
|
+
activation = torch.nn.functional.silu(self.WG(input_embeddings)) # Activation part
|
|
98
|
+
linear = self.W1(input_embeddings) # Linear part
|
|
99
|
+
return self.W2(activation * linear) # Element-wise multiplication and projection
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class SwiGLUEncoder(torch.nn.Module):
|
|
103
|
+
"""
|
|
104
|
+
MLP block consists of SwiGLU Feed-Forward network followed by a RMSNorm layer with skip connection.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, embedding_dim: int, hidden_dim: int) -> None:
|
|
108
|
+
"""
|
|
109
|
+
:param embedding_dim: Dimension of the input features.
|
|
110
|
+
"""
|
|
111
|
+
super().__init__()
|
|
112
|
+
self.sw1 = SwiGLU(embedding_dim, hidden_dim)
|
|
113
|
+
self.norm1 = torch.nn.RMSNorm(embedding_dim)
|
|
114
|
+
self.sw2 = SwiGLU(embedding_dim, hidden_dim)
|
|
115
|
+
self.norm2 = torch.nn.RMSNorm(embedding_dim)
|
|
116
|
+
|
|
117
|
+
def reset_parameters(self) -> None:
|
|
118
|
+
self.sw1.reset_parameters()
|
|
119
|
+
self.sw2.reset_parameters()
|
|
120
|
+
self.norm1.reset_parameters()
|
|
121
|
+
self.norm2.reset_parameters()
|
|
122
|
+
|
|
123
|
+
def forward(
|
|
124
|
+
self,
|
|
125
|
+
feature_tensors: TensorMap, # noqa: ARG002
|
|
126
|
+
input_embeddings: torch.Tensor,
|
|
127
|
+
) -> torch.Tensor:
|
|
128
|
+
"""
|
|
129
|
+
forward(input_embeddings)
|
|
130
|
+
:param input_embeddings: Input tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
131
|
+
:returns: Output tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
132
|
+
"""
|
|
133
|
+
x = self.norm1(self.sw1(input_embeddings) + input_embeddings)
|
|
134
|
+
x = self.norm2(self.sw2(x) + x)
|
|
135
|
+
return x
|
replay/nn/head.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class EmbeddingTyingHead(torch.nn.Module):
|
|
5
|
+
"""
|
|
6
|
+
The model head for calculating the output logits as a dot product
|
|
7
|
+
between the model hidden state and the item embeddings.
|
|
8
|
+
The module supports both 2-d and 3-d tensors for the hidden state and the item embeddings.
|
|
9
|
+
|
|
10
|
+
As a result of the work, the scores for each item will be obtained.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
super().__init__()
|
|
15
|
+
|
|
16
|
+
def forward(
|
|
17
|
+
self,
|
|
18
|
+
hidden_states: torch.Tensor,
|
|
19
|
+
item_embeddings: torch.Tensor,
|
|
20
|
+
) -> torch.Tensor:
|
|
21
|
+
"""
|
|
22
|
+
:param hidden_states: hidden state of shape
|
|
23
|
+
``(batch_size, embedding_dim)`` or ``(batch_size, sequence_length, embedding_dim)``.
|
|
24
|
+
:param item_embeddings: item embeddings of shape
|
|
25
|
+
``(num_items, embedding_dim)`` or ``(batch_size, num_items, embedding_dim)``.
|
|
26
|
+
:return: logits of shape ``(batch_size, num_items)``
|
|
27
|
+
or ``(batch_size, sequence_length, num_items)``.
|
|
28
|
+
"""
|
|
29
|
+
if item_embeddings.dim() == 2:
|
|
30
|
+
item_embeddings = item_embeddings.transpose(-1, -2).contiguous()
|
|
31
|
+
# hidden_states shape [B, *, E]
|
|
32
|
+
# item embeddings shape [I, E]
|
|
33
|
+
# [B, *, E] x [E, I] -> [B, *, I]
|
|
34
|
+
return hidden_states.matmul(item_embeddings)
|
|
35
|
+
elif item_embeddings.dim() == 3 and hidden_states.dim() == 2:
|
|
36
|
+
item_embeddings = item_embeddings.transpose(-1, -2).contiguous()
|
|
37
|
+
# out_embeddings shape [B, E]
|
|
38
|
+
# item embeddings shape [B, I, E]
|
|
39
|
+
# [B, E] x [B, E, I] -> [B, I]
|
|
40
|
+
hidden_states = hidden_states.unsqueeze(-2)
|
|
41
|
+
logits = hidden_states.matmul(item_embeddings)
|
|
42
|
+
return logits.squeeze(-2)
|
|
43
|
+
# out_embeddings shape: [B, *, E]
|
|
44
|
+
# item embeddings shape [B, *, E]
|
|
45
|
+
# [*, 1, E] x [*, E, 1] -> [B, *]
|
|
46
|
+
return torch.bmm(
|
|
47
|
+
hidden_states.view(-1, 1, hidden_states.size(-1)),
|
|
48
|
+
item_embeddings.view(-1, item_embeddings.size(-1), 1),
|
|
49
|
+
).view(hidden_states.size(0), *item_embeddings.shape[1:-1])
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .module import LightningModule
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
import lightning
|
|
4
|
+
import torch
|
|
5
|
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
|
6
|
+
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
|
7
|
+
|
|
8
|
+
from replay.metrics.torch_metrics_builder import (
|
|
9
|
+
MetricName,
|
|
10
|
+
TorchMetricsBuilder,
|
|
11
|
+
metrics_to_df,
|
|
12
|
+
)
|
|
13
|
+
from replay.nn.lightning import LightningModule
|
|
14
|
+
from replay.nn.lightning.postprocessor import PostprocessorBase
|
|
15
|
+
from replay.nn.output import InferenceOutput
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ComputeMetricsCallback(lightning.Callback):
|
|
19
|
+
"""
|
|
20
|
+
Callback for validation and testing stages.
|
|
21
|
+
|
|
22
|
+
If multiple validation/testing dataloaders are used,
|
|
23
|
+
the suffix of the metric name will contain the serial number of the dataloader.
|
|
24
|
+
|
|
25
|
+
For the correct calculation of metrics inside the callback,
|
|
26
|
+
the batch must contain the ``ground_truth_column`` key - the padding value of this tensor can be any,
|
|
27
|
+
the main condition is that the padding value does not overlap with the existing item ID values.
|
|
28
|
+
For example, these can be negative values.
|
|
29
|
+
|
|
30
|
+
To calculate the ``coverage`` and ``novelty`` metrics, the batch must additionally contain the ``train_column`` key.
|
|
31
|
+
The padding value of this tensor can be any, the main condition is that the padding value does not overlap
|
|
32
|
+
with the existing item ID values. For example, these can be negative values.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
metrics: Optional[list[MetricName]] = None,
|
|
38
|
+
ks: Optional[list[int]] = None,
|
|
39
|
+
postprocessors: Optional[list[PostprocessorBase]] = None,
|
|
40
|
+
item_count: Optional[int] = None,
|
|
41
|
+
ground_truth_column: str = "ground_truth",
|
|
42
|
+
train_column: str = "train",
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
:param metrics: Sequence of metrics to calculate.\n
|
|
46
|
+
Default: ``None``. This means that the default metrics will be used - ``Map``, ``NDCG``, ``Recall``.
|
|
47
|
+
:param ks: highest k scores in ranking.\n
|
|
48
|
+
Default: ``None``. This means that the default ``ks`` will be ``[1, 5, 10, 20]``.
|
|
49
|
+
:param postprocessors: A list of postprocessors for modifying logits from the model.
|
|
50
|
+
For example, it can be a softmax operation to logits or set the ``-inf`` value for some IDs.
|
|
51
|
+
Default: ``None``.
|
|
52
|
+
:param item_count: the total number of items in the dataset, required only for ``Coverage`` calculations.
|
|
53
|
+
Default: ``None``.
|
|
54
|
+
:param ground_truth_column: Name of key in batch that contains ground truth items.
|
|
55
|
+
:param train_column: Name of key in batch that contains items on which the model is trained.
|
|
56
|
+
"""
|
|
57
|
+
self._metrics = metrics
|
|
58
|
+
self._ks = ks
|
|
59
|
+
self._item_count = item_count
|
|
60
|
+
self._metrics_builders: list[TorchMetricsBuilder] = []
|
|
61
|
+
self._dataloaders_size: list[int] = []
|
|
62
|
+
self._postprocessors: list[PostprocessorBase] = postprocessors or []
|
|
63
|
+
self._ground_truth_column = ground_truth_column
|
|
64
|
+
self._train_column = train_column
|
|
65
|
+
|
|
66
|
+
def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> list[int]:
|
|
67
|
+
if isinstance(dataloaders, CombinedLoader):
|
|
68
|
+
return [len(dataloader) for dataloader in dataloaders.flattened] # pragma: no cover
|
|
69
|
+
return [len(dataloaders)]
|
|
70
|
+
|
|
71
|
+
def on_validation_epoch_start(
|
|
72
|
+
self,
|
|
73
|
+
trainer: lightning.Trainer,
|
|
74
|
+
pl_module: LightningModule, # noqa: ARG002
|
|
75
|
+
) -> None:
|
|
76
|
+
self._dataloaders_size = self._get_dataloaders_size(trainer.val_dataloaders)
|
|
77
|
+
self._metrics_builders = [
|
|
78
|
+
TorchMetricsBuilder(self._metrics, self._ks, self._item_count) for _ in self._dataloaders_size
|
|
79
|
+
]
|
|
80
|
+
for builder in self._metrics_builders:
|
|
81
|
+
builder.reset()
|
|
82
|
+
|
|
83
|
+
def on_test_epoch_start(
|
|
84
|
+
self,
|
|
85
|
+
trainer: lightning.Trainer,
|
|
86
|
+
pl_module: LightningModule, # noqa: ARG002
|
|
87
|
+
) -> None:
|
|
88
|
+
self._dataloaders_size = self._get_dataloaders_size(trainer.test_dataloaders)
|
|
89
|
+
self._metrics_builders = [
|
|
90
|
+
TorchMetricsBuilder(self._metrics, self._ks, self._item_count) for _ in self._dataloaders_size
|
|
91
|
+
]
|
|
92
|
+
for builder in self._metrics_builders:
|
|
93
|
+
builder.reset()
|
|
94
|
+
|
|
95
|
+
def _apply_postproccesors(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
|
|
96
|
+
for postprocessor in self._postprocessors:
|
|
97
|
+
logits = postprocessor.on_validation(batch, logits)
|
|
98
|
+
return logits
|
|
99
|
+
|
|
100
|
+
def on_validation_batch_end(
|
|
101
|
+
self,
|
|
102
|
+
trainer: lightning.Trainer,
|
|
103
|
+
pl_module: LightningModule,
|
|
104
|
+
outputs: InferenceOutput,
|
|
105
|
+
batch: dict,
|
|
106
|
+
batch_idx: int,
|
|
107
|
+
dataloader_idx: int = 0,
|
|
108
|
+
) -> None:
|
|
109
|
+
self._batch_end(
|
|
110
|
+
trainer,
|
|
111
|
+
pl_module,
|
|
112
|
+
outputs,
|
|
113
|
+
batch,
|
|
114
|
+
batch_idx,
|
|
115
|
+
dataloader_idx,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def on_test_batch_end(
|
|
119
|
+
self,
|
|
120
|
+
trainer: lightning.Trainer,
|
|
121
|
+
pl_module: LightningModule,
|
|
122
|
+
outputs: InferenceOutput,
|
|
123
|
+
batch: dict,
|
|
124
|
+
batch_idx: int,
|
|
125
|
+
dataloader_idx: int = 0,
|
|
126
|
+
) -> None: # pragma: no cover
|
|
127
|
+
self._batch_end(
|
|
128
|
+
trainer,
|
|
129
|
+
pl_module,
|
|
130
|
+
outputs,
|
|
131
|
+
batch,
|
|
132
|
+
batch_idx,
|
|
133
|
+
dataloader_idx,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def _batch_end(
|
|
137
|
+
self,
|
|
138
|
+
trainer: lightning.Trainer, # noqa: ARG002
|
|
139
|
+
pl_module: LightningModule,
|
|
140
|
+
outputs: InferenceOutput,
|
|
141
|
+
batch: dict,
|
|
142
|
+
batch_idx: int,
|
|
143
|
+
dataloader_idx: int,
|
|
144
|
+
) -> None:
|
|
145
|
+
seen_scores = self._apply_postproccesors(batch, outputs["logits"])
|
|
146
|
+
sampled_items = torch.topk(seen_scores, k=self._metrics_builders[dataloader_idx].max_k, dim=1).indices
|
|
147
|
+
self._metrics_builders[dataloader_idx].add_prediction(
|
|
148
|
+
sampled_items, batch[self._ground_truth_column], batch.get(self._train_column)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if batch_idx + 1 == self._dataloaders_size[dataloader_idx]:
|
|
152
|
+
pl_module.log_dict(
|
|
153
|
+
self._metrics_builders[dataloader_idx].get_metrics(),
|
|
154
|
+
on_epoch=True,
|
|
155
|
+
sync_dist=True,
|
|
156
|
+
add_dataloader_idx=True,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def on_validation_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None:
|
|
160
|
+
self._epoch_end(trainer, pl_module)
|
|
161
|
+
|
|
162
|
+
def on_test_epoch_end(self, trainer: lightning.Trainer, pl_module: LightningModule) -> None: # pragma: no cover
|
|
163
|
+
self._epoch_end(trainer, pl_module)
|
|
164
|
+
|
|
165
|
+
def _epoch_end(
|
|
166
|
+
self,
|
|
167
|
+
trainer: lightning.Trainer,
|
|
168
|
+
pl_module: LightningModule, # noqa: ARG002
|
|
169
|
+
) -> None:
|
|
170
|
+
@rank_zero_only
|
|
171
|
+
def print_metrics() -> None:
|
|
172
|
+
metrics = {}
|
|
173
|
+
for name, value in trainer.logged_metrics.items():
|
|
174
|
+
if "@" in name:
|
|
175
|
+
metrics[name] = value.item()
|
|
176
|
+
|
|
177
|
+
if metrics:
|
|
178
|
+
metrics_df = metrics_to_df(metrics)
|
|
179
|
+
|
|
180
|
+
print(metrics_df) # noqa: T201
|
|
181
|
+
print() # noqa: T201
|
|
182
|
+
|
|
183
|
+
print_metrics()
|
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Generic, Optional, TypeVar
|
|
3
|
+
|
|
4
|
+
import lightning
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from replay.nn.lightning import LightningModule
|
|
8
|
+
from replay.nn.lightning.postprocessor import PostprocessorBase
|
|
9
|
+
from replay.nn.output import InferenceOutput
|
|
10
|
+
from replay.utils import (
|
|
11
|
+
PYSPARK_AVAILABLE,
|
|
12
|
+
MissingImport,
|
|
13
|
+
PandasDataFrame,
|
|
14
|
+
PolarsDataFrame,
|
|
15
|
+
SparkDataFrame,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if PYSPARK_AVAILABLE: # pragma: no cover
|
|
19
|
+
import pyspark.sql.functions as sf
|
|
20
|
+
from pyspark.sql import SparkSession
|
|
21
|
+
from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType
|
|
22
|
+
else: # pragma: no cover
|
|
23
|
+
SparkSession = MissingImport
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
_T = TypeVar("_T")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TopItemsCallbackBase(lightning.Callback, Generic[_T]):
|
|
30
|
+
"""
|
|
31
|
+
The base class for a callback that records the result at the inference stage via ``LightningModule``.
|
|
32
|
+
The result consists of top K the highest logit values, IDs of these top K logit values
|
|
33
|
+
and corresponding query ids (encoded IDs of users named ``query_id``).
|
|
34
|
+
|
|
35
|
+
For the callback to work correctly, the batch is expected to contain the ``query_id`` key.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
top_k: int,
|
|
41
|
+
query_column: str,
|
|
42
|
+
item_column: str,
|
|
43
|
+
rating_column: str = "rating",
|
|
44
|
+
postprocessors: Optional[list[PostprocessorBase]] = None,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""
|
|
47
|
+
:param top_k: Take the ``top_k`` IDs with the highest logit values.
|
|
48
|
+
:param query_column: The name of the query column in the resulting dataframe.
|
|
49
|
+
:param item_column: The name of the item column in the resulting dataframe.
|
|
50
|
+
:param rating_column: The name of the rating column in the resulting dataframe.
|
|
51
|
+
This column will contain the ``top_k`` items with the highest logit values.
|
|
52
|
+
:param postprocessors: A list of postprocessors for modifying logits from the model
|
|
53
|
+
before sorting and taking top K ones.
|
|
54
|
+
For example, it can be a softmax operation to logits or set the ``-inf`` value for some IDs.
|
|
55
|
+
Default: ``None``.
|
|
56
|
+
"""
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.query_column = query_column
|
|
59
|
+
self.item_column = item_column
|
|
60
|
+
self.rating_column = rating_column
|
|
61
|
+
self._top_k = top_k
|
|
62
|
+
self._postprocessors: list[PostprocessorBase] = postprocessors or []
|
|
63
|
+
self._query_batches: list[torch.Tensor] = []
|
|
64
|
+
self._item_batches: list[torch.Tensor] = []
|
|
65
|
+
self._item_scores: list[torch.Tensor] = []
|
|
66
|
+
|
|
67
|
+
def on_predict_epoch_start(
|
|
68
|
+
self,
|
|
69
|
+
trainer: lightning.Trainer, # noqa: ARG002
|
|
70
|
+
pl_module: LightningModule,
|
|
71
|
+
) -> None:
|
|
72
|
+
self._query_batches.clear()
|
|
73
|
+
self._item_batches.clear()
|
|
74
|
+
self._item_scores.clear()
|
|
75
|
+
|
|
76
|
+
candidates = pl_module.candidates_to_score
|
|
77
|
+
for postprocessor in self._postprocessors:
|
|
78
|
+
postprocessor.candidates = candidates
|
|
79
|
+
|
|
80
|
+
def on_predict_batch_end(
|
|
81
|
+
self,
|
|
82
|
+
trainer: lightning.Trainer, # noqa: ARG002
|
|
83
|
+
pl_module: LightningModule,
|
|
84
|
+
outputs: InferenceOutput,
|
|
85
|
+
batch: dict,
|
|
86
|
+
batch_idx: int, # noqa: ARG002
|
|
87
|
+
dataloader_idx: int = 0, # noqa: ARG002
|
|
88
|
+
) -> None:
|
|
89
|
+
logits = self._apply_postproccesors(batch, outputs["logits"])
|
|
90
|
+
top_scores, top_item_ids = torch.topk(logits, k=self._top_k, dim=1)
|
|
91
|
+
if pl_module.candidates_to_score is not None:
|
|
92
|
+
top_item_ids = torch.take(pl_module.candidates_to_score, top_item_ids)
|
|
93
|
+
|
|
94
|
+
self._query_batches.append(batch["query_id"])
|
|
95
|
+
self._item_batches.append(top_item_ids)
|
|
96
|
+
self._item_scores.append(top_scores)
|
|
97
|
+
|
|
98
|
+
def get_result(self) -> _T:
|
|
99
|
+
"""
|
|
100
|
+
:returns: prediction result
|
|
101
|
+
"""
|
|
102
|
+
prediction = self._ids_to_result(
|
|
103
|
+
torch.cat(self._query_batches),
|
|
104
|
+
torch.cat(self._item_batches),
|
|
105
|
+
torch.cat(self._item_scores),
|
|
106
|
+
)
|
|
107
|
+
return prediction
|
|
108
|
+
|
|
109
|
+
def _apply_postproccesors(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
|
|
110
|
+
for postprocessor in self._postprocessors:
|
|
111
|
+
logits = postprocessor.on_prediction(batch, logits)
|
|
112
|
+
return logits
|
|
113
|
+
|
|
114
|
+
@abc.abstractmethod
|
|
115
|
+
def _ids_to_result(
|
|
116
|
+
self,
|
|
117
|
+
query_ids: torch.Tensor,
|
|
118
|
+
item_ids: torch.Tensor,
|
|
119
|
+
item_scores: torch.Tensor,
|
|
120
|
+
) -> _T: # pragma: no cover
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class PandasTopItemsCallback(TopItemsCallbackBase[PandasDataFrame]):
|
|
125
|
+
"""
|
|
126
|
+
A callback that records the result of the model's forward function at the inference stage in a Pandas Dataframe.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def _ids_to_result(
|
|
130
|
+
self,
|
|
131
|
+
query_ids: torch.Tensor,
|
|
132
|
+
item_ids: torch.Tensor,
|
|
133
|
+
item_scores: torch.Tensor,
|
|
134
|
+
) -> PandasDataFrame:
|
|
135
|
+
prediction = PandasDataFrame(
|
|
136
|
+
{
|
|
137
|
+
self.query_column: query_ids.flatten().cpu().numpy(),
|
|
138
|
+
self.item_column: list(item_ids.cpu().numpy()),
|
|
139
|
+
self.rating_column: list(item_scores.cpu().numpy()),
|
|
140
|
+
}
|
|
141
|
+
)
|
|
142
|
+
return prediction.explode([self.item_column, self.rating_column])
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class PolarsTopItemsCallback(TopItemsCallbackBase[PolarsDataFrame]):
|
|
146
|
+
"""
|
|
147
|
+
A callback that records the result of the model's forward function at the inference stage in a Polars Dataframe.
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
def _ids_to_result(
|
|
151
|
+
self,
|
|
152
|
+
query_ids: torch.Tensor,
|
|
153
|
+
item_ids: torch.Tensor,
|
|
154
|
+
item_scores: torch.Tensor,
|
|
155
|
+
) -> PolarsDataFrame:
|
|
156
|
+
prediction = PolarsDataFrame(
|
|
157
|
+
{
|
|
158
|
+
self.query_column: query_ids.flatten().cpu().numpy(),
|
|
159
|
+
self.item_column: list(item_ids.cpu().numpy()),
|
|
160
|
+
self.rating_column: list(item_scores.cpu().numpy()),
|
|
161
|
+
}
|
|
162
|
+
)
|
|
163
|
+
return prediction.explode([self.item_column, self.rating_column])
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class SparkTopItemsCallback(TopItemsCallbackBase[SparkDataFrame]):
|
|
167
|
+
"""
|
|
168
|
+
A callback that records the result of the model's forward function at the inference stage in a Spark Dataframe.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
top_k: int,
|
|
174
|
+
query_column: str,
|
|
175
|
+
item_column: str,
|
|
176
|
+
rating_column: str,
|
|
177
|
+
spark_session: SparkSession,
|
|
178
|
+
postprocessors: Optional[list[PostprocessorBase]] = None,
|
|
179
|
+
) -> None:
|
|
180
|
+
"""
|
|
181
|
+
:param top_k: Take the ``top_k`` IDs with the highest logit values.
|
|
182
|
+
:param query_column: The name of the query column in the resulting dataframe.
|
|
183
|
+
:param item_column: The name of the item column in the resulting dataframe.
|
|
184
|
+
:param rating_column: The name of the rating column in the resulting dataframe.
|
|
185
|
+
This column will contain the ``top_k`` items with the highest logit values.
|
|
186
|
+
:param spark_session: Spark session. Required to create a Spark DataFrame.
|
|
187
|
+
:param postprocessors: A list of postprocessors for modifying logits from the model
|
|
188
|
+
before sorting and taking top K ones.
|
|
189
|
+
For example, it can be a softmax operation to logits or set the ``-inf`` value for some IDs.
|
|
190
|
+
Default: ``None``.
|
|
191
|
+
"""
|
|
192
|
+
super().__init__(
|
|
193
|
+
top_k=top_k,
|
|
194
|
+
query_column=query_column,
|
|
195
|
+
item_column=item_column,
|
|
196
|
+
rating_column=rating_column,
|
|
197
|
+
postprocessors=postprocessors,
|
|
198
|
+
)
|
|
199
|
+
self.spark_session = spark_session
|
|
200
|
+
|
|
201
|
+
def _ids_to_result(
|
|
202
|
+
self,
|
|
203
|
+
query_ids: torch.Tensor,
|
|
204
|
+
item_ids: torch.Tensor,
|
|
205
|
+
item_scores: torch.Tensor,
|
|
206
|
+
) -> SparkDataFrame:
|
|
207
|
+
schema = (
|
|
208
|
+
StructType()
|
|
209
|
+
.add(self.query_column, IntegerType(), False)
|
|
210
|
+
.add(self.item_column, ArrayType(IntegerType()), False)
|
|
211
|
+
.add(self.rating_column, ArrayType(DoubleType()), False)
|
|
212
|
+
)
|
|
213
|
+
prediction = (
|
|
214
|
+
self.spark_session.createDataFrame(
|
|
215
|
+
data=list(
|
|
216
|
+
zip(
|
|
217
|
+
query_ids.flatten().cpu().numpy().tolist(),
|
|
218
|
+
item_ids.cpu().numpy().tolist(),
|
|
219
|
+
item_scores.cpu().numpy().tolist(),
|
|
220
|
+
)
|
|
221
|
+
),
|
|
222
|
+
schema=schema,
|
|
223
|
+
)
|
|
224
|
+
.withColumn(
|
|
225
|
+
"exploded_columns",
|
|
226
|
+
sf.explode(sf.arrays_zip(self.item_column, self.rating_column)),
|
|
227
|
+
)
|
|
228
|
+
.select(
|
|
229
|
+
self.query_column,
|
|
230
|
+
f"exploded_columns.{self.item_column}",
|
|
231
|
+
f"exploded_columns.{self.rating_column}",
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
return prediction
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class TorchTopItemsCallback(TopItemsCallbackBase[tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]]):
|
|
238
|
+
"""
|
|
239
|
+
A callback that records the result of the model's forward function at the inference stage in a PyTorch Tensors.
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
def __init__(
|
|
243
|
+
self,
|
|
244
|
+
top_k: int,
|
|
245
|
+
postprocessors: Optional[list[PostprocessorBase]] = None,
|
|
246
|
+
) -> None:
|
|
247
|
+
"""
|
|
248
|
+
:param top_k: Take the ``top_k`` IDs with the highest logit values.
|
|
249
|
+
:param postprocessors: A list of postprocessors for modifying logits from the model
|
|
250
|
+
before sorting and taking top K.
|
|
251
|
+
For example, it can be a softmax operation to logits or set the ``-inf`` value for some IDs.
|
|
252
|
+
Default: ``None``.
|
|
253
|
+
"""
|
|
254
|
+
super().__init__(
|
|
255
|
+
top_k=top_k,
|
|
256
|
+
query_column="query_id",
|
|
257
|
+
item_column="item_id",
|
|
258
|
+
rating_column="rating",
|
|
259
|
+
postprocessors=postprocessors,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
def _ids_to_result(
|
|
263
|
+
self,
|
|
264
|
+
query_ids: torch.Tensor,
|
|
265
|
+
item_ids: torch.Tensor,
|
|
266
|
+
item_scores: torch.Tensor,
|
|
267
|
+
) -> tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]:
|
|
268
|
+
return (
|
|
269
|
+
query_ids.flatten().cpu().long(),
|
|
270
|
+
item_ids.cpu().long(),
|
|
271
|
+
item_scores.cpu(),
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class HiddenStatesCallback(lightning.Callback):
|
|
276
|
+
"""
|
|
277
|
+
A callback for getting any hidden state from the model.
|
|
278
|
+
|
|
279
|
+
When applying this callback,
|
|
280
|
+
it is expected that the result of the model's forward function contains the ``hidden_states`` key.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
def __init__(self, hidden_state_index: int):
|
|
284
|
+
"""
|
|
285
|
+
:param hidden_state_index: It is expected that the result of the model's forward function
|
|
286
|
+
contains the ``hidden_states`` key. ``hidden_states`` key contains Tuple of PyTorch Tensors.
|
|
287
|
+
Therefore, to get a specific hidden state, you need to submit an index from this tuple.
|
|
288
|
+
"""
|
|
289
|
+
self._hidden_state_index = hidden_state_index
|
|
290
|
+
self._embeddings_per_batch: list[torch.Tensor] = []
|
|
291
|
+
|
|
292
|
+
def on_predict_epoch_start(
|
|
293
|
+
self,
|
|
294
|
+
trainer: lightning.Trainer, # noqa: ARG002
|
|
295
|
+
pl_module: LightningModule, # noqa: ARG002
|
|
296
|
+
) -> None:
|
|
297
|
+
self._embeddings_per_batch.clear()
|
|
298
|
+
|
|
299
|
+
def on_predict_batch_end(
|
|
300
|
+
self,
|
|
301
|
+
trainer: lightning.Trainer, # noqa: ARG002
|
|
302
|
+
pl_module: LightningModule, # noqa: ARG002
|
|
303
|
+
outputs: InferenceOutput,
|
|
304
|
+
batch: dict, # noqa: ARG002
|
|
305
|
+
batch_idx: int, # noqa: ARG002
|
|
306
|
+
dataloader_idx: int = 0, # noqa: ARG002
|
|
307
|
+
) -> None:
|
|
308
|
+
self._embeddings_per_batch.append(outputs["hidden_states"][self._hidden_state_index].detach().cpu())
|
|
309
|
+
|
|
310
|
+
def get_result(self):
|
|
311
|
+
"""
|
|
312
|
+
:returns: Hidden states through all batches.
|
|
313
|
+
"""
|
|
314
|
+
return torch.cat(self._embeddings_per_batch)
|