replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay_rec-0.20.3.dist-info/RECORD +0 -138
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -3,10 +3,15 @@ import contextlib
|
|
|
3
3
|
from typing import Any, Optional, Union, cast
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
+
from typing_extensions import deprecated
|
|
6
7
|
|
|
7
8
|
from replay.data.nn import TensorMap, TensorSchema
|
|
8
9
|
|
|
9
10
|
|
|
11
|
+
@deprecated(
|
|
12
|
+
"`SasRecModel` class is deprecated. Use `replay.nn.sequential.SasRec` instead.",
|
|
13
|
+
stacklevel=2,
|
|
14
|
+
)
|
|
10
15
|
class SasRecModel(torch.nn.Module):
|
|
11
16
|
"""
|
|
12
17
|
SasRec model
|
|
@@ -110,7 +115,7 @@ class SasRecModel(torch.nn.Module):
|
|
|
110
115
|
) -> torch.Tensor:
|
|
111
116
|
"""
|
|
112
117
|
:param feature_tensor: Batch of features.
|
|
113
|
-
:param padding_mask: Padding mask where 0 -
|
|
118
|
+
:param padding_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
|
|
114
119
|
|
|
115
120
|
:returns: Calculated scores.
|
|
116
121
|
"""
|
|
@@ -127,11 +132,11 @@ class SasRecModel(torch.nn.Module):
|
|
|
127
132
|
) -> torch.Tensor:
|
|
128
133
|
"""
|
|
129
134
|
:param feature_tensor: Batch of features.
|
|
130
|
-
:param padding_mask: Padding mask where 0 -
|
|
131
|
-
:param candidates_to_score: Item ids to calculate scores
|
|
132
|
-
|
|
135
|
+
:param padding_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
|
|
136
|
+
:param candidates_to_score: Item ids to calculate scores.\n
|
|
137
|
+
If ``None`` then predicts for all items. Default: ``None``.
|
|
133
138
|
|
|
134
|
-
:returns: Prediction among canditates_to_score items.
|
|
139
|
+
:returns: Prediction among ``canditates_to_score`` items.
|
|
135
140
|
"""
|
|
136
141
|
# final_emb: [B x E]
|
|
137
142
|
final_emb = self.get_query_embeddings(feature_tensor, padding_mask)
|
|
@@ -145,7 +150,7 @@ class SasRecModel(torch.nn.Module):
|
|
|
145
150
|
):
|
|
146
151
|
"""
|
|
147
152
|
:param feature_tensor: Batch of features.
|
|
148
|
-
:param padding_mask: Padding mask where 0 -
|
|
153
|
+
:param padding_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
|
|
149
154
|
|
|
150
155
|
:returns: Query embeddings.
|
|
151
156
|
"""
|
|
@@ -158,7 +163,7 @@ class SasRecModel(torch.nn.Module):
|
|
|
158
163
|
) -> torch.Tensor:
|
|
159
164
|
"""
|
|
160
165
|
:param feature_tensor: Batch of features.
|
|
161
|
-
:param padding_mask: Padding mask where 0 -
|
|
166
|
+
:param padding_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
|
|
162
167
|
|
|
163
168
|
:returns: Output embeddings.
|
|
164
169
|
"""
|
|
@@ -176,9 +181,9 @@ class SasRecModel(torch.nn.Module):
|
|
|
176
181
|
|
|
177
182
|
def get_logits(self, out_embeddings: torch.Tensor, item_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
|
178
183
|
"""
|
|
179
|
-
Apply head to output embeddings of
|
|
184
|
+
Apply head to output embeddings of ``forward_step``.
|
|
180
185
|
|
|
181
|
-
:param out_embeddings: Embeddings after
|
|
186
|
+
:param out_embeddings: Embeddings after ``forward step``.
|
|
182
187
|
:param item_ids: Item ids to calculate scores.
|
|
183
188
|
Default: ``None``.
|
|
184
189
|
|
replay/nn/__init__.py
ADDED
replay/nn/agg.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
from typing import Protocol
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from replay.data.nn.schema import TensorMap
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AggregatorProto(Protocol):
|
|
10
|
+
"""Class-protocol for working with embedding aggregation functions"""
|
|
11
|
+
|
|
12
|
+
def forward(
|
|
13
|
+
self,
|
|
14
|
+
feature_tensors: TensorMap,
|
|
15
|
+
) -> torch.Tensor: ...
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def embedding_dim(self) -> int: ...
|
|
19
|
+
|
|
20
|
+
def reset_parameters(self) -> None: ...
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SumAggregator(torch.nn.Module):
|
|
24
|
+
"""
|
|
25
|
+
The class summarizes the incoming embeddings.
|
|
26
|
+
Note that for successful aggregation, the dimensions of all embeddings must match.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, embedding_dim: int) -> None:
|
|
30
|
+
"""
|
|
31
|
+
:param embedding_dim: The last dimension of incoming and outcoming embeddings.
|
|
32
|
+
"""
|
|
33
|
+
super().__init__()
|
|
34
|
+
self._embedding_dim = embedding_dim
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def embedding_dim(self) -> int:
|
|
38
|
+
"""The dimension of the output embedding"""
|
|
39
|
+
return self._embedding_dim
|
|
40
|
+
|
|
41
|
+
def reset_parameters(self) -> None:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
def forward(self, feature_tensors: TensorMap) -> torch.Tensor:
|
|
45
|
+
"""
|
|
46
|
+
:param feature_tensors: a dictionary of tensors to sum up.
|
|
47
|
+
The dimensions of all tensors in the dictionary must match.
|
|
48
|
+
|
|
49
|
+
:returns: torch.Tensor. The last dimension of the tensor is ``embedding_dim``.
|
|
50
|
+
"""
|
|
51
|
+
out = sum(feature_tensors.values())
|
|
52
|
+
assert out.size(-1) == self.embedding_dim
|
|
53
|
+
return out
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ConcatAggregator(torch.nn.Module):
|
|
57
|
+
"""
|
|
58
|
+
The class concatenates incoming embeddings by the last dimension.
|
|
59
|
+
|
|
60
|
+
If you need to concatenate several embeddings,
|
|
61
|
+
then a linear layer will be applied to get the last dimension equal to ``embedding_dim``.\n
|
|
62
|
+
If only one embedding comes to the input, then its last dimension is expected to be equal to ``embedding_dim``.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
input_embedding_dims: list[int],
|
|
68
|
+
output_embedding_dim: int,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""
|
|
71
|
+
:param input_embedding_dims: Dimensions of incoming embeddings.
|
|
72
|
+
:param output_embedding_dim: The dimension of the output embedding after concatenation.
|
|
73
|
+
"""
|
|
74
|
+
super().__init__()
|
|
75
|
+
self._embedding_dim = output_embedding_dim
|
|
76
|
+
embedding_concat_size = sum(input_embedding_dims)
|
|
77
|
+
self.feat_projection = None
|
|
78
|
+
if len(input_embedding_dims) > 1:
|
|
79
|
+
self.feat_projection = torch.nn.Linear(embedding_concat_size, self.embedding_dim)
|
|
80
|
+
elif embedding_concat_size != self.embedding_dim:
|
|
81
|
+
msg = f"Input embedding dim is not equal to embedding_dim ({embedding_concat_size} != {self.embedding_dim})"
|
|
82
|
+
raise ValueError(msg)
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def embedding_dim(self) -> int:
|
|
86
|
+
"""The dimension of the output embedding"""
|
|
87
|
+
return self._embedding_dim
|
|
88
|
+
|
|
89
|
+
def reset_parameters(self) -> None:
|
|
90
|
+
for _, param in self.named_parameters():
|
|
91
|
+
with contextlib.suppress(ValueError):
|
|
92
|
+
torch.nn.init.xavier_normal_(param.data)
|
|
93
|
+
|
|
94
|
+
def forward(self, feature_tensors: TensorMap) -> torch.Tensor:
|
|
95
|
+
"""
|
|
96
|
+
To ensure the deterministic nature of the result,
|
|
97
|
+
the embeddings are concatenated in the ascending order of the keys in the dictionary.
|
|
98
|
+
|
|
99
|
+
:param feature_tensors: a dictionary of tensors to concatenate.
|
|
100
|
+
|
|
101
|
+
:returns: The last dimension of the tensor is ``embedding_dim``.
|
|
102
|
+
"""
|
|
103
|
+
# To maintain determinism, we concatenate the tensors in sorted order by names.
|
|
104
|
+
sorted_names = sorted(feature_tensors.keys())
|
|
105
|
+
out = torch.cat([feature_tensors[name] for name in sorted_names], dim=-1)
|
|
106
|
+
if self.feat_projection is not None:
|
|
107
|
+
out = self.feat_projection(out)
|
|
108
|
+
assert out.size(-1) == self.embedding_dim
|
|
109
|
+
return out
|
replay/nn/attention.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import math
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MultiHeadDifferentialAttention(torch.nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
Multi-Head Differential Attention Mechanism.
|
|
11
|
+
Replaces the conventional softmax attention with a differential attention.
|
|
12
|
+
Incorporattes a causal mask (if other not specified) to ensure autoregressive behavior.
|
|
13
|
+
|
|
14
|
+
Source paper: https://arxiv.org/pdf/2410.05258
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
embedding_dim: int,
|
|
20
|
+
num_heads: int,
|
|
21
|
+
lambda_init: float,
|
|
22
|
+
bias: bool = False,
|
|
23
|
+
kdim: Optional[int] = None,
|
|
24
|
+
vdim: Optional[int] = None,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
:param embedding_dim: Total dimension of the model. Must be divisible by ``num_heads``.
|
|
28
|
+
:param num_heads: Number of parallel attention heads.
|
|
29
|
+
:param lambda_init: Initial value for lambda.
|
|
30
|
+
:param bias: If specified, adds bias to input / output projection layers. Default: ``False``.
|
|
31
|
+
:param kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embedding_dim``).
|
|
32
|
+
:param vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embedding_dim``).
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
kdim = kdim or embedding_dim
|
|
36
|
+
vdim = vdim or embedding_dim
|
|
37
|
+
assert kdim % num_heads == 0, "Query/Key embedding dim is not divisible by num_heads"
|
|
38
|
+
assert vdim % num_heads == 0, "Value embedding dim is not divisible by num_heads"
|
|
39
|
+
self.qk_e_head = kdim // num_heads
|
|
40
|
+
self.v_e_head = vdim // num_heads
|
|
41
|
+
self.num_heads = num_heads
|
|
42
|
+
|
|
43
|
+
# Linear projections for queries, keys, and values
|
|
44
|
+
# Project to 2 * d_head per head for differential attention
|
|
45
|
+
self.W_q = torch.nn.Linear(embedding_dim, 2 * self.qk_e_head * num_heads, bias=bias)
|
|
46
|
+
self.W_k = torch.nn.Linear(embedding_dim, 2 * self.qk_e_head * num_heads, bias=bias)
|
|
47
|
+
self.W_v = torch.nn.Linear(embedding_dim, self.v_e_head * num_heads, bias=bias)
|
|
48
|
+
self.W_o = torch.nn.Linear(self.v_e_head * num_heads, embedding_dim, bias=bias)
|
|
49
|
+
|
|
50
|
+
# Learnable parameters for lambda reparameterization
|
|
51
|
+
self.lambda_q1 = torch.nn.Parameter(torch.randn(num_heads, self.qk_e_head))
|
|
52
|
+
self.lambda_k1 = torch.nn.Parameter(torch.randn(num_heads, self.qk_e_head))
|
|
53
|
+
self.lambda_q2 = torch.nn.Parameter(torch.randn(num_heads, self.qk_e_head))
|
|
54
|
+
self.lambda_k2 = torch.nn.Parameter(torch.randn(num_heads, self.qk_e_head))
|
|
55
|
+
self.register_buffer("scaling", torch.asarray(1 / math.sqrt(self.qk_e_head), dtype=torch.float32))
|
|
56
|
+
|
|
57
|
+
self.lambda_init = lambda_init
|
|
58
|
+
|
|
59
|
+
# Scale parameter for RMSNorm
|
|
60
|
+
self.rms_scale = torch.nn.Parameter(torch.ones(self.v_e_head))
|
|
61
|
+
self.eps = 1e-5 # Epsilon for numerical stability
|
|
62
|
+
|
|
63
|
+
def reset_parameters(self) -> None:
|
|
64
|
+
for _, param in self.named_parameters():
|
|
65
|
+
with contextlib.suppress(ValueError):
|
|
66
|
+
torch.nn.init.xavier_normal_(param.data)
|
|
67
|
+
|
|
68
|
+
def forward(
|
|
69
|
+
self,
|
|
70
|
+
query: torch.Tensor,
|
|
71
|
+
key: torch.Tensor,
|
|
72
|
+
value: torch.Tensor,
|
|
73
|
+
attn_mask: torch.FloatTensor,
|
|
74
|
+
) -> torch.Tensor:
|
|
75
|
+
"""
|
|
76
|
+
Forward pass for Multi-Head Differential Attention.
|
|
77
|
+
|
|
78
|
+
:param query: Query sequence of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
79
|
+
:param key: Key sequence of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
80
|
+
:param value: Value sequence of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
81
|
+
:param attn_mask: attention mask, where ``-inf`` for ``PAD``, ``0`` - otherwise.\n
|
|
82
|
+
Possible shapes:\n
|
|
83
|
+
1. ``(batch_size * num_heads, sequence_length, sequence_length)``
|
|
84
|
+
2. ``(batch_size, num_heads, sequence_length, sequence_length)``
|
|
85
|
+
:returns: torch.Tensor: Output tensor after applying differential attention.
|
|
86
|
+
"""
|
|
87
|
+
batch_size, seq_len, _ = value.shape
|
|
88
|
+
|
|
89
|
+
# Project inputs to queries, keys, and values
|
|
90
|
+
query = self.W_q(query) # Shape: (batch_size, seq_len, 2 * num_heads * qk_e_head)
|
|
91
|
+
key = self.W_k(key) # Shape: (batch_size, seq_len, 2 * num_heads * qk_e_head)
|
|
92
|
+
value = self.W_v(value) # Shape: (batch_size, seq_len, num_heads * v_e_head)
|
|
93
|
+
|
|
94
|
+
# Reshape and permute for multi-head attention
|
|
95
|
+
# New shape: (batch_size, num_heads, sequence_length, 2 * qk_e_head or v_e_head)
|
|
96
|
+
query = query.view(batch_size, seq_len, self.num_heads, 2 * self.qk_e_head).transpose(1, 2)
|
|
97
|
+
key = key.view(batch_size, seq_len, self.num_heads, 2 * self.qk_e_head).transpose(1, 2)
|
|
98
|
+
value = value.view(batch_size, seq_len, self.num_heads, self.v_e_head).transpose(1, 2)
|
|
99
|
+
|
|
100
|
+
# Split query and key into query1, query2 and key1, key2
|
|
101
|
+
query1, query2 = query.chunk(2, dim=-1) # Each of shape: (batch_size, num_heads, seq_len, d_head)
|
|
102
|
+
key1, key2 = key.chunk(2, dim=-1) # Each of shape: (batch_size, num_heads, seq_len, d_head)
|
|
103
|
+
|
|
104
|
+
# Compute lambda using reparameterization
|
|
105
|
+
# Compute dot products for each head
|
|
106
|
+
# Shape of lambda_val: (num_heads,)
|
|
107
|
+
lambda_q1_dot_k1 = torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() # (num_heads,)
|
|
108
|
+
lambda_q2_dot_k2 = torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() # (num_heads,)
|
|
109
|
+
lambda_val = torch.exp(lambda_q1_dot_k1) - torch.exp(lambda_q2_dot_k2) + self.lambda_init # (num_heads,)
|
|
110
|
+
|
|
111
|
+
# Expand lambda_val to match attention dimensions (batch_size, num_heads, 1, 1)
|
|
112
|
+
lambda_val = lambda_val.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
113
|
+
|
|
114
|
+
# Reshape attn_mask from 3D to 4D
|
|
115
|
+
if len(attn_mask.shape) == 3:
|
|
116
|
+
attn_mask = attn_mask.reshape(attn_mask.shape[0] // self.num_heads, self.num_heads, *attn_mask.shape[1:])
|
|
117
|
+
|
|
118
|
+
# check shapes
|
|
119
|
+
assert attn_mask.dim() == 4
|
|
120
|
+
assert attn_mask.size() == (batch_size, self.num_heads, seq_len, seq_len)
|
|
121
|
+
|
|
122
|
+
# Compute attention scores
|
|
123
|
+
attention_scores1 = torch.matmul(query1, key1.transpose(-2, -1)) * self.get_buffer(
|
|
124
|
+
"scaling"
|
|
125
|
+
) # (batch_size, num_heads, seq_len, seq_len)
|
|
126
|
+
attention_scores2 = torch.matmul(query2, key2.transpose(-2, -1)) * self.get_buffer(
|
|
127
|
+
"scaling"
|
|
128
|
+
) # (batch_size, num_heads, seq_len, seq_len)
|
|
129
|
+
|
|
130
|
+
# Apply the causal mask
|
|
131
|
+
attention_scores1 = attention_scores1 + attn_mask # Mask out future positions
|
|
132
|
+
attention_scores2 = attention_scores2 + attn_mask # Mask out future positions
|
|
133
|
+
|
|
134
|
+
# Apply softmax to get attention weights
|
|
135
|
+
attention1 = torch.nn.functional.softmax(attention_scores1, dim=-1) # (batch_size, num_heads, seq_len, seq_len)
|
|
136
|
+
attention2 = torch.nn.functional.softmax(attention_scores2, dim=-1)
|
|
137
|
+
attention = attention1 - lambda_val * attention2
|
|
138
|
+
|
|
139
|
+
# Apply attention weights to values
|
|
140
|
+
output = torch.matmul(attention, value) # (batch_size, num_heads, seq_len, v_e_head)
|
|
141
|
+
|
|
142
|
+
# Normalize each head independently using RMSNorm
|
|
143
|
+
# Compute RMSNorm
|
|
144
|
+
rms_norm = torch.sqrt(output.pow(2).mean(dim=-1, keepdim=True) + self.eps) # (batch_size*num_heads, seq_len, 1)
|
|
145
|
+
output_normalized = (output / rms_norm) * self.rms_scale # (batch*num_heads, seq_len, v_e_head)
|
|
146
|
+
|
|
147
|
+
# Scale the normalized output
|
|
148
|
+
output_normalized = output_normalized * (1 - self.lambda_init) # Scalar scaling
|
|
149
|
+
|
|
150
|
+
# Concatenate all heads
|
|
151
|
+
# New shape: (batch_size, seq_len, num_heads * v_e_head)
|
|
152
|
+
output_concat = (
|
|
153
|
+
output_normalized.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.v_e_head)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Final linear projection
|
|
157
|
+
output_projection = self.W_o(output_concat) # (batch_size, seq_len, embedding_dim)
|
|
158
|
+
return output_projection
|
replay/nn/embedding.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Literal, Optional, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from replay.data.nn.schema import TensorFeatureInfo, TensorMap, TensorSchema
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SequenceEmbedding(torch.nn.Module):
|
|
12
|
+
"""
|
|
13
|
+
The embedding generation class for all types of features given into the sequential models.
|
|
14
|
+
|
|
15
|
+
The embedding size for each feature will be taken from ``TensorSchema`` (from field named ``embedding_dim``).
|
|
16
|
+
For numerical features, it is expected that the last dimension of the tensor will be equal
|
|
17
|
+
to ``tensor_dim`` field in ``TensorSchema``.
|
|
18
|
+
|
|
19
|
+
Keep in mind that the first dimension of the every categorical embedding (the size of embedding table)
|
|
20
|
+
will equal to the ``cardinality`` + 1. This is necessary to take into account the padding value.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
schema: TensorSchema,
|
|
26
|
+
excluded_features: Optional[list[str]] = None,
|
|
27
|
+
categorical_list_feature_aggregation_method: Literal["sum", "mean", "max"] = "sum",
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
:param schema: TensorSchema containing meta information about all the features
|
|
31
|
+
for which you need to generate an embedding.
|
|
32
|
+
:param excluded_features: A list containing the names of features
|
|
33
|
+
for which you do not need to generate an embedding.
|
|
34
|
+
Fragments from this list are expected to be contained in ``schema``.
|
|
35
|
+
Default: ``None``.
|
|
36
|
+
:param categorical_list_feature_aggregation_method: Mode to aggregate tokens
|
|
37
|
+
in token item representation (categorical list only).
|
|
38
|
+
Default: ``"sum"``.
|
|
39
|
+
"""
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.excluded_features = excluded_features or []
|
|
42
|
+
feature_embedders = {}
|
|
43
|
+
|
|
44
|
+
for feature_name, tensor_info in schema.items():
|
|
45
|
+
if feature_name in self.excluded_features:
|
|
46
|
+
continue
|
|
47
|
+
if not tensor_info.is_seq:
|
|
48
|
+
msg = f"Non-sequential features is not yet supported. Got {feature_name}"
|
|
49
|
+
raise NotImplementedError(msg)
|
|
50
|
+
if tensor_info.is_cat:
|
|
51
|
+
feature_embedders[feature_name] = CategoricalEmbedding(
|
|
52
|
+
tensor_info,
|
|
53
|
+
categorical_list_feature_aggregation_method,
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
feature_embedders[feature_name] = NumericalEmbedding(tensor_info)
|
|
57
|
+
|
|
58
|
+
self.feature_names = list(feature_embedders.keys())
|
|
59
|
+
if not feature_embedders:
|
|
60
|
+
msg = "Expected to have at least one feature name to generate embedding."
|
|
61
|
+
raise ValueError(msg)
|
|
62
|
+
self.feature_embedders: dict[str, Union[CategoricalEmbedding, NumericalEmbedding]] = torch.nn.ModuleDict(
|
|
63
|
+
feature_embedders
|
|
64
|
+
)
|
|
65
|
+
self._item_feature_name = schema.item_id_feature_name
|
|
66
|
+
|
|
67
|
+
def reset_parameters(self) -> None:
|
|
68
|
+
for _, param in self.named_parameters():
|
|
69
|
+
with contextlib.suppress(ValueError):
|
|
70
|
+
torch.nn.init.xavier_normal_(param.data)
|
|
71
|
+
|
|
72
|
+
def forward(self, feature_tensor: TensorMap, feature_names: Optional[Sequence[str]] = None) -> TensorMap:
|
|
73
|
+
"""
|
|
74
|
+
:param feature_tensor: a dictionary of tensors to generate embedding.
|
|
75
|
+
It is expected that the keys from this dictionary match the names of the features in the given ``schema``.
|
|
76
|
+
:param feature_names: A custom list of features for which embeddings need to be generated.
|
|
77
|
+
It is expected that the values from this list match the names of the features in the given ``schema``.\n
|
|
78
|
+
Default: ``None``. This means that the names of the features from the ``schema`` will be used.
|
|
79
|
+
|
|
80
|
+
:returns: a dictionary with tensors that contains embeddings.
|
|
81
|
+
"""
|
|
82
|
+
return {
|
|
83
|
+
feature_name: self.feature_embedders[feature_name](feature_tensor[feature_name])
|
|
84
|
+
for feature_name in (feature_names or self.feature_names)
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def embeddings_dim(self) -> dict[str, int]:
|
|
89
|
+
"""
|
|
90
|
+
Returns the embedding dimensions for each of the features in the `schema`.
|
|
91
|
+
"""
|
|
92
|
+
return {name: emb.embedding_dim for name, emb in self.feature_embedders.items()}
|
|
93
|
+
|
|
94
|
+
def get_item_weights(self, indices: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
|
95
|
+
"""
|
|
96
|
+
Getting the embedding weights for a feature that matches the item id feature
|
|
97
|
+
with the name specified in the ``schema``.
|
|
98
|
+
It is expected that embeddings for this feature will definitely exist.
|
|
99
|
+
**Note**: the row corresponding to the padding will be excluded from the returned weights.
|
|
100
|
+
This logic will work if given ``indices`` is ``None``.
|
|
101
|
+
|
|
102
|
+
:param indices: Items indices.
|
|
103
|
+
:returns: Embeddings for specific items.
|
|
104
|
+
"""
|
|
105
|
+
if indices is None:
|
|
106
|
+
return self.feature_embedders[self._item_feature_name].weight
|
|
107
|
+
return self.feature_embedders[self._item_feature_name](indices)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class CategoricalEmbedding(torch.nn.Module):
|
|
111
|
+
"""
|
|
112
|
+
The embedding generation class for categorical features.
|
|
113
|
+
It supports working with single features for each event in sequence, as well as several (categorical list).
|
|
114
|
+
|
|
115
|
+
When using this class, keep in mind that
|
|
116
|
+
the first dimension of the embedding (the size of embedding table) will equal to the ``cardinality`` + 1.
|
|
117
|
+
This is necessary to take into account the padding value.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
feature_info: TensorFeatureInfo,
|
|
123
|
+
categorical_list_feature_aggregation_method: Literal["sum", "mean", "max"] = "sum",
|
|
124
|
+
) -> None:
|
|
125
|
+
"""
|
|
126
|
+
:param feature_info: Meta information about the feature.
|
|
127
|
+
:param categorical_list_feature_aggregation_method: Mode to aggregate tokens
|
|
128
|
+
in token item representation (categorical list only). One of {`sum`, `mean`, `max`}
|
|
129
|
+
Default: ``"sum"``.
|
|
130
|
+
"""
|
|
131
|
+
super().__init__()
|
|
132
|
+
assert feature_info.cardinality
|
|
133
|
+
assert feature_info.embedding_dim
|
|
134
|
+
|
|
135
|
+
self._expect_padding_value_setted = True
|
|
136
|
+
if feature_info.cardinality != feature_info.padding_value:
|
|
137
|
+
self._expect_padding_value_setted = False
|
|
138
|
+
msg = (
|
|
139
|
+
f"The padding value={feature_info.padding_value} is set for the feature={feature_info.name}. "
|
|
140
|
+
f"The expected padding value for this feature should be {feature_info.cardinality}. "
|
|
141
|
+
"Keep this in mind when getting the weights via the `weight` property, "
|
|
142
|
+
"because the weights are returned there without padding row. "
|
|
143
|
+
"Therefore, during the IDs scores generating, "
|
|
144
|
+
"all the IDs that greater than the padding value should be increased by 1."
|
|
145
|
+
)
|
|
146
|
+
warnings.warn(msg, stacklevel=2)
|
|
147
|
+
|
|
148
|
+
if feature_info.is_list:
|
|
149
|
+
self.emb = torch.nn.EmbeddingBag(
|
|
150
|
+
feature_info.cardinality + 1,
|
|
151
|
+
feature_info.embedding_dim,
|
|
152
|
+
padding_idx=feature_info.padding_value,
|
|
153
|
+
mode=categorical_list_feature_aggregation_method,
|
|
154
|
+
)
|
|
155
|
+
self._get_embeddings = self._get_cat_list_embeddings
|
|
156
|
+
else:
|
|
157
|
+
self.emb = torch.nn.Embedding(
|
|
158
|
+
feature_info.cardinality + 1,
|
|
159
|
+
feature_info.embedding_dim,
|
|
160
|
+
padding_idx=feature_info.padding_value,
|
|
161
|
+
)
|
|
162
|
+
self._get_embeddings = self._get_cat_embeddings
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def weight(self) -> torch.Tensor:
|
|
166
|
+
"""
|
|
167
|
+
Returns the weights of the embedding layer,
|
|
168
|
+
excluding the row that corresponds to the padding.
|
|
169
|
+
"""
|
|
170
|
+
if not self._expect_padding_value_setted:
|
|
171
|
+
msg = (
|
|
172
|
+
"The weights are returned there do not contain padding row. "
|
|
173
|
+
"Therefore, during the IDs scores generating, "
|
|
174
|
+
"all the IDs that greater than the padding value should be increased by 1."
|
|
175
|
+
)
|
|
176
|
+
warnings.warn(msg, stacklevel=2)
|
|
177
|
+
|
|
178
|
+
mask_without_padding = torch.ones(
|
|
179
|
+
size=(self.emb.weight.size(0),),
|
|
180
|
+
dtype=torch.bool,
|
|
181
|
+
device=self.emb.weight.device,
|
|
182
|
+
)
|
|
183
|
+
mask_without_padding[self.emb.padding_idx].zero_()
|
|
184
|
+
return self.emb.weight[mask_without_padding]
|
|
185
|
+
|
|
186
|
+
def forward(self, indices: torch.LongTensor) -> torch.Tensor:
|
|
187
|
+
"""
|
|
188
|
+
:param indices: Items indices.
|
|
189
|
+
|
|
190
|
+
:returns: Embeddings for specific items.
|
|
191
|
+
"""
|
|
192
|
+
return self._get_embeddings(indices)
|
|
193
|
+
|
|
194
|
+
@property
|
|
195
|
+
def embedding_dim(self) -> int:
|
|
196
|
+
"""Embedding dimension after applying the layer"""
|
|
197
|
+
return self.emb.embedding_dim
|
|
198
|
+
|
|
199
|
+
def _get_cat_embeddings(self, indices: torch.LongTensor) -> torch.Tensor:
|
|
200
|
+
"""
|
|
201
|
+
:param indices: Items indices.
|
|
202
|
+
|
|
203
|
+
:returns: Embeddings for specific items.
|
|
204
|
+
"""
|
|
205
|
+
return self.emb(indices)
|
|
206
|
+
|
|
207
|
+
def _get_cat_list_embeddings(self, indices: torch.LongTensor) -> torch.Tensor:
|
|
208
|
+
"""
|
|
209
|
+
:param indices: Items indices.
|
|
210
|
+
|
|
211
|
+
:returns: Embeddings for specific items.
|
|
212
|
+
"""
|
|
213
|
+
assert indices.dim() >= 2
|
|
214
|
+
if indices.dim() == 2:
|
|
215
|
+
embeddings: torch.Tensor = self.emb(indices)
|
|
216
|
+
else:
|
|
217
|
+
source_size = indices.size()
|
|
218
|
+
indices = indices.view(-1, source_size[-1])
|
|
219
|
+
embeddings = self.emb(indices)
|
|
220
|
+
embeddings = embeddings.view(*source_size[:-1], -1)
|
|
221
|
+
return embeddings
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class NumericalEmbedding(torch.nn.Module):
|
|
225
|
+
"""
|
|
226
|
+
The embedding generation class for numerical features.
|
|
227
|
+
It supports working with single features for each event in sequence, as well as several (numerical list).
|
|
228
|
+
|
|
229
|
+
**Note**: if the ``embedding_dim`` field in ``TensorFeatureInfo`` for an incoming feature matches its last dimension
|
|
230
|
+
(``tensor_dim`` field in ``TensorFeatureInfo``), then transformation will not be applied.
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
def __init__(self, feature_info: TensorFeatureInfo) -> None:
|
|
234
|
+
"""
|
|
235
|
+
:param feature_info: Meta information about the feature.
|
|
236
|
+
"""
|
|
237
|
+
super().__init__()
|
|
238
|
+
assert feature_info.tensor_dim
|
|
239
|
+
assert feature_info.embedding_dim
|
|
240
|
+
self._tensor_dim = feature_info.tensor_dim
|
|
241
|
+
self._embedding_dim = feature_info.embedding_dim
|
|
242
|
+
self.linear = torch.nn.Linear(feature_info.tensor_dim, self.embedding_dim)
|
|
243
|
+
|
|
244
|
+
if feature_info.is_list:
|
|
245
|
+
if self.embedding_dim == feature_info.tensor_dim:
|
|
246
|
+
torch.nn.init.eye_(self.linear.weight.data)
|
|
247
|
+
torch.nn.init.zeros_(self.linear.bias.data)
|
|
248
|
+
|
|
249
|
+
self.linear.weight.requires_grad = False
|
|
250
|
+
self.linear.bias.requires_grad = False
|
|
251
|
+
else:
|
|
252
|
+
assert feature_info.tensor_dim == 1
|
|
253
|
+
self.linear = torch.nn.Linear(feature_info.tensor_dim, self.embedding_dim)
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def weight(self) -> torch.Tensor:
|
|
257
|
+
"""
|
|
258
|
+
Returns the weight of the applied layer.
|
|
259
|
+
If ``embedding_dim`` matches ``tensor_dim``, then the identity matrix will be returned.
|
|
260
|
+
"""
|
|
261
|
+
return self.linear.weight
|
|
262
|
+
|
|
263
|
+
def forward(self, values: torch.FloatTensor) -> torch.Tensor:
|
|
264
|
+
"""
|
|
265
|
+
Numerical embedding forward pass.\n
|
|
266
|
+
**Note**: if the ``embedding_dim`` for an incoming feature matches its last dimension (``tensor_dim``),
|
|
267
|
+
then transformation will not be applied.
|
|
268
|
+
|
|
269
|
+
:param values: feature values.
|
|
270
|
+
:returns: Embeddings for specific items.
|
|
271
|
+
"""
|
|
272
|
+
if values.dim() <= 2 and self._tensor_dim == 1:
|
|
273
|
+
values = values.unsqueeze(-1).contiguous()
|
|
274
|
+
|
|
275
|
+
assert values.size(-1) == self._tensor_dim
|
|
276
|
+
if self._tensor_dim != self.embedding_dim:
|
|
277
|
+
return self.linear(values)
|
|
278
|
+
return values
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def embedding_dim(self) -> int:
|
|
282
|
+
"""Embedding dimension after applying the layer"""
|
|
283
|
+
return self._embedding_dim
|