replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,674 @@
1
+ from collections.abc import Sequence
2
+ from typing import Optional, Protocol, Union
3
+
4
+ import torch
5
+
6
+ from replay.data.nn import TensorMap, TensorSchema
7
+ from replay.nn.agg import AggregatorProto
8
+ from replay.nn.head import EmbeddingTyingHead
9
+ from replay.nn.loss import LossProto
10
+ from replay.nn.mask import AttentionMaskProto
11
+ from replay.nn.normalization import NormalizerProto
12
+ from replay.nn.output import InferenceOutput, TrainOutput
13
+ from replay.nn.utils import warning_is_not_none
14
+
15
+ from .reader import FeaturesReaderProtocol
16
+
17
+
18
+ class EmbedderProto(Protocol):
19
+ @property
20
+ def feature_names(self) -> Sequence[str]: ...
21
+
22
+ def forward(
23
+ self,
24
+ feature_tensors: TensorMap,
25
+ feature_names: Optional[Sequence[str]] = None,
26
+ ) -> TensorMap: ...
27
+
28
+ def reset_parameters(self) -> None: ...
29
+
30
+
31
+ class QueryEncoderProto(Protocol):
32
+ def forward(
33
+ self,
34
+ feature_tensors: TensorMap,
35
+ input_embeddings: torch.Tensor,
36
+ padding_mask: torch.BoolTensor,
37
+ attention_mask: torch.Tensor,
38
+ ) -> torch.Tensor: ...
39
+
40
+ def reset_parameters(self) -> None: ...
41
+
42
+
43
+ class ItemEncoderProto(Protocol):
44
+ def forward(
45
+ self,
46
+ feature_tensors: TensorMap,
47
+ input_embeddings: torch.Tensor,
48
+ ) -> torch.Tensor: ...
49
+
50
+ def reset_parameters(self) -> None: ...
51
+
52
+
53
+ class QueryTower(torch.nn.Module):
54
+ """
55
+ Query Tower of Two-Tower model.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ feature_names: Sequence[str],
61
+ embedder: EmbedderProto,
62
+ embedding_aggregator: AggregatorProto,
63
+ attn_mask_builder: AttentionMaskProto,
64
+ encoder: QueryEncoderProto,
65
+ output_normalization: NormalizerProto,
66
+ ):
67
+ """
68
+ :param feature_names: sequence of names used in query tower.
69
+ :param embedder: An object of a class that performs the logic of
70
+ generating embeddings from an input batch.
71
+ :param embedding_aggregator: An object of a class that performs
72
+ the logic of aggregating multiple embeddings of query tower.
73
+ :param attn_mask_builder: An object of a class that performs the logic of
74
+ generating an attention mask based on the features and padding mask given to the model.
75
+ :param encoder: An object of a class that performs the logic of generating
76
+ a query hidden embedding representation based on
77
+ features, padding masks, attention mask, and aggregated embedding of ``query_tower_feature_names``.
78
+ It's supposed to be a transformer.
79
+ :param output_normalization: An object of a class that performs the logic of
80
+ normalization of the hidden state obtained from the query encoder.\n
81
+ For example, it can be a ``torch.nn.LayerNorm`` or ``torch.nn.RMSNorm``.
82
+ """
83
+ super().__init__()
84
+ self.embedder = embedder
85
+ self.attn_mask_builder = attn_mask_builder
86
+ self.feature_names = feature_names
87
+ self.embedding_aggregator = embedding_aggregator
88
+ self.encoder = encoder
89
+ self.output_normalization = output_normalization
90
+
91
+ def reset_parameters(self) -> None:
92
+ self.embedding_aggregator.reset_parameters()
93
+ self.encoder.reset_parameters()
94
+ self.output_normalization.reset_parameters()
95
+
96
+ def forward(
97
+ self,
98
+ feature_tensors: TensorMap,
99
+ padding_mask: torch.BoolTensor,
100
+ ) -> torch.Tensor:
101
+ """
102
+ :param feature_tensors: a dictionary of tensors to generate embeddings.
103
+ :param padding_mask: A mask of shape ``(batch_size, sequence_length)``
104
+ indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding").
105
+ ``False`` value indicates that the corresponding ``key`` value will be ignored.
106
+ :returns: The final hidden state.\n
107
+ Expected shape: ``(batch_size, sequence_length, embedding_dim)``
108
+ """
109
+ embeddings: TensorMap = self.embedder(feature_tensors, self.feature_names)
110
+ agg_emb: torch.Tensor = self.embedding_aggregator(embeddings)
111
+ assert agg_emb.dim() == 3
112
+
113
+ attn_mask = self.attn_mask_builder(feature_tensors, padding_mask)
114
+
115
+ hidden_state: torch.Tensor = self.encoder(
116
+ feature_tensors=feature_tensors,
117
+ input_embeddings=agg_emb,
118
+ padding_mask=padding_mask,
119
+ attention_mask=attn_mask,
120
+ )
121
+ assert agg_emb.size() == hidden_state.size()
122
+
123
+ hidden_state = self.output_normalization(hidden_state)
124
+ return hidden_state
125
+
126
+
127
+ class ItemTower(torch.nn.Module):
128
+ """
129
+ Item Tower of Two-Tower model.
130
+
131
+ **Note**: ItemTower loads feature tensors of all items into memory.
132
+ """
133
+
134
+ def __init__(
135
+ self,
136
+ schema: TensorSchema,
137
+ item_features_reader: FeaturesReaderProtocol,
138
+ feature_names: Sequence[str],
139
+ embedder: EmbedderProto,
140
+ embedding_aggregator: AggregatorProto,
141
+ encoder: ItemEncoderProto,
142
+ ):
143
+ """
144
+ :param schema: tensor schema object with metainformation about features.
145
+ :param item_features_reader: A class that implements reading features,
146
+ processing them, and converting them to ``torch.Tensor`` for ItemTower.
147
+ You can use ``replay.nn.sequential.twotower.FeaturesReader`` as a standard class.\n
148
+ But you can implement your own feature processing,
149
+ just follow the ``replay.nn.sequential.twotower.FeaturesReaderProtocol`` protocol.
150
+ :param feature_names: sequence of names used in item tower.
151
+ :param embedder: An object of a class that performs the logic of
152
+ generating embeddings from an input batch.
153
+ :param embedding_aggregator: An object of a class that performs
154
+ the logic of aggregating multiple embeddings of item tower.
155
+ :param encoder: An object of a class that performs the logic of generating
156
+ an item hidden embedding representation based on
157
+ features and aggregated embeddings of ``item_tower_feature_names``.
158
+ Item encoder uses item reference which is created based on ``item_features_path``.
159
+ """
160
+ super().__init__()
161
+ self.embedder = embedder
162
+ self.feature_names = feature_names
163
+ self.embedding_aggregator = embedding_aggregator
164
+ self.encoder = encoder
165
+
166
+ for feature_name in schema:
167
+ if feature_name not in self.feature_names:
168
+ continue
169
+
170
+ self.register_buffer(f"item_reference_{feature_name}", item_features_reader[feature_name])
171
+
172
+ self.cache = None
173
+
174
+ def reset_parameters(self) -> None:
175
+ self.embedding_aggregator.reset_parameters()
176
+ self.encoder.reset_parameters()
177
+
178
+ def get_feature_buffer(self, feature_name: str) -> torch.Tensor:
179
+ buffer_name = f"item_reference_{feature_name}"
180
+ return self.get_buffer(buffer_name)
181
+
182
+ def forward(
183
+ self,
184
+ candidates_to_score: Optional[torch.LongTensor] = None,
185
+ ):
186
+ """
187
+ :param candidates_to_score: IDs of items using for obtaining item embeddings from item tower.
188
+ If is setted to ``None``, all item embeddings from item tower will be returned.
189
+ Default: ``None``.
190
+ :return: item embeddings.\n
191
+ Expected shape:\n
192
+ - ``(candidates_to_score, embedding_dim)``,
193
+ - ``(items_num, embedding_dim)`` if ``candidates_to_score`` is ``None``.
194
+ """
195
+ if self.training:
196
+ self.cache = None
197
+
198
+ if not self.training and self.cache is not None:
199
+ if candidates_to_score is None:
200
+ return self.cache
201
+ return self.cache[candidates_to_score]
202
+
203
+ if candidates_to_score is None:
204
+ feature_tensors = {
205
+ feature_name: self.get_feature_buffer(feature_name) for feature_name in self.feature_names
206
+ }
207
+ else:
208
+ feature_tensors = {
209
+ feature_name: self.get_feature_buffer(feature_name)[candidates_to_score]
210
+ for feature_name in self.feature_names
211
+ }
212
+
213
+ embeddings: TensorMap = self.embedder(feature_tensors, self.feature_names)
214
+ agg_emb: torch.Tensor = self.embedding_aggregator(embeddings)
215
+
216
+ hidden_state: torch.Tensor = self.encoder(
217
+ feature_tensors=feature_tensors,
218
+ input_embeddings=agg_emb,
219
+ )
220
+ assert agg_emb.size() == hidden_state.size()
221
+
222
+ if not self.training and self.cache is None and candidates_to_score is None:
223
+ self.cache = hidden_state
224
+ return hidden_state
225
+
226
+
227
+ class TwoTowerBody(torch.nn.Module):
228
+ """
229
+ Foundation for Two-Tower model which creates query "tower" and item "tower".\n
230
+
231
+ For usage of two tower model, an instance of this class should be passed into `TwoTower`_ with any loss
232
+ from `Losses`_.
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ schema: TensorSchema,
238
+ embedder: EmbedderProto,
239
+ attn_mask_builder: AttentionMaskProto,
240
+ query_tower_feature_names: Sequence[str],
241
+ item_tower_feature_names: Sequence[str],
242
+ query_embedding_aggregator: AggregatorProto,
243
+ item_embedding_aggregator: AggregatorProto,
244
+ query_encoder: QueryEncoderProto,
245
+ query_tower_output_normalization: NormalizerProto,
246
+ item_encoder: ItemEncoderProto,
247
+ item_features_reader: FeaturesReaderProtocol,
248
+ ):
249
+ """
250
+ :param schema: tensor schema object with metainformation about features.
251
+ :param embedder: An object of a class that performs the logic of
252
+ generating embeddings from an input batch.\n
253
+ The same object is used to generate embeddings in different towers.
254
+ :param query_tower_feature_names: sequence of names used in query tower.
255
+ :param item_tower_feature_names: sequence of names used in item tower.
256
+ :param query_embedding_aggregator: An object of a class that performs
257
+ the logic of aggregating multiple embeddings of query tower.
258
+ :param item_embedding_aggregator: An object of a class that performs
259
+ the logic of aggregating multiple embeddings of item tower.
260
+ :param query_encoder: An object of a class that performs the logic of generating
261
+ a query hidden embedding representation based on
262
+ features, padding masks, attention mask, and aggregated embedding of ``query_tower_feature_names``.
263
+ It's supposed to be a transformer.
264
+ :param query_tower_output_normalization: An object of a class that performs the logic of
265
+ normalization of the hidden state obtained from the query encoder.\n
266
+ For example, it can be a ``torch.nn.LayerNorm`` or ``torch.nn.RMSNorm``.
267
+ :param attn_mask_builder: An object of a class that performs the logic of
268
+ generating an attention mask based on the features and padding mask given to the model.
269
+ :param item_encoder: An object of a class that performs the logic of generating
270
+ an item hidden embedding representation based on
271
+ features and aggregated embeddings of ``item_tower_feature_names``.
272
+ Item encoder uses item reference which is created based on ``item_features_path``.
273
+ :param item_features_reader: A class that implements reading features,
274
+ processing them, and converting them to ``torch.Tensor`` for ItemTower.
275
+ You can use ``replay.nn.sequential.twotower.FeaturesReader`` as a standard class.\n
276
+ But you can implement your own feature processing,
277
+ just follow the ``replay.nn.sequential.twotower.FeaturesReaderProtocol`` protocol.
278
+
279
+ """
280
+ super().__init__()
281
+ self.embedder = embedder
282
+ feature_names_union = set(query_tower_feature_names) | set(item_tower_feature_names)
283
+ feature_names_not_in_emb = feature_names_union - set(self.embedder.feature_names)
284
+ if len(feature_names_not_in_emb) != 0:
285
+ msg = f"Feature names found that embedder does not support {list(feature_names_not_in_emb)}"
286
+ raise ValueError(msg)
287
+
288
+ self.query_tower = QueryTower(
289
+ query_tower_feature_names,
290
+ embedder,
291
+ query_embedding_aggregator,
292
+ attn_mask_builder,
293
+ query_encoder,
294
+ query_tower_output_normalization,
295
+ )
296
+ self.item_tower = ItemTower(
297
+ schema,
298
+ item_features_reader,
299
+ item_tower_feature_names,
300
+ embedder,
301
+ item_embedding_aggregator,
302
+ item_encoder,
303
+ )
304
+
305
+ def reset_parameters(self) -> None:
306
+ self.embedder.reset_parameters()
307
+ self.query_tower.reset_parameters()
308
+ self.item_tower.reset_parameters()
309
+
310
+
311
+ class ContextMergerProto(Protocol):
312
+ def forward(
313
+ self,
314
+ model_hidden_state: torch.Tensor,
315
+ feature_tensors: TensorMap,
316
+ ) -> torch.Tensor: ...
317
+
318
+ def reset_parameters(self) -> None: ...
319
+
320
+
321
+ class TwoTower(torch.nn.Module):
322
+ """
323
+ Implementation generic Two-Tower architecture with two independent "towers" (encoders)
324
+ which encode separate inputs. In recommender systems they are typically query tower and item tower.
325
+ The output hidden states of each "tower" are fused via dot product in the model head.
326
+
327
+ Source paper: https://doi.org/10.1145/3366424.3386195
328
+
329
+ Example:
330
+
331
+ .. code-block:: python
332
+
333
+ from replay.data import FeatureHint, FeatureSource, FeatureType
334
+ from replay.data.nn import TensorFeatureInfo, TensorFeatureSource, TensorSchema
335
+ from replay.nn.agg import SumAggregator
336
+ from replay.nn.embedding import SequenceEmbedding
337
+ from replay.nn.ffn import SwiGLUEncoder
338
+ from replay.nn.mask import DefaultAttentionMask
339
+ from replay.nn.loss import CESampled
340
+ from replay.nn.sequential import PositionAwareAggregator, SasRecTransformerLayer
341
+ from replay.nn.sequential.twotower import FeaturesReader
342
+
343
+ tensor_schema = TensorSchema(
344
+ [
345
+ TensorFeatureInfo(
346
+ "item_id",
347
+ is_seq=True,
348
+ feature_type=FeatureType.CATEGORICAL,
349
+ embedding_dim=256,
350
+ padding_value=NUM_UNIQUE_ITEMS,
351
+ cardinality=NUM_UNIQUE_ITEMS,
352
+ feature_hint=FeatureHint.ITEM_ID,
353
+ feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "item_id")]
354
+ ),
355
+ ]
356
+ )
357
+
358
+ common_aggregator = SumAggregator(embedding_dim=256)
359
+
360
+ body = TwoTowerBody(
361
+ schema=tensor_schema,
362
+ embedder=SequenceEmbedding(schema=tensor_schema),
363
+ attn_mask_builder=DefaultAttentionMask(
364
+ reference_feature_name=tensor_schema.item_id_feature_name,
365
+ num_heads=2,
366
+ ),
367
+ query_tower_feature_names=tensor_schema.names,
368
+ item_tower_feature_names=tensor_schema.names,
369
+ query_embedding_aggregator=PositionAwareAggregator(
370
+ embedding_aggregator=common_aggregator,
371
+ max_sequence_length=100,
372
+ dropout=0.2,
373
+ ),
374
+ item_embedding_aggregator=common_aggregator,
375
+ query_encoder=SasRecTransformerLayer(
376
+ embedding_dim=256,
377
+ num_heads=2,
378
+ num_blocks=2,
379
+ dropout=0.3,
380
+ activation="relu",
381
+ ),
382
+ query_tower_output_normalization=torch.nn.LayerNorm(256),
383
+ item_encoder=SwiGLUEncoder(embedding_dim=256, hidden_dim=2*256),
384
+ item_features_reader=FeaturesReader(
385
+ schema=tensor_schema,
386
+ metadata={"item_id": {}},
387
+ path="item_features.parquet",
388
+ ),
389
+ )
390
+ twotower = TwoTower(
391
+ body=body,
392
+ loss=CESampled(ignore_index=tensor_schema["item_id"].padding_value),
393
+ )
394
+
395
+ """
396
+
397
+ def __init__(
398
+ self,
399
+ body: TwoTowerBody,
400
+ loss: LossProto,
401
+ context_merger: Optional[ContextMergerProto] = None,
402
+ ):
403
+ """
404
+ :param body: An instance of TwoTowerBody.
405
+ :param loss: An object of a class that performs loss calculation
406
+ based on hidden states from the model, positive and optionally negative labels.
407
+ :param context_merger: An object of class that performs fusing query encoder hidden state
408
+ with input feature tensors.
409
+ Default: ``None``.
410
+ """
411
+ super().__init__()
412
+ self.body = body
413
+ self.head = EmbeddingTyingHead()
414
+ self.loss = loss
415
+ self.context_merger = context_merger
416
+ self.loss.logits_callback = self.get_logits
417
+
418
+ self.reset_parameters()
419
+
420
+ @classmethod
421
+ def from_params(
422
+ cls,
423
+ schema: TensorSchema,
424
+ item_features_reader: FeaturesReaderProtocol,
425
+ embedding_dim: int = 192,
426
+ num_heads: int = 4,
427
+ num_blocks: int = 2,
428
+ max_sequence_length: int = 50,
429
+ dropout: float = 0.3,
430
+ excluded_features: Optional[list[str]] = None,
431
+ categorical_list_feature_aggregation_method: str = "sum",
432
+ ) -> "TwoTower":
433
+ """
434
+ Class method for fast creating an instance of TwoTower with typical types
435
+ of blocks and user provided parameters.\n
436
+ The item "tower" is a SwiGLU encoder (MLP with SwiGLU activation),\n
437
+ the user "tower" is a SasRec transformer layers, and loss is Cross-Entropy loss.\n
438
+ Embeddings of every feature in both "towers" are aggregated via sum.
439
+ The same features are be used in both "towers",
440
+ that is, the features specified in the tensor schema with the exception of `excluded_features`.\n
441
+ To create an instance of TwoTower with other types of blocks, please use the class constructor.
442
+
443
+ :param schema: tensor schema object with metainformation about features.
444
+ :param item_features_reader: A class that implements reading features,
445
+ processing them, and converting them to ``torch.Tensor`` for ItemTower.
446
+ You can use ``replay.nn.sequential.twotower.FeaturesReader`` as a standard class.\n
447
+ But you can implement your own feature processing,
448
+ just follow the ``replay.nn.sequential.twotower.FeaturesReaderProtocol`` protocol.
449
+ :param embedding_dim: embeddings dimension in both towers. Default: ``192``.
450
+ :param num_heads: number of heads in user tower SasRec layers. Default: ``4``.
451
+ :param num_blocks: number of blocks in user tower SasRec layers. Default: ``2``.
452
+ :param max_sequence_length: maximun length of sequence in user tower SasRec layers. Default: ``50``.
453
+ :param dropout: dropout value in both towers. Default: ``0.3``
454
+ :param excluded_features: A list containing the names of features
455
+ for which you do not need to generate an embedding.
456
+ Fragments from this list are expected to be contained in ``schema``.
457
+ Default: ``None``.
458
+ :param categorical_list_feature_aggregation_method: Mode to aggregate tokens
459
+ in token item representation (categorical list only).
460
+ Default: ``"sum"``.
461
+ :return: an instance of TwoTower class.
462
+ """
463
+ from replay.nn.agg import SumAggregator
464
+ from replay.nn.embedding import SequenceEmbedding
465
+ from replay.nn.ffn import SwiGLUEncoder
466
+ from replay.nn.loss import CE
467
+ from replay.nn.mask import DefaultAttentionMask
468
+ from replay.nn.sequential import PositionAwareAggregator, SasRecTransformerLayer
469
+
470
+ excluded_features = [
471
+ schema.query_id_feature_name,
472
+ schema.timestamp_feature_name,
473
+ *(excluded_features or []),
474
+ ]
475
+ excluded_features = list(set(excluded_features))
476
+
477
+ feature_names = set(schema.names) - set(excluded_features)
478
+
479
+ common_aggregator = SumAggregator(embedding_dim=embedding_dim)
480
+ return cls(
481
+ TwoTowerBody(
482
+ schema=schema,
483
+ embedder=SequenceEmbedding(
484
+ schema=schema,
485
+ categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,
486
+ excluded_features=excluded_features,
487
+ ),
488
+ attn_mask_builder=DefaultAttentionMask(
489
+ reference_feature_name=schema.item_id_feature_name,
490
+ num_heads=num_heads,
491
+ ),
492
+ query_tower_feature_names=feature_names,
493
+ item_tower_feature_names=feature_names,
494
+ query_embedding_aggregator=PositionAwareAggregator(
495
+ embedding_aggregator=common_aggregator,
496
+ max_sequence_length=max_sequence_length,
497
+ dropout=dropout,
498
+ ),
499
+ item_embedding_aggregator=common_aggregator,
500
+ query_encoder=SasRecTransformerLayer(
501
+ embedding_dim=embedding_dim,
502
+ num_heads=num_heads,
503
+ num_blocks=num_blocks,
504
+ dropout=dropout,
505
+ activation="relu",
506
+ ),
507
+ query_tower_output_normalization=torch.nn.LayerNorm(embedding_dim),
508
+ item_encoder=SwiGLUEncoder(embedding_dim=embedding_dim, hidden_dim=2 * embedding_dim),
509
+ item_features_reader=item_features_reader,
510
+ ),
511
+ loss=CE(ignore_index=schema.item_id_features.item().padding_value),
512
+ context_merger=None,
513
+ )
514
+
515
+ def reset_parameters(self) -> None:
516
+ self.body.reset_parameters()
517
+
518
+ def get_logits(
519
+ self,
520
+ model_embeddings: torch.Tensor,
521
+ candidates_to_score: Optional[torch.LongTensor] = None,
522
+ ) -> torch.Tensor:
523
+ """
524
+ Function for tying last hidden states of query "tower" and set of item embeddings from item "tower"
525
+ via dot product in the model head.
526
+
527
+ :param model_embeddings: last hidden state of query tower.
528
+ :param candidates_to_score: IDs of items to be scored.
529
+ These IDs are used for obtaining item embeddings from item tower.
530
+ If is setted to ``None``, all item embeddings from item tower will be used.
531
+ Default: ``None``.
532
+ :return: logits.
533
+ """
534
+ item_embeddings: torch.Tensor = self.body.item_tower(candidates_to_score)
535
+ logits: torch.Tensor = self.head(model_embeddings, item_embeddings)
536
+ return logits
537
+
538
+ def forward_train(
539
+ self,
540
+ feature_tensors: TensorMap,
541
+ padding_mask: torch.BoolTensor,
542
+ positive_labels: torch.LongTensor,
543
+ negative_labels: torch.LongTensor,
544
+ target_padding_mask: torch.BoolTensor,
545
+ ) -> TrainOutput:
546
+ hidden_states = ()
547
+ query_hidden_states: torch.Tensor = self.body.query_tower(
548
+ feature_tensors,
549
+ padding_mask,
550
+ )
551
+ assert query_hidden_states.dim() == 3
552
+ hidden_states += (query_hidden_states,)
553
+
554
+ if self.context_merger is not None:
555
+ query_hidden_states: torch.Tensor = self.context_merger(
556
+ model_hidden_state=query_hidden_states,
557
+ feature_tensors=feature_tensors,
558
+ )
559
+ assert query_hidden_states.dim() == 3
560
+ hidden_states += (query_hidden_states,)
561
+
562
+ loss: torch.Tensor = self.loss(
563
+ model_embeddings=query_hidden_states,
564
+ feature_tensors=feature_tensors,
565
+ positive_labels=positive_labels,
566
+ negative_labels=negative_labels,
567
+ padding_mask=padding_mask,
568
+ target_padding_mask=target_padding_mask,
569
+ )
570
+
571
+ return TrainOutput(
572
+ loss=loss,
573
+ hidden_states=hidden_states,
574
+ )
575
+
576
+ def forward_inference(
577
+ self,
578
+ feature_tensors: TensorMap,
579
+ padding_mask: torch.BoolTensor,
580
+ candidates_to_score: Optional[torch.LongTensor] = None,
581
+ ) -> InferenceOutput:
582
+ hidden_states = ()
583
+ query_hidden_states: torch.Tensor = self.body.query_tower(
584
+ feature_tensors,
585
+ padding_mask,
586
+ )
587
+ assert query_hidden_states.dim() == 3
588
+
589
+ hidden_states += (query_hidden_states,)
590
+
591
+ if self.context_merger is not None:
592
+ query_hidden_states: torch.Tensor = self.context_merger(
593
+ model_hidden_state=query_hidden_states,
594
+ feature_tensors=feature_tensors,
595
+ )
596
+ assert query_hidden_states.dim() == 3
597
+ hidden_states += (query_hidden_states,)
598
+
599
+ last_hidden_state = query_hidden_states[:, -1, :].contiguous()
600
+ logits = self.get_logits(last_hidden_state, candidates_to_score)
601
+
602
+ return InferenceOutput(
603
+ logits=logits,
604
+ hidden_states=hidden_states,
605
+ )
606
+
607
+ def forward(
608
+ self,
609
+ feature_tensors: TensorMap,
610
+ padding_mask: torch.BoolTensor,
611
+ candidates_to_score: Optional[torch.LongTensor] = None,
612
+ positive_labels: Optional[torch.LongTensor] = None,
613
+ negative_labels: Optional[torch.LongTensor] = None,
614
+ target_padding_mask: Optional[torch.BoolTensor] = None,
615
+ ) -> Union[TrainOutput, InferenceOutput]:
616
+ """
617
+ :param feature_tensors: a dictionary of tensors to generate embeddings.
618
+ :param padding_mask: A mask of shape ``(batch_size, sequence_length)``
619
+ indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding").
620
+ ``False`` value indicates that the corresponding ``key`` value will be ignored.
621
+ :param candidates_to_score: a tensor containing item IDs
622
+ for which you need to get logits at the inference stage.\n
623
+ **Note:** you must take into account the padding value when creating the tensor.\n
624
+ The tensor participates in calculations only on the inference stage.
625
+ You don't have to submit an argument at training stage,
626
+ but if it is submitted, then no effect will be provided.\n
627
+ Default: ``None``.
628
+ :param positive_labels: a tensor containing positive labels for calculating the loss.\n
629
+ You don't have to submit an argument at inference stage,
630
+ but if it is submitted, then no effect will be provided.\n
631
+ Default: ``None``.
632
+ :param negative_labels: a tensor containing negative labels for calculating the loss.\n
633
+ **Note:** Before run make sure that your loss supports calculations with negative labels.\n
634
+ You don't have to submit an argument at inference stage,
635
+ but if it is submitted, then no effect will be provided.\n
636
+ Default: ``None``.
637
+ :param target_padding_mask: A mask of shape ``(batch_size, sequence_length, num_positives)``
638
+ indicating elements from ``positive_labels`` to ignore during loss calculation.
639
+ ``False`` value indicates that the corresponding value will be ignored.\n
640
+ You don't have to submit an argument at inference stage,
641
+ but if it is submitted, then no effect will be provided.\n
642
+ Default: ``None``.
643
+ :returns: During training, the model will return an object
644
+ of the ``TrainOutput`` container class.
645
+ At the inference stage, the ``InferenceOutput`` class will be returned.
646
+ """
647
+ if self.training:
648
+ all(
649
+ map(
650
+ warning_is_not_none("Variable `{}` is not None. This will have no effect at the training stage."),
651
+ [(candidates_to_score, "candidates_to_score")],
652
+ )
653
+ )
654
+ return self.forward_train(
655
+ feature_tensors=feature_tensors,
656
+ padding_mask=padding_mask,
657
+ positive_labels=positive_labels,
658
+ negative_labels=negative_labels,
659
+ target_padding_mask=target_padding_mask,
660
+ )
661
+
662
+ all(
663
+ map(
664
+ warning_is_not_none("Variable `{}` is not None. This will have no effect at the inference stage."),
665
+ [
666
+ (positive_labels, "positive_labels"),
667
+ (negative_labels, "negative_labels"),
668
+ (target_padding_mask, "target_padding_mask"),
669
+ ],
670
+ )
671
+ )
672
+ return self.forward_inference(
673
+ feature_tensors=feature_tensors, padding_mask=padding_mask, candidates_to_score=candidates_to_score
674
+ )