replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay_rec-0.20.3.dist-info/RECORD +0 -138
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
replay/nn/output.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from typing import TypedDict
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing_extensions import NotRequired
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TrainOutput(TypedDict):
|
|
8
|
+
"""
|
|
9
|
+
Storing outputs from models training stage.
|
|
10
|
+
|
|
11
|
+
:param loss: a tensor containing the calculated loss.\n
|
|
12
|
+
It is important that the tensor contains a gradient to call back propagation from the outside.
|
|
13
|
+
:param hidden_states: Tuple of `torch.Tensor`.\n
|
|
14
|
+
One for the output of the embeddings, if the model has an embedding layer, +
|
|
15
|
+
one for the output of each layer.\n
|
|
16
|
+
Expected shape: ``(batch_size, sequence_length, hidden_size)``.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
loss: torch.Tensor
|
|
20
|
+
hidden_states: NotRequired[tuple[torch.Tensor, ...]]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class InferenceOutput(TypedDict):
|
|
24
|
+
"""
|
|
25
|
+
Storing outputs from models inference stage.
|
|
26
|
+
|
|
27
|
+
:param logits:
|
|
28
|
+
Sequence of hidden-states at the output of the last layer of the model.\n
|
|
29
|
+
Expected shape: ``(batch_size, sequence_length, hidden_size)``.
|
|
30
|
+
:param hidden_states: Tuple of `torch.Tensor`
|
|
31
|
+
(one for the output of the embeddings, if the model has an embedding layer, +
|
|
32
|
+
one for the output of each layer).\n
|
|
33
|
+
Expected shape: ``(batch_size, sequence_length, hidden_size)``.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
logits: torch.Tensor
|
|
37
|
+
hidden_states: NotRequired[tuple[torch.Tensor, ...]]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from replay.data.nn.schema import TensorMap
|
|
6
|
+
from replay.nn.agg import AggregatorProto
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PositionAwareAggregator(torch.nn.Module):
|
|
10
|
+
"""
|
|
11
|
+
The layer for aggregating embeddings and adding positional encoding.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
embedding_aggregator: AggregatorProto,
|
|
17
|
+
max_sequence_length: int,
|
|
18
|
+
dropout: float,
|
|
19
|
+
) -> None:
|
|
20
|
+
"""
|
|
21
|
+
:param embedding_aggregator: An object of a class that performs the logic of aggregating multiple embeddings.\n
|
|
22
|
+
For example, it can be a ``sum``, a ``mean``, or a ``concatenation``.
|
|
23
|
+
:param max_sequence_length: Max length of sequence.
|
|
24
|
+
:param dropout: probability of an element to be zeroed.
|
|
25
|
+
"""
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.embedding_aggregator = embedding_aggregator
|
|
28
|
+
self.pe = torch.nn.Embedding(max_sequence_length, self.embedding_aggregator.embedding_dim)
|
|
29
|
+
self.dropout = torch.nn.Dropout(p=dropout)
|
|
30
|
+
|
|
31
|
+
def reset_parameters(self) -> None:
|
|
32
|
+
self.embedding_aggregator.reset_parameters()
|
|
33
|
+
for _, param in self.pe.named_parameters():
|
|
34
|
+
with contextlib.suppress(ValueError):
|
|
35
|
+
torch.nn.init.xavier_normal_(param.data)
|
|
36
|
+
|
|
37
|
+
def forward(self, feature_tensors: TensorMap) -> torch.Tensor:
|
|
38
|
+
"""
|
|
39
|
+
:param feature_tensors: a dictionary of tensors to pass into ``embedding_aggregator``.
|
|
40
|
+
|
|
41
|
+
:returns: Aggregated embeddings with positional encoding.
|
|
42
|
+
"""
|
|
43
|
+
seqs: torch.Tensor = self.embedding_aggregator(feature_tensors)
|
|
44
|
+
assert seqs.dim() == 3
|
|
45
|
+
batch_size, seq_len, embedding_dim = seqs.size()
|
|
46
|
+
assert (
|
|
47
|
+
seq_len <= self.pe.num_embeddings
|
|
48
|
+
), f"Sequence length = {seq_len} is greater then positional embedding num = {self.pe.num_embeddings}"
|
|
49
|
+
|
|
50
|
+
seqs *= embedding_dim**0.5
|
|
51
|
+
seqs += self.pe.weight[:seq_len].unsqueeze(0).repeat(batch_size, 1, 1)
|
|
52
|
+
seqs = self.dropout(seqs)
|
|
53
|
+
return seqs
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from replay.data.nn import TensorMap
|
|
6
|
+
from replay.nn.attention import MultiHeadDifferentialAttention
|
|
7
|
+
from replay.nn.ffn import SwiGLU
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DiffTransformerBlock(torch.nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
Single Block of the DiffTransformer Architecture.
|
|
13
|
+
Consists of Multi-Head Differential Attention followed by a SwiGLU Feed-Forward Network.
|
|
14
|
+
|
|
15
|
+
Source paper: https://arxiv.org/pdf/2410.05258
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, embedding_dim: int, num_heads: int, lambda_init: float):
|
|
19
|
+
"""
|
|
20
|
+
:param embedding_dim: Total dimension of the model. Must be divisible by ``num_heads``.
|
|
21
|
+
:param num_heads: Number of parallel attention heads.
|
|
22
|
+
:param lambda_init: Initial value for lambda.
|
|
23
|
+
"""
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.attn_norm = torch.nn.RMSNorm(embedding_dim)
|
|
26
|
+
self.attn = MultiHeadDifferentialAttention(embedding_dim, num_heads, lambda_init, vdim=2 * embedding_dim)
|
|
27
|
+
self.ff_norm = torch.nn.RMSNorm(embedding_dim)
|
|
28
|
+
self.ff = SwiGLU(embedding_dim, 2 * embedding_dim)
|
|
29
|
+
|
|
30
|
+
def reset_parameters(self) -> None:
|
|
31
|
+
self.attn_norm.reset_parameters()
|
|
32
|
+
self.attn.reset_parameters()
|
|
33
|
+
self.ff_norm.reset_parameters()
|
|
34
|
+
self.ff.reset_parameters()
|
|
35
|
+
|
|
36
|
+
def forward(
|
|
37
|
+
self,
|
|
38
|
+
input_embeddings: torch.Tensor,
|
|
39
|
+
attention_mask: torch.FloatTensor,
|
|
40
|
+
) -> torch.Tensor:
|
|
41
|
+
"""
|
|
42
|
+
Forward pass for a single differential transformer block.
|
|
43
|
+
|
|
44
|
+
:param input_embeddings: Input tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
45
|
+
:param attention_mask: Causal-like mask for attention pattern, where ``-inf`` for ``PAD``, ``0`` - otherwise.\n
|
|
46
|
+
Possible shapes:\n
|
|
47
|
+
1. ``(batch_size * num_heads, sequence_length, sequence_length)``
|
|
48
|
+
2. ``(batch_size, num_heads, sequence_length, sequence_length)``
|
|
49
|
+
:returns: Output tensor after processing through the block.
|
|
50
|
+
"""
|
|
51
|
+
# Apply Multi-Head Differential Attention with residual connection
|
|
52
|
+
attent_emb = self.attn(
|
|
53
|
+
input_embeddings,
|
|
54
|
+
input_embeddings,
|
|
55
|
+
input_embeddings,
|
|
56
|
+
attention_mask,
|
|
57
|
+
)
|
|
58
|
+
attention_block_out = self.attn_norm(attent_emb + input_embeddings)
|
|
59
|
+
|
|
60
|
+
# Apply SwiGLU Feed-Forward Network with residual connection
|
|
61
|
+
ff_out = self.ff(input_embeddings=attention_block_out)
|
|
62
|
+
feedforward_block_out = self.ff_norm(ff_out + attention_block_out)
|
|
63
|
+
return feedforward_block_out
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class DiffTransformerLayer(torch.nn.Module):
|
|
67
|
+
"""
|
|
68
|
+
Stacked blocks of the DiffTransformer Architecture.
|
|
69
|
+
Single block consists of Multi-Head Differential Attention followed by a SwiGLU Feed-Forward Network.
|
|
70
|
+
|
|
71
|
+
Source paper: https://arxiv.org/pdf/2410.05258\n
|
|
72
|
+
Reference: https://github.com/nanowell/Differential-Transformer-PyTorch/blob/main/DiffTransformer.py
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
embedding_dim: int,
|
|
78
|
+
num_heads: int,
|
|
79
|
+
num_blocks: int,
|
|
80
|
+
) -> None:
|
|
81
|
+
"""
|
|
82
|
+
:param embedding_dim: Total dimension of the model. Must be divisible by num_heads.
|
|
83
|
+
:param num_heads: Number of parallel attention heads.
|
|
84
|
+
:param num_blocks: Number of Transformer blocks.
|
|
85
|
+
"""
|
|
86
|
+
torch.nn.MultiheadAttention
|
|
87
|
+
super().__init__()
|
|
88
|
+
self.layers = torch.nn.ModuleList(
|
|
89
|
+
[
|
|
90
|
+
DiffTransformerBlock(
|
|
91
|
+
embedding_dim=embedding_dim,
|
|
92
|
+
num_heads=num_heads,
|
|
93
|
+
lambda_init=0.8 - 0.6 * math.exp(-0.3 * block_num),
|
|
94
|
+
)
|
|
95
|
+
for block_num in range(num_blocks)
|
|
96
|
+
]
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def reset_parameters(self) -> None:
|
|
100
|
+
for layer in self.layers:
|
|
101
|
+
layer.reset_parameters()
|
|
102
|
+
|
|
103
|
+
def forward(
|
|
104
|
+
self,
|
|
105
|
+
feature_tensors: TensorMap, # noqa: ARG002
|
|
106
|
+
input_embeddings: torch.Tensor,
|
|
107
|
+
padding_mask: torch.BoolTensor, # noqa: ARG002
|
|
108
|
+
attention_mask: torch.FloatTensor,
|
|
109
|
+
) -> torch.Tensor:
|
|
110
|
+
"""
|
|
111
|
+
forward(input_embeddings, attention_mask)
|
|
112
|
+
:param input_embeddings: Input tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
113
|
+
:param attention_mask: Causal-like mask for attention pattern, where ``-inf`` for ``PAD``, ``0`` - otherwise.\n
|
|
114
|
+
Possible shapes:\n
|
|
115
|
+
1. ``(batch_size * num_heads, sequence_length, sequence_length)``
|
|
116
|
+
2. ``(batch_size, num_heads, sequence_length, sequence_length)``
|
|
117
|
+
:returns: Output tensor after processing through the layer.
|
|
118
|
+
"""
|
|
119
|
+
seqs = input_embeddings
|
|
120
|
+
for layer in self.layers:
|
|
121
|
+
seqs = layer(
|
|
122
|
+
input_embeddings=seqs,
|
|
123
|
+
attention_mask=attention_mask,
|
|
124
|
+
)
|
|
125
|
+
return seqs
|
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Literal, Optional, Protocol, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from replay.data.nn import TensorMap, TensorSchema
|
|
7
|
+
from replay.nn.agg import AggregatorProto
|
|
8
|
+
from replay.nn.head import EmbeddingTyingHead
|
|
9
|
+
from replay.nn.loss import LossProto
|
|
10
|
+
from replay.nn.mask import AttentionMaskProto
|
|
11
|
+
from replay.nn.normalization import NormalizerProto
|
|
12
|
+
from replay.nn.output import InferenceOutput, TrainOutput
|
|
13
|
+
from replay.nn.utils import warning_is_not_none
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EmbedderProto(Protocol):
|
|
17
|
+
def get_item_weights(
|
|
18
|
+
self,
|
|
19
|
+
indices: Optional[torch.LongTensor],
|
|
20
|
+
) -> torch.Tensor: ...
|
|
21
|
+
|
|
22
|
+
def forward(
|
|
23
|
+
self,
|
|
24
|
+
feature_tensors: TensorMap,
|
|
25
|
+
feature_names: Optional[Sequence[str]] = None,
|
|
26
|
+
) -> TensorMap: ...
|
|
27
|
+
|
|
28
|
+
def reset_parameters(self) -> None: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class EncoderProto(Protocol):
|
|
32
|
+
def forward(
|
|
33
|
+
self,
|
|
34
|
+
feature_tensors: TensorMap,
|
|
35
|
+
input_embeddings: torch.Tensor,
|
|
36
|
+
padding_mask: torch.BoolTensor,
|
|
37
|
+
attention_mask: torch.Tensor,
|
|
38
|
+
) -> torch.Tensor: ...
|
|
39
|
+
|
|
40
|
+
def reset_parameters(self) -> None: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SasRecBody(torch.nn.Module):
|
|
44
|
+
"""
|
|
45
|
+
Implementation of the architecture of the SasRec model.\n
|
|
46
|
+
It can include various self-written blocks for modifying the model,
|
|
47
|
+
but the sequence of applying layers is fixed in accordance with the original architecture.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
embedder: EmbedderProto,
|
|
53
|
+
embedding_aggregator: AggregatorProto,
|
|
54
|
+
attn_mask_builder: AttentionMaskProto,
|
|
55
|
+
encoder: EncoderProto,
|
|
56
|
+
output_normalization: NormalizerProto,
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
:param embedder: An object of a class that performs the logic of
|
|
60
|
+
generating embeddings from an input set of tensors.
|
|
61
|
+
:param embedding_aggregator: An object of a class that performs the logic of aggregating multiple embeddings.\n
|
|
62
|
+
For example, it can be a ``sum``, a ``mean``, or a ``concatenation``.
|
|
63
|
+
:param attn_mask_builder: An object of a class that performs the logic of
|
|
64
|
+
generating an attention mask based on the features and padding mask given to the model.
|
|
65
|
+
:param encoder: An object of a class that performs the logic of generating
|
|
66
|
+
a hidden embedding representation based on
|
|
67
|
+
features, padding masks, attention mask, and aggregated embedding.
|
|
68
|
+
:param output_normalization: An object of a class that performs the logic of
|
|
69
|
+
normalization of the hidden state obtained from the encoder.\n
|
|
70
|
+
For example, it may be a ``torch.nn.LayerNorm`` or ``torch.nn.RMSNorm``.
|
|
71
|
+
"""
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.embedder = embedder
|
|
74
|
+
self.attn_mask_builder = attn_mask_builder
|
|
75
|
+
self.embedding_aggregator = embedding_aggregator
|
|
76
|
+
self.encoder = encoder
|
|
77
|
+
self.output_normalization = output_normalization
|
|
78
|
+
|
|
79
|
+
def reset_parameters(self) -> None:
|
|
80
|
+
self.embedder.reset_parameters()
|
|
81
|
+
self.embedding_aggregator.reset_parameters()
|
|
82
|
+
self.encoder.reset_parameters()
|
|
83
|
+
self.output_normalization.reset_parameters()
|
|
84
|
+
|
|
85
|
+
def forward(
|
|
86
|
+
self,
|
|
87
|
+
feature_tensors: TensorMap,
|
|
88
|
+
padding_mask: torch.BoolTensor,
|
|
89
|
+
) -> torch.Tensor:
|
|
90
|
+
"""
|
|
91
|
+
:param feature_tensors: a dictionary of tensors to generate embeddings.
|
|
92
|
+
:param padding_mask: A mask of shape ``(batch_size, sequence_length)``
|
|
93
|
+
indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding").
|
|
94
|
+
``False`` value indicates that the corresponding ``key`` value will be ignored.
|
|
95
|
+
:returns: The final hidden state.\n
|
|
96
|
+
Expected shape: ``(batch_size, sequence_length, embedding_dim)``
|
|
97
|
+
"""
|
|
98
|
+
embeddings = self.embedder(feature_tensors)
|
|
99
|
+
agg_emb: torch.Tensor = self.embedding_aggregator(embeddings)
|
|
100
|
+
assert agg_emb.dim() == 3
|
|
101
|
+
|
|
102
|
+
attn_mask = self.attn_mask_builder(feature_tensors, padding_mask)
|
|
103
|
+
|
|
104
|
+
hidden_state: torch.Tensor = self.encoder(
|
|
105
|
+
feature_tensors=feature_tensors,
|
|
106
|
+
input_embeddings=agg_emb,
|
|
107
|
+
padding_mask=padding_mask,
|
|
108
|
+
attention_mask=attn_mask,
|
|
109
|
+
)
|
|
110
|
+
assert agg_emb.size() == hidden_state.size()
|
|
111
|
+
|
|
112
|
+
hidden_state = self.output_normalization(hidden_state)
|
|
113
|
+
return hidden_state
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class SasRec(torch.nn.Module):
|
|
117
|
+
"""
|
|
118
|
+
A model using the SasRec architecture as a hidden state generator.
|
|
119
|
+
The hidden states are multiplied by the item embeddings,
|
|
120
|
+
resulting in logits for each of the items.
|
|
121
|
+
|
|
122
|
+
Source paper: https://arxiv.org/pdf/1808.09781.
|
|
123
|
+
|
|
124
|
+
Example:
|
|
125
|
+
|
|
126
|
+
.. code-block:: python
|
|
127
|
+
|
|
128
|
+
from replay.data import FeatureHint, FeatureSource, FeatureType
|
|
129
|
+
from replay.data.nn import TensorFeatureInfo, TensorFeatureSource, TensorSchema
|
|
130
|
+
from replay.nn.agg import SumAggregator
|
|
131
|
+
from replay.nn.embedding import SequenceEmbedding
|
|
132
|
+
from replay.nn.mask import DefaultAttentionMask
|
|
133
|
+
from replay.nn.loss import CESampled
|
|
134
|
+
from replay.nn.sequential import PositionAwareAggregator, SasRecTransformerLayer
|
|
135
|
+
|
|
136
|
+
tensor_schema = TensorSchema(
|
|
137
|
+
[
|
|
138
|
+
TensorFeatureInfo(
|
|
139
|
+
"item_id",
|
|
140
|
+
is_seq=True,
|
|
141
|
+
feature_type=FeatureType.CATEGORICAL,
|
|
142
|
+
embedding_dim=256,
|
|
143
|
+
padding_value=NUM_UNIQUE_ITEMS,
|
|
144
|
+
cardinality=NUM_UNIQUE_ITEMS+1,
|
|
145
|
+
feature_hint=FeatureHint.ITEM_ID,
|
|
146
|
+
feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "item_id")]
|
|
147
|
+
),
|
|
148
|
+
]
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
body = SasRecBody(
|
|
152
|
+
embedder=SequenceEmbedding(
|
|
153
|
+
schema=tensor_schema,
|
|
154
|
+
),
|
|
155
|
+
embedding_aggregator=PositionAwareAggregator(
|
|
156
|
+
embedding_aggregator=SumAggregator(embedding_dim=256),
|
|
157
|
+
max_sequence_length=100,
|
|
158
|
+
dropout=0.2,
|
|
159
|
+
),
|
|
160
|
+
attn_mask_builder=DefaultAttentionMask(
|
|
161
|
+
reference_feature_name=tensor_schema.item_id_feature_name,
|
|
162
|
+
num_heads=2,
|
|
163
|
+
),
|
|
164
|
+
encoder=SasRecTransformerLayer(
|
|
165
|
+
embedding_dim=256,
|
|
166
|
+
num_heads=2,
|
|
167
|
+
num_blocks=2,
|
|
168
|
+
dropout=0.3,
|
|
169
|
+
activation="relu",
|
|
170
|
+
),
|
|
171
|
+
output_normalization=torch.nn.LayerNorm(256),
|
|
172
|
+
)
|
|
173
|
+
sasrec = SasRec(
|
|
174
|
+
body=body,
|
|
175
|
+
loss=CESampled(padding_idx=tensor_schema.item_id_features.item().padding_value)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def __init__(
|
|
181
|
+
self,
|
|
182
|
+
body: SasRecBody,
|
|
183
|
+
loss: LossProto,
|
|
184
|
+
):
|
|
185
|
+
"""
|
|
186
|
+
:param body: An instance of SasRecBody.
|
|
187
|
+
:param loss: An object of a class that performs loss calculation
|
|
188
|
+
based on hidden states from the model, positive and optionally negative labels.
|
|
189
|
+
"""
|
|
190
|
+
super().__init__()
|
|
191
|
+
self.body = body
|
|
192
|
+
self.head = EmbeddingTyingHead()
|
|
193
|
+
self.loss = loss
|
|
194
|
+
self.loss.logits_callback = self.get_logits
|
|
195
|
+
|
|
196
|
+
self.reset_parameters()
|
|
197
|
+
|
|
198
|
+
@classmethod
|
|
199
|
+
def from_params(
|
|
200
|
+
cls,
|
|
201
|
+
schema: TensorSchema,
|
|
202
|
+
embedding_dim: int = 192,
|
|
203
|
+
num_heads: int = 4,
|
|
204
|
+
num_blocks: int = 2,
|
|
205
|
+
max_sequence_length: int = 50,
|
|
206
|
+
dropout: float = 0.3,
|
|
207
|
+
excluded_features: Optional[list[str]] = None,
|
|
208
|
+
categorical_list_feature_aggregation_method: Literal["sum", "mean", "max"] = "sum",
|
|
209
|
+
) -> "SasRec":
|
|
210
|
+
from replay.nn.agg import SumAggregator
|
|
211
|
+
from replay.nn.embedding import SequenceEmbedding
|
|
212
|
+
from replay.nn.loss import CE
|
|
213
|
+
from replay.nn.mask import DefaultAttentionMask
|
|
214
|
+
|
|
215
|
+
from .agg import PositionAwareAggregator
|
|
216
|
+
from .transformer import SasRecTransformerLayer
|
|
217
|
+
|
|
218
|
+
excluded_features = [
|
|
219
|
+
schema.query_id_feature_name,
|
|
220
|
+
schema.timestamp_feature_name,
|
|
221
|
+
*(excluded_features or []),
|
|
222
|
+
]
|
|
223
|
+
excluded_features = list(set(excluded_features))
|
|
224
|
+
|
|
225
|
+
body = SasRecBody(
|
|
226
|
+
embedder=SequenceEmbedding(
|
|
227
|
+
schema=schema,
|
|
228
|
+
categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,
|
|
229
|
+
excluded_features=excluded_features,
|
|
230
|
+
),
|
|
231
|
+
embedding_aggregator=PositionAwareAggregator(
|
|
232
|
+
embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),
|
|
233
|
+
max_sequence_length=max_sequence_length,
|
|
234
|
+
dropout=dropout,
|
|
235
|
+
),
|
|
236
|
+
attn_mask_builder=DefaultAttentionMask(
|
|
237
|
+
reference_feature_name=schema.item_id_feature_name,
|
|
238
|
+
num_heads=num_heads,
|
|
239
|
+
),
|
|
240
|
+
encoder=SasRecTransformerLayer(
|
|
241
|
+
embedding_dim=embedding_dim,
|
|
242
|
+
num_heads=num_heads,
|
|
243
|
+
num_blocks=num_blocks,
|
|
244
|
+
dropout=dropout,
|
|
245
|
+
activation="relu",
|
|
246
|
+
),
|
|
247
|
+
output_normalization=torch.nn.LayerNorm(embedding_dim),
|
|
248
|
+
)
|
|
249
|
+
return cls(
|
|
250
|
+
body=body,
|
|
251
|
+
loss=CE(ignore_index=schema.item_id_features.item().padding_value),
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def reset_parameters(self) -> None:
|
|
255
|
+
self.body.reset_parameters()
|
|
256
|
+
|
|
257
|
+
def get_logits(
|
|
258
|
+
self,
|
|
259
|
+
model_embeddings: torch.Tensor,
|
|
260
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
261
|
+
) -> torch.Tensor:
|
|
262
|
+
item_embeddings: torch.Tensor = self.body.embedder.get_item_weights(candidates_to_score)
|
|
263
|
+
logits: torch.Tensor = self.head(model_embeddings, item_embeddings)
|
|
264
|
+
return logits
|
|
265
|
+
|
|
266
|
+
def forward_train(
|
|
267
|
+
self,
|
|
268
|
+
feature_tensors: TensorMap,
|
|
269
|
+
padding_mask: torch.BoolTensor,
|
|
270
|
+
positive_labels: torch.LongTensor,
|
|
271
|
+
negative_labels: torch.LongTensor,
|
|
272
|
+
target_padding_mask: torch.BoolTensor,
|
|
273
|
+
) -> TrainOutput:
|
|
274
|
+
hidden_states: torch.Tensor = self.body(feature_tensors, padding_mask)
|
|
275
|
+
assert hidden_states.dim() == 3
|
|
276
|
+
|
|
277
|
+
loss: torch.Tensor = self.loss(
|
|
278
|
+
model_embeddings=hidden_states,
|
|
279
|
+
feature_tensors=feature_tensors,
|
|
280
|
+
positive_labels=positive_labels,
|
|
281
|
+
negative_labels=negative_labels,
|
|
282
|
+
padding_mask=padding_mask,
|
|
283
|
+
target_padding_mask=target_padding_mask,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
return {
|
|
287
|
+
"loss": loss,
|
|
288
|
+
"hidden_states": (hidden_states,),
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
def forward_inference(
|
|
292
|
+
self,
|
|
293
|
+
feature_tensors: TensorMap,
|
|
294
|
+
padding_mask: torch.BoolTensor,
|
|
295
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
296
|
+
) -> InferenceOutput:
|
|
297
|
+
hidden_states: torch.Tensor = self.body(feature_tensors, padding_mask)
|
|
298
|
+
assert hidden_states.dim() == 3
|
|
299
|
+
|
|
300
|
+
last_hidden_state = hidden_states[:, -1, :].contiguous()
|
|
301
|
+
logits = self.get_logits(last_hidden_state, candidates_to_score)
|
|
302
|
+
|
|
303
|
+
return {
|
|
304
|
+
"logits": logits,
|
|
305
|
+
"hidden_states": (hidden_states,),
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
def forward(
|
|
309
|
+
self,
|
|
310
|
+
feature_tensors: TensorMap,
|
|
311
|
+
padding_mask: torch.BoolTensor,
|
|
312
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
313
|
+
positive_labels: Optional[torch.LongTensor] = None,
|
|
314
|
+
negative_labels: Optional[torch.LongTensor] = None,
|
|
315
|
+
target_padding_mask: Optional[torch.BoolTensor] = None,
|
|
316
|
+
) -> Union[TrainOutput, InferenceOutput]:
|
|
317
|
+
"""
|
|
318
|
+
:param feature_tensors: a dictionary of tensors to generate embeddings.
|
|
319
|
+
:param padding_mask: A mask of shape ``(batch_size, sequence_length)``
|
|
320
|
+
indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding").
|
|
321
|
+
``False`` value indicates that the corresponding ``key`` value will be ignored.
|
|
322
|
+
:param candidates_to_score: a tensor containing item IDs
|
|
323
|
+
for which you need to get logits at the inference stage.\n
|
|
324
|
+
**Note:** you must take into account the padding value when creating the tensor.\n
|
|
325
|
+
The tensor participates in calculations only on the inference stage.
|
|
326
|
+
You don't have to submit an argument at training stage,
|
|
327
|
+
but if it is submitted, then no effect will be provided.\n
|
|
328
|
+
Default: ``None``.
|
|
329
|
+
:param positive_labels: a tensor containing positive labels for calculating the loss.\n
|
|
330
|
+
You don't have to submit an argument at inference stage,
|
|
331
|
+
but if it is submitted, then no effect will be provided.\n
|
|
332
|
+
Default: ``None``.
|
|
333
|
+
:param negative_labels: a tensor containing negative labels for calculating the loss.\n
|
|
334
|
+
**Note:** Before run make sure that your loss supports calculations with negative labels.\n
|
|
335
|
+
You don't have to submit an argument at inference stage,
|
|
336
|
+
but if it is submitted, then no effect will be provided.\n
|
|
337
|
+
Default: ``None``.
|
|
338
|
+
:param target_padding_mask: A mask of shape ``(batch_size, sequence_length, num_positives)``
|
|
339
|
+
indicating elements from ``positive_labels`` to ignore during loss calculation.
|
|
340
|
+
``False`` value indicates that the corresponding value will be ignored.\n
|
|
341
|
+
You don't have to submit an argument at inference stage,
|
|
342
|
+
but if it is submitted, then no effect will be provided.\n
|
|
343
|
+
Default: ``None``.
|
|
344
|
+
:returns: During training, the model will return an object
|
|
345
|
+
of the ``TrainOutput`` container class.
|
|
346
|
+
At the inference stage, the ``InferenceOutput`` class will be returned.
|
|
347
|
+
"""
|
|
348
|
+
if self.training:
|
|
349
|
+
all(
|
|
350
|
+
map(
|
|
351
|
+
warning_is_not_none("Variable `{}` is not None. This will have no effect at the training stage."),
|
|
352
|
+
[(candidates_to_score, "candidates_to_score")],
|
|
353
|
+
)
|
|
354
|
+
)
|
|
355
|
+
return self.forward_train(
|
|
356
|
+
feature_tensors=feature_tensors,
|
|
357
|
+
padding_mask=padding_mask,
|
|
358
|
+
positive_labels=positive_labels,
|
|
359
|
+
negative_labels=negative_labels,
|
|
360
|
+
target_padding_mask=target_padding_mask,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
all(
|
|
364
|
+
map(
|
|
365
|
+
warning_is_not_none("Variable `{}` is not None. This will have no effect at the inference stage."),
|
|
366
|
+
[
|
|
367
|
+
(positive_labels, "positive_labels"),
|
|
368
|
+
(negative_labels, "negative_labels"),
|
|
369
|
+
(target_padding_mask, "target_padding_mask"),
|
|
370
|
+
],
|
|
371
|
+
)
|
|
372
|
+
)
|
|
373
|
+
return self.forward_inference(
|
|
374
|
+
feature_tensors=feature_tensors,
|
|
375
|
+
padding_mask=padding_mask,
|
|
376
|
+
candidates_to_score=candidates_to_score,
|
|
377
|
+
)
|