replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay_rec-0.20.3.dist-info/RECORD +0 -138
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,674 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Optional, Protocol, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from replay.data.nn import TensorMap, TensorSchema
|
|
7
|
+
from replay.nn.agg import AggregatorProto
|
|
8
|
+
from replay.nn.head import EmbeddingTyingHead
|
|
9
|
+
from replay.nn.loss import LossProto
|
|
10
|
+
from replay.nn.mask import AttentionMaskProto
|
|
11
|
+
from replay.nn.normalization import NormalizerProto
|
|
12
|
+
from replay.nn.output import InferenceOutput, TrainOutput
|
|
13
|
+
from replay.nn.utils import warning_is_not_none
|
|
14
|
+
|
|
15
|
+
from .reader import FeaturesReaderProtocol
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class EmbedderProto(Protocol):
|
|
19
|
+
@property
|
|
20
|
+
def feature_names(self) -> Sequence[str]: ...
|
|
21
|
+
|
|
22
|
+
def forward(
|
|
23
|
+
self,
|
|
24
|
+
feature_tensors: TensorMap,
|
|
25
|
+
feature_names: Optional[Sequence[str]] = None,
|
|
26
|
+
) -> TensorMap: ...
|
|
27
|
+
|
|
28
|
+
def reset_parameters(self) -> None: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class QueryEncoderProto(Protocol):
|
|
32
|
+
def forward(
|
|
33
|
+
self,
|
|
34
|
+
feature_tensors: TensorMap,
|
|
35
|
+
input_embeddings: torch.Tensor,
|
|
36
|
+
padding_mask: torch.BoolTensor,
|
|
37
|
+
attention_mask: torch.Tensor,
|
|
38
|
+
) -> torch.Tensor: ...
|
|
39
|
+
|
|
40
|
+
def reset_parameters(self) -> None: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ItemEncoderProto(Protocol):
|
|
44
|
+
def forward(
|
|
45
|
+
self,
|
|
46
|
+
feature_tensors: TensorMap,
|
|
47
|
+
input_embeddings: torch.Tensor,
|
|
48
|
+
) -> torch.Tensor: ...
|
|
49
|
+
|
|
50
|
+
def reset_parameters(self) -> None: ...
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class QueryTower(torch.nn.Module):
|
|
54
|
+
"""
|
|
55
|
+
Query Tower of Two-Tower model.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
feature_names: Sequence[str],
|
|
61
|
+
embedder: EmbedderProto,
|
|
62
|
+
embedding_aggregator: AggregatorProto,
|
|
63
|
+
attn_mask_builder: AttentionMaskProto,
|
|
64
|
+
encoder: QueryEncoderProto,
|
|
65
|
+
output_normalization: NormalizerProto,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
:param feature_names: sequence of names used in query tower.
|
|
69
|
+
:param embedder: An object of a class that performs the logic of
|
|
70
|
+
generating embeddings from an input batch.
|
|
71
|
+
:param embedding_aggregator: An object of a class that performs
|
|
72
|
+
the logic of aggregating multiple embeddings of query tower.
|
|
73
|
+
:param attn_mask_builder: An object of a class that performs the logic of
|
|
74
|
+
generating an attention mask based on the features and padding mask given to the model.
|
|
75
|
+
:param encoder: An object of a class that performs the logic of generating
|
|
76
|
+
a query hidden embedding representation based on
|
|
77
|
+
features, padding masks, attention mask, and aggregated embedding of ``query_tower_feature_names``.
|
|
78
|
+
It's supposed to be a transformer.
|
|
79
|
+
:param output_normalization: An object of a class that performs the logic of
|
|
80
|
+
normalization of the hidden state obtained from the query encoder.\n
|
|
81
|
+
For example, it can be a ``torch.nn.LayerNorm`` or ``torch.nn.RMSNorm``.
|
|
82
|
+
"""
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.embedder = embedder
|
|
85
|
+
self.attn_mask_builder = attn_mask_builder
|
|
86
|
+
self.feature_names = feature_names
|
|
87
|
+
self.embedding_aggregator = embedding_aggregator
|
|
88
|
+
self.encoder = encoder
|
|
89
|
+
self.output_normalization = output_normalization
|
|
90
|
+
|
|
91
|
+
def reset_parameters(self) -> None:
|
|
92
|
+
self.embedding_aggregator.reset_parameters()
|
|
93
|
+
self.encoder.reset_parameters()
|
|
94
|
+
self.output_normalization.reset_parameters()
|
|
95
|
+
|
|
96
|
+
def forward(
|
|
97
|
+
self,
|
|
98
|
+
feature_tensors: TensorMap,
|
|
99
|
+
padding_mask: torch.BoolTensor,
|
|
100
|
+
) -> torch.Tensor:
|
|
101
|
+
"""
|
|
102
|
+
:param feature_tensors: a dictionary of tensors to generate embeddings.
|
|
103
|
+
:param padding_mask: A mask of shape ``(batch_size, sequence_length)``
|
|
104
|
+
indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding").
|
|
105
|
+
``False`` value indicates that the corresponding ``key`` value will be ignored.
|
|
106
|
+
:returns: The final hidden state.\n
|
|
107
|
+
Expected shape: ``(batch_size, sequence_length, embedding_dim)``
|
|
108
|
+
"""
|
|
109
|
+
embeddings: TensorMap = self.embedder(feature_tensors, self.feature_names)
|
|
110
|
+
agg_emb: torch.Tensor = self.embedding_aggregator(embeddings)
|
|
111
|
+
assert agg_emb.dim() == 3
|
|
112
|
+
|
|
113
|
+
attn_mask = self.attn_mask_builder(feature_tensors, padding_mask)
|
|
114
|
+
|
|
115
|
+
hidden_state: torch.Tensor = self.encoder(
|
|
116
|
+
feature_tensors=feature_tensors,
|
|
117
|
+
input_embeddings=agg_emb,
|
|
118
|
+
padding_mask=padding_mask,
|
|
119
|
+
attention_mask=attn_mask,
|
|
120
|
+
)
|
|
121
|
+
assert agg_emb.size() == hidden_state.size()
|
|
122
|
+
|
|
123
|
+
hidden_state = self.output_normalization(hidden_state)
|
|
124
|
+
return hidden_state
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class ItemTower(torch.nn.Module):
|
|
128
|
+
"""
|
|
129
|
+
Item Tower of Two-Tower model.
|
|
130
|
+
|
|
131
|
+
**Note**: ItemTower loads feature tensors of all items into memory.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
schema: TensorSchema,
|
|
137
|
+
item_features_reader: FeaturesReaderProtocol,
|
|
138
|
+
feature_names: Sequence[str],
|
|
139
|
+
embedder: EmbedderProto,
|
|
140
|
+
embedding_aggregator: AggregatorProto,
|
|
141
|
+
encoder: ItemEncoderProto,
|
|
142
|
+
):
|
|
143
|
+
"""
|
|
144
|
+
:param schema: tensor schema object with metainformation about features.
|
|
145
|
+
:param item_features_reader: A class that implements reading features,
|
|
146
|
+
processing them, and converting them to ``torch.Tensor`` for ItemTower.
|
|
147
|
+
You can use ``replay.nn.sequential.twotower.FeaturesReader`` as a standard class.\n
|
|
148
|
+
But you can implement your own feature processing,
|
|
149
|
+
just follow the ``replay.nn.sequential.twotower.FeaturesReaderProtocol`` protocol.
|
|
150
|
+
:param feature_names: sequence of names used in item tower.
|
|
151
|
+
:param embedder: An object of a class that performs the logic of
|
|
152
|
+
generating embeddings from an input batch.
|
|
153
|
+
:param embedding_aggregator: An object of a class that performs
|
|
154
|
+
the logic of aggregating multiple embeddings of item tower.
|
|
155
|
+
:param encoder: An object of a class that performs the logic of generating
|
|
156
|
+
an item hidden embedding representation based on
|
|
157
|
+
features and aggregated embeddings of ``item_tower_feature_names``.
|
|
158
|
+
Item encoder uses item reference which is created based on ``item_features_path``.
|
|
159
|
+
"""
|
|
160
|
+
super().__init__()
|
|
161
|
+
self.embedder = embedder
|
|
162
|
+
self.feature_names = feature_names
|
|
163
|
+
self.embedding_aggregator = embedding_aggregator
|
|
164
|
+
self.encoder = encoder
|
|
165
|
+
|
|
166
|
+
for feature_name in schema:
|
|
167
|
+
if feature_name not in self.feature_names:
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
self.register_buffer(f"item_reference_{feature_name}", item_features_reader[feature_name])
|
|
171
|
+
|
|
172
|
+
self.cache = None
|
|
173
|
+
|
|
174
|
+
def reset_parameters(self) -> None:
|
|
175
|
+
self.embedding_aggregator.reset_parameters()
|
|
176
|
+
self.encoder.reset_parameters()
|
|
177
|
+
|
|
178
|
+
def get_feature_buffer(self, feature_name: str) -> torch.Tensor:
|
|
179
|
+
buffer_name = f"item_reference_{feature_name}"
|
|
180
|
+
return self.get_buffer(buffer_name)
|
|
181
|
+
|
|
182
|
+
def forward(
|
|
183
|
+
self,
|
|
184
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
185
|
+
):
|
|
186
|
+
"""
|
|
187
|
+
:param candidates_to_score: IDs of items using for obtaining item embeddings from item tower.
|
|
188
|
+
If is setted to ``None``, all item embeddings from item tower will be returned.
|
|
189
|
+
Default: ``None``.
|
|
190
|
+
:return: item embeddings.\n
|
|
191
|
+
Expected shape:\n
|
|
192
|
+
- ``(candidates_to_score, embedding_dim)``,
|
|
193
|
+
- ``(items_num, embedding_dim)`` if ``candidates_to_score`` is ``None``.
|
|
194
|
+
"""
|
|
195
|
+
if self.training:
|
|
196
|
+
self.cache = None
|
|
197
|
+
|
|
198
|
+
if not self.training and self.cache is not None:
|
|
199
|
+
if candidates_to_score is None:
|
|
200
|
+
return self.cache
|
|
201
|
+
return self.cache[candidates_to_score]
|
|
202
|
+
|
|
203
|
+
if candidates_to_score is None:
|
|
204
|
+
feature_tensors = {
|
|
205
|
+
feature_name: self.get_feature_buffer(feature_name) for feature_name in self.feature_names
|
|
206
|
+
}
|
|
207
|
+
else:
|
|
208
|
+
feature_tensors = {
|
|
209
|
+
feature_name: self.get_feature_buffer(feature_name)[candidates_to_score]
|
|
210
|
+
for feature_name in self.feature_names
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
embeddings: TensorMap = self.embedder(feature_tensors, self.feature_names)
|
|
214
|
+
agg_emb: torch.Tensor = self.embedding_aggregator(embeddings)
|
|
215
|
+
|
|
216
|
+
hidden_state: torch.Tensor = self.encoder(
|
|
217
|
+
feature_tensors=feature_tensors,
|
|
218
|
+
input_embeddings=agg_emb,
|
|
219
|
+
)
|
|
220
|
+
assert agg_emb.size() == hidden_state.size()
|
|
221
|
+
|
|
222
|
+
if not self.training and self.cache is None and candidates_to_score is None:
|
|
223
|
+
self.cache = hidden_state
|
|
224
|
+
return hidden_state
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class TwoTowerBody(torch.nn.Module):
|
|
228
|
+
"""
|
|
229
|
+
Foundation for Two-Tower model which creates query "tower" and item "tower".\n
|
|
230
|
+
|
|
231
|
+
For usage of two tower model, an instance of this class should be passed into `TwoTower`_ with any loss
|
|
232
|
+
from `Losses`_.
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
schema: TensorSchema,
|
|
238
|
+
embedder: EmbedderProto,
|
|
239
|
+
attn_mask_builder: AttentionMaskProto,
|
|
240
|
+
query_tower_feature_names: Sequence[str],
|
|
241
|
+
item_tower_feature_names: Sequence[str],
|
|
242
|
+
query_embedding_aggregator: AggregatorProto,
|
|
243
|
+
item_embedding_aggregator: AggregatorProto,
|
|
244
|
+
query_encoder: QueryEncoderProto,
|
|
245
|
+
query_tower_output_normalization: NormalizerProto,
|
|
246
|
+
item_encoder: ItemEncoderProto,
|
|
247
|
+
item_features_reader: FeaturesReaderProtocol,
|
|
248
|
+
):
|
|
249
|
+
"""
|
|
250
|
+
:param schema: tensor schema object with metainformation about features.
|
|
251
|
+
:param embedder: An object of a class that performs the logic of
|
|
252
|
+
generating embeddings from an input batch.\n
|
|
253
|
+
The same object is used to generate embeddings in different towers.
|
|
254
|
+
:param query_tower_feature_names: sequence of names used in query tower.
|
|
255
|
+
:param item_tower_feature_names: sequence of names used in item tower.
|
|
256
|
+
:param query_embedding_aggregator: An object of a class that performs
|
|
257
|
+
the logic of aggregating multiple embeddings of query tower.
|
|
258
|
+
:param item_embedding_aggregator: An object of a class that performs
|
|
259
|
+
the logic of aggregating multiple embeddings of item tower.
|
|
260
|
+
:param query_encoder: An object of a class that performs the logic of generating
|
|
261
|
+
a query hidden embedding representation based on
|
|
262
|
+
features, padding masks, attention mask, and aggregated embedding of ``query_tower_feature_names``.
|
|
263
|
+
It's supposed to be a transformer.
|
|
264
|
+
:param query_tower_output_normalization: An object of a class that performs the logic of
|
|
265
|
+
normalization of the hidden state obtained from the query encoder.\n
|
|
266
|
+
For example, it can be a ``torch.nn.LayerNorm`` or ``torch.nn.RMSNorm``.
|
|
267
|
+
:param attn_mask_builder: An object of a class that performs the logic of
|
|
268
|
+
generating an attention mask based on the features and padding mask given to the model.
|
|
269
|
+
:param item_encoder: An object of a class that performs the logic of generating
|
|
270
|
+
an item hidden embedding representation based on
|
|
271
|
+
features and aggregated embeddings of ``item_tower_feature_names``.
|
|
272
|
+
Item encoder uses item reference which is created based on ``item_features_path``.
|
|
273
|
+
:param item_features_reader: A class that implements reading features,
|
|
274
|
+
processing them, and converting them to ``torch.Tensor`` for ItemTower.
|
|
275
|
+
You can use ``replay.nn.sequential.twotower.FeaturesReader`` as a standard class.\n
|
|
276
|
+
But you can implement your own feature processing,
|
|
277
|
+
just follow the ``replay.nn.sequential.twotower.FeaturesReaderProtocol`` protocol.
|
|
278
|
+
|
|
279
|
+
"""
|
|
280
|
+
super().__init__()
|
|
281
|
+
self.embedder = embedder
|
|
282
|
+
feature_names_union = set(query_tower_feature_names) | set(item_tower_feature_names)
|
|
283
|
+
feature_names_not_in_emb = feature_names_union - set(self.embedder.feature_names)
|
|
284
|
+
if len(feature_names_not_in_emb) != 0:
|
|
285
|
+
msg = f"Feature names found that embedder does not support {list(feature_names_not_in_emb)}"
|
|
286
|
+
raise ValueError(msg)
|
|
287
|
+
|
|
288
|
+
self.query_tower = QueryTower(
|
|
289
|
+
query_tower_feature_names,
|
|
290
|
+
embedder,
|
|
291
|
+
query_embedding_aggregator,
|
|
292
|
+
attn_mask_builder,
|
|
293
|
+
query_encoder,
|
|
294
|
+
query_tower_output_normalization,
|
|
295
|
+
)
|
|
296
|
+
self.item_tower = ItemTower(
|
|
297
|
+
schema,
|
|
298
|
+
item_features_reader,
|
|
299
|
+
item_tower_feature_names,
|
|
300
|
+
embedder,
|
|
301
|
+
item_embedding_aggregator,
|
|
302
|
+
item_encoder,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def reset_parameters(self) -> None:
|
|
306
|
+
self.embedder.reset_parameters()
|
|
307
|
+
self.query_tower.reset_parameters()
|
|
308
|
+
self.item_tower.reset_parameters()
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class ContextMergerProto(Protocol):
|
|
312
|
+
def forward(
|
|
313
|
+
self,
|
|
314
|
+
model_hidden_state: torch.Tensor,
|
|
315
|
+
feature_tensors: TensorMap,
|
|
316
|
+
) -> torch.Tensor: ...
|
|
317
|
+
|
|
318
|
+
def reset_parameters(self) -> None: ...
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class TwoTower(torch.nn.Module):
|
|
322
|
+
"""
|
|
323
|
+
Implementation generic Two-Tower architecture with two independent "towers" (encoders)
|
|
324
|
+
which encode separate inputs. In recommender systems they are typically query tower and item tower.
|
|
325
|
+
The output hidden states of each "tower" are fused via dot product in the model head.
|
|
326
|
+
|
|
327
|
+
Source paper: https://doi.org/10.1145/3366424.3386195
|
|
328
|
+
|
|
329
|
+
Example:
|
|
330
|
+
|
|
331
|
+
.. code-block:: python
|
|
332
|
+
|
|
333
|
+
from replay.data import FeatureHint, FeatureSource, FeatureType
|
|
334
|
+
from replay.data.nn import TensorFeatureInfo, TensorFeatureSource, TensorSchema
|
|
335
|
+
from replay.nn.agg import SumAggregator
|
|
336
|
+
from replay.nn.embedding import SequenceEmbedding
|
|
337
|
+
from replay.nn.ffn import SwiGLUEncoder
|
|
338
|
+
from replay.nn.mask import DefaultAttentionMask
|
|
339
|
+
from replay.nn.loss import CESampled
|
|
340
|
+
from replay.nn.sequential import PositionAwareAggregator, SasRecTransformerLayer
|
|
341
|
+
from replay.nn.sequential.twotower import FeaturesReader
|
|
342
|
+
|
|
343
|
+
tensor_schema = TensorSchema(
|
|
344
|
+
[
|
|
345
|
+
TensorFeatureInfo(
|
|
346
|
+
"item_id",
|
|
347
|
+
is_seq=True,
|
|
348
|
+
feature_type=FeatureType.CATEGORICAL,
|
|
349
|
+
embedding_dim=256,
|
|
350
|
+
padding_value=NUM_UNIQUE_ITEMS,
|
|
351
|
+
cardinality=NUM_UNIQUE_ITEMS,
|
|
352
|
+
feature_hint=FeatureHint.ITEM_ID,
|
|
353
|
+
feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "item_id")]
|
|
354
|
+
),
|
|
355
|
+
]
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
common_aggregator = SumAggregator(embedding_dim=256)
|
|
359
|
+
|
|
360
|
+
body = TwoTowerBody(
|
|
361
|
+
schema=tensor_schema,
|
|
362
|
+
embedder=SequenceEmbedding(schema=tensor_schema),
|
|
363
|
+
attn_mask_builder=DefaultAttentionMask(
|
|
364
|
+
reference_feature_name=tensor_schema.item_id_feature_name,
|
|
365
|
+
num_heads=2,
|
|
366
|
+
),
|
|
367
|
+
query_tower_feature_names=tensor_schema.names,
|
|
368
|
+
item_tower_feature_names=tensor_schema.names,
|
|
369
|
+
query_embedding_aggregator=PositionAwareAggregator(
|
|
370
|
+
embedding_aggregator=common_aggregator,
|
|
371
|
+
max_sequence_length=100,
|
|
372
|
+
dropout=0.2,
|
|
373
|
+
),
|
|
374
|
+
item_embedding_aggregator=common_aggregator,
|
|
375
|
+
query_encoder=SasRecTransformerLayer(
|
|
376
|
+
embedding_dim=256,
|
|
377
|
+
num_heads=2,
|
|
378
|
+
num_blocks=2,
|
|
379
|
+
dropout=0.3,
|
|
380
|
+
activation="relu",
|
|
381
|
+
),
|
|
382
|
+
query_tower_output_normalization=torch.nn.LayerNorm(256),
|
|
383
|
+
item_encoder=SwiGLUEncoder(embedding_dim=256, hidden_dim=2*256),
|
|
384
|
+
item_features_reader=FeaturesReader(
|
|
385
|
+
schema=tensor_schema,
|
|
386
|
+
metadata={"item_id": {}},
|
|
387
|
+
path="item_features.parquet",
|
|
388
|
+
),
|
|
389
|
+
)
|
|
390
|
+
twotower = TwoTower(
|
|
391
|
+
body=body,
|
|
392
|
+
loss=CESampled(ignore_index=tensor_schema["item_id"].padding_value),
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
"""
|
|
396
|
+
|
|
397
|
+
def __init__(
|
|
398
|
+
self,
|
|
399
|
+
body: TwoTowerBody,
|
|
400
|
+
loss: LossProto,
|
|
401
|
+
context_merger: Optional[ContextMergerProto] = None,
|
|
402
|
+
):
|
|
403
|
+
"""
|
|
404
|
+
:param body: An instance of TwoTowerBody.
|
|
405
|
+
:param loss: An object of a class that performs loss calculation
|
|
406
|
+
based on hidden states from the model, positive and optionally negative labels.
|
|
407
|
+
:param context_merger: An object of class that performs fusing query encoder hidden state
|
|
408
|
+
with input feature tensors.
|
|
409
|
+
Default: ``None``.
|
|
410
|
+
"""
|
|
411
|
+
super().__init__()
|
|
412
|
+
self.body = body
|
|
413
|
+
self.head = EmbeddingTyingHead()
|
|
414
|
+
self.loss = loss
|
|
415
|
+
self.context_merger = context_merger
|
|
416
|
+
self.loss.logits_callback = self.get_logits
|
|
417
|
+
|
|
418
|
+
self.reset_parameters()
|
|
419
|
+
|
|
420
|
+
@classmethod
|
|
421
|
+
def from_params(
|
|
422
|
+
cls,
|
|
423
|
+
schema: TensorSchema,
|
|
424
|
+
item_features_reader: FeaturesReaderProtocol,
|
|
425
|
+
embedding_dim: int = 192,
|
|
426
|
+
num_heads: int = 4,
|
|
427
|
+
num_blocks: int = 2,
|
|
428
|
+
max_sequence_length: int = 50,
|
|
429
|
+
dropout: float = 0.3,
|
|
430
|
+
excluded_features: Optional[list[str]] = None,
|
|
431
|
+
categorical_list_feature_aggregation_method: str = "sum",
|
|
432
|
+
) -> "TwoTower":
|
|
433
|
+
"""
|
|
434
|
+
Class method for fast creating an instance of TwoTower with typical types
|
|
435
|
+
of blocks and user provided parameters.\n
|
|
436
|
+
The item "tower" is a SwiGLU encoder (MLP with SwiGLU activation),\n
|
|
437
|
+
the user "tower" is a SasRec transformer layers, and loss is Cross-Entropy loss.\n
|
|
438
|
+
Embeddings of every feature in both "towers" are aggregated via sum.
|
|
439
|
+
The same features are be used in both "towers",
|
|
440
|
+
that is, the features specified in the tensor schema with the exception of `excluded_features`.\n
|
|
441
|
+
To create an instance of TwoTower with other types of blocks, please use the class constructor.
|
|
442
|
+
|
|
443
|
+
:param schema: tensor schema object with metainformation about features.
|
|
444
|
+
:param item_features_reader: A class that implements reading features,
|
|
445
|
+
processing them, and converting them to ``torch.Tensor`` for ItemTower.
|
|
446
|
+
You can use ``replay.nn.sequential.twotower.FeaturesReader`` as a standard class.\n
|
|
447
|
+
But you can implement your own feature processing,
|
|
448
|
+
just follow the ``replay.nn.sequential.twotower.FeaturesReaderProtocol`` protocol.
|
|
449
|
+
:param embedding_dim: embeddings dimension in both towers. Default: ``192``.
|
|
450
|
+
:param num_heads: number of heads in user tower SasRec layers. Default: ``4``.
|
|
451
|
+
:param num_blocks: number of blocks in user tower SasRec layers. Default: ``2``.
|
|
452
|
+
:param max_sequence_length: maximun length of sequence in user tower SasRec layers. Default: ``50``.
|
|
453
|
+
:param dropout: dropout value in both towers. Default: ``0.3``
|
|
454
|
+
:param excluded_features: A list containing the names of features
|
|
455
|
+
for which you do not need to generate an embedding.
|
|
456
|
+
Fragments from this list are expected to be contained in ``schema``.
|
|
457
|
+
Default: ``None``.
|
|
458
|
+
:param categorical_list_feature_aggregation_method: Mode to aggregate tokens
|
|
459
|
+
in token item representation (categorical list only).
|
|
460
|
+
Default: ``"sum"``.
|
|
461
|
+
:return: an instance of TwoTower class.
|
|
462
|
+
"""
|
|
463
|
+
from replay.nn.agg import SumAggregator
|
|
464
|
+
from replay.nn.embedding import SequenceEmbedding
|
|
465
|
+
from replay.nn.ffn import SwiGLUEncoder
|
|
466
|
+
from replay.nn.loss import CE
|
|
467
|
+
from replay.nn.mask import DefaultAttentionMask
|
|
468
|
+
from replay.nn.sequential import PositionAwareAggregator, SasRecTransformerLayer
|
|
469
|
+
|
|
470
|
+
excluded_features = [
|
|
471
|
+
schema.query_id_feature_name,
|
|
472
|
+
schema.timestamp_feature_name,
|
|
473
|
+
*(excluded_features or []),
|
|
474
|
+
]
|
|
475
|
+
excluded_features = list(set(excluded_features))
|
|
476
|
+
|
|
477
|
+
feature_names = set(schema.names) - set(excluded_features)
|
|
478
|
+
|
|
479
|
+
common_aggregator = SumAggregator(embedding_dim=embedding_dim)
|
|
480
|
+
return cls(
|
|
481
|
+
TwoTowerBody(
|
|
482
|
+
schema=schema,
|
|
483
|
+
embedder=SequenceEmbedding(
|
|
484
|
+
schema=schema,
|
|
485
|
+
categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,
|
|
486
|
+
excluded_features=excluded_features,
|
|
487
|
+
),
|
|
488
|
+
attn_mask_builder=DefaultAttentionMask(
|
|
489
|
+
reference_feature_name=schema.item_id_feature_name,
|
|
490
|
+
num_heads=num_heads,
|
|
491
|
+
),
|
|
492
|
+
query_tower_feature_names=feature_names,
|
|
493
|
+
item_tower_feature_names=feature_names,
|
|
494
|
+
query_embedding_aggregator=PositionAwareAggregator(
|
|
495
|
+
embedding_aggregator=common_aggregator,
|
|
496
|
+
max_sequence_length=max_sequence_length,
|
|
497
|
+
dropout=dropout,
|
|
498
|
+
),
|
|
499
|
+
item_embedding_aggregator=common_aggregator,
|
|
500
|
+
query_encoder=SasRecTransformerLayer(
|
|
501
|
+
embedding_dim=embedding_dim,
|
|
502
|
+
num_heads=num_heads,
|
|
503
|
+
num_blocks=num_blocks,
|
|
504
|
+
dropout=dropout,
|
|
505
|
+
activation="relu",
|
|
506
|
+
),
|
|
507
|
+
query_tower_output_normalization=torch.nn.LayerNorm(embedding_dim),
|
|
508
|
+
item_encoder=SwiGLUEncoder(embedding_dim=embedding_dim, hidden_dim=2 * embedding_dim),
|
|
509
|
+
item_features_reader=item_features_reader,
|
|
510
|
+
),
|
|
511
|
+
loss=CE(ignore_index=schema.item_id_features.item().padding_value),
|
|
512
|
+
context_merger=None,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
def reset_parameters(self) -> None:
|
|
516
|
+
self.body.reset_parameters()
|
|
517
|
+
|
|
518
|
+
def get_logits(
|
|
519
|
+
self,
|
|
520
|
+
model_embeddings: torch.Tensor,
|
|
521
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
522
|
+
) -> torch.Tensor:
|
|
523
|
+
"""
|
|
524
|
+
Function for tying last hidden states of query "tower" and set of item embeddings from item "tower"
|
|
525
|
+
via dot product in the model head.
|
|
526
|
+
|
|
527
|
+
:param model_embeddings: last hidden state of query tower.
|
|
528
|
+
:param candidates_to_score: IDs of items to be scored.
|
|
529
|
+
These IDs are used for obtaining item embeddings from item tower.
|
|
530
|
+
If is setted to ``None``, all item embeddings from item tower will be used.
|
|
531
|
+
Default: ``None``.
|
|
532
|
+
:return: logits.
|
|
533
|
+
"""
|
|
534
|
+
item_embeddings: torch.Tensor = self.body.item_tower(candidates_to_score)
|
|
535
|
+
logits: torch.Tensor = self.head(model_embeddings, item_embeddings)
|
|
536
|
+
return logits
|
|
537
|
+
|
|
538
|
+
def forward_train(
|
|
539
|
+
self,
|
|
540
|
+
feature_tensors: TensorMap,
|
|
541
|
+
padding_mask: torch.BoolTensor,
|
|
542
|
+
positive_labels: torch.LongTensor,
|
|
543
|
+
negative_labels: torch.LongTensor,
|
|
544
|
+
target_padding_mask: torch.BoolTensor,
|
|
545
|
+
) -> TrainOutput:
|
|
546
|
+
hidden_states = ()
|
|
547
|
+
query_hidden_states: torch.Tensor = self.body.query_tower(
|
|
548
|
+
feature_tensors,
|
|
549
|
+
padding_mask,
|
|
550
|
+
)
|
|
551
|
+
assert query_hidden_states.dim() == 3
|
|
552
|
+
hidden_states += (query_hidden_states,)
|
|
553
|
+
|
|
554
|
+
if self.context_merger is not None:
|
|
555
|
+
query_hidden_states: torch.Tensor = self.context_merger(
|
|
556
|
+
model_hidden_state=query_hidden_states,
|
|
557
|
+
feature_tensors=feature_tensors,
|
|
558
|
+
)
|
|
559
|
+
assert query_hidden_states.dim() == 3
|
|
560
|
+
hidden_states += (query_hidden_states,)
|
|
561
|
+
|
|
562
|
+
loss: torch.Tensor = self.loss(
|
|
563
|
+
model_embeddings=query_hidden_states,
|
|
564
|
+
feature_tensors=feature_tensors,
|
|
565
|
+
positive_labels=positive_labels,
|
|
566
|
+
negative_labels=negative_labels,
|
|
567
|
+
padding_mask=padding_mask,
|
|
568
|
+
target_padding_mask=target_padding_mask,
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
return TrainOutput(
|
|
572
|
+
loss=loss,
|
|
573
|
+
hidden_states=hidden_states,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
def forward_inference(
|
|
577
|
+
self,
|
|
578
|
+
feature_tensors: TensorMap,
|
|
579
|
+
padding_mask: torch.BoolTensor,
|
|
580
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
581
|
+
) -> InferenceOutput:
|
|
582
|
+
hidden_states = ()
|
|
583
|
+
query_hidden_states: torch.Tensor = self.body.query_tower(
|
|
584
|
+
feature_tensors,
|
|
585
|
+
padding_mask,
|
|
586
|
+
)
|
|
587
|
+
assert query_hidden_states.dim() == 3
|
|
588
|
+
|
|
589
|
+
hidden_states += (query_hidden_states,)
|
|
590
|
+
|
|
591
|
+
if self.context_merger is not None:
|
|
592
|
+
query_hidden_states: torch.Tensor = self.context_merger(
|
|
593
|
+
model_hidden_state=query_hidden_states,
|
|
594
|
+
feature_tensors=feature_tensors,
|
|
595
|
+
)
|
|
596
|
+
assert query_hidden_states.dim() == 3
|
|
597
|
+
hidden_states += (query_hidden_states,)
|
|
598
|
+
|
|
599
|
+
last_hidden_state = query_hidden_states[:, -1, :].contiguous()
|
|
600
|
+
logits = self.get_logits(last_hidden_state, candidates_to_score)
|
|
601
|
+
|
|
602
|
+
return InferenceOutput(
|
|
603
|
+
logits=logits,
|
|
604
|
+
hidden_states=hidden_states,
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
def forward(
|
|
608
|
+
self,
|
|
609
|
+
feature_tensors: TensorMap,
|
|
610
|
+
padding_mask: torch.BoolTensor,
|
|
611
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
612
|
+
positive_labels: Optional[torch.LongTensor] = None,
|
|
613
|
+
negative_labels: Optional[torch.LongTensor] = None,
|
|
614
|
+
target_padding_mask: Optional[torch.BoolTensor] = None,
|
|
615
|
+
) -> Union[TrainOutput, InferenceOutput]:
|
|
616
|
+
"""
|
|
617
|
+
:param feature_tensors: a dictionary of tensors to generate embeddings.
|
|
618
|
+
:param padding_mask: A mask of shape ``(batch_size, sequence_length)``
|
|
619
|
+
indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding").
|
|
620
|
+
``False`` value indicates that the corresponding ``key`` value will be ignored.
|
|
621
|
+
:param candidates_to_score: a tensor containing item IDs
|
|
622
|
+
for which you need to get logits at the inference stage.\n
|
|
623
|
+
**Note:** you must take into account the padding value when creating the tensor.\n
|
|
624
|
+
The tensor participates in calculations only on the inference stage.
|
|
625
|
+
You don't have to submit an argument at training stage,
|
|
626
|
+
but if it is submitted, then no effect will be provided.\n
|
|
627
|
+
Default: ``None``.
|
|
628
|
+
:param positive_labels: a tensor containing positive labels for calculating the loss.\n
|
|
629
|
+
You don't have to submit an argument at inference stage,
|
|
630
|
+
but if it is submitted, then no effect will be provided.\n
|
|
631
|
+
Default: ``None``.
|
|
632
|
+
:param negative_labels: a tensor containing negative labels for calculating the loss.\n
|
|
633
|
+
**Note:** Before run make sure that your loss supports calculations with negative labels.\n
|
|
634
|
+
You don't have to submit an argument at inference stage,
|
|
635
|
+
but if it is submitted, then no effect will be provided.\n
|
|
636
|
+
Default: ``None``.
|
|
637
|
+
:param target_padding_mask: A mask of shape ``(batch_size, sequence_length, num_positives)``
|
|
638
|
+
indicating elements from ``positive_labels`` to ignore during loss calculation.
|
|
639
|
+
``False`` value indicates that the corresponding value will be ignored.\n
|
|
640
|
+
You don't have to submit an argument at inference stage,
|
|
641
|
+
but if it is submitted, then no effect will be provided.\n
|
|
642
|
+
Default: ``None``.
|
|
643
|
+
:returns: During training, the model will return an object
|
|
644
|
+
of the ``TrainOutput`` container class.
|
|
645
|
+
At the inference stage, the ``InferenceOutput`` class will be returned.
|
|
646
|
+
"""
|
|
647
|
+
if self.training:
|
|
648
|
+
all(
|
|
649
|
+
map(
|
|
650
|
+
warning_is_not_none("Variable `{}` is not None. This will have no effect at the training stage."),
|
|
651
|
+
[(candidates_to_score, "candidates_to_score")],
|
|
652
|
+
)
|
|
653
|
+
)
|
|
654
|
+
return self.forward_train(
|
|
655
|
+
feature_tensors=feature_tensors,
|
|
656
|
+
padding_mask=padding_mask,
|
|
657
|
+
positive_labels=positive_labels,
|
|
658
|
+
negative_labels=negative_labels,
|
|
659
|
+
target_padding_mask=target_padding_mask,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
all(
|
|
663
|
+
map(
|
|
664
|
+
warning_is_not_none("Variable `{}` is not None. This will have no effect at the inference stage."),
|
|
665
|
+
[
|
|
666
|
+
(positive_labels, "positive_labels"),
|
|
667
|
+
(negative_labels, "negative_labels"),
|
|
668
|
+
(target_padding_mask, "target_padding_mask"),
|
|
669
|
+
],
|
|
670
|
+
)
|
|
671
|
+
)
|
|
672
|
+
return self.forward_inference(
|
|
673
|
+
feature_tensors=feature_tensors, padding_mask=padding_mask, candidates_to_score=candidates_to_score
|
|
674
|
+
)
|