replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,373 @@
1
+ from typing import Callable, Optional, TypedDict
2
+
3
+ import torch
4
+
5
+ from replay.data.nn import TensorMap
6
+
7
+ from .base import SampledLossBase, mask_negative_logits
8
+
9
+
10
+ class LogInCESampledOutput(TypedDict):
11
+ positive_logits: torch.Tensor
12
+ negative_logits: torch.Tensor
13
+ positive_labels: torch.LongTensor
14
+ negative_labels: torch.LongTensor
15
+ target_padding_mask: torch.BoolTensor
16
+
17
+
18
+ class LogInCEBase(SampledLossBase):
19
+ def get_sampled_logits(
20
+ self,
21
+ model_embeddings: torch.Tensor,
22
+ positive_labels: torch.LongTensor, # [batch_size, seq_len, num_positives]
23
+ negative_labels: torch.LongTensor, # [num_negatives] or [batch_size, seq_len, num_negatives]
24
+ target_padding_mask: torch.BoolTensor, # [batch_size, seq_len, num_positives]
25
+ ) -> LogInCESampledOutput:
26
+ """
27
+ The function of calculating positive and negative logits in LogInCE losses.
28
+ Based on the embeddings from the model, positive and negative labels.
29
+
30
+ The function supports the calculation of logits for the case of multi-positive labels
31
+ (there are several labels for each position in the sequence).
32
+
33
+ :param model_embeddings: Embeddings from the model. This is usually the last hidden state.
34
+ Expected shape: ``(batch_size, sequence_length, embedding_dim)``
35
+ :param positive_labels: a tensor containing labels with positive events.
36
+ Expected shape: ``(batch_size, sequence_length, num_positives)``
37
+ :param negative_labels: a tensor containing labels with negative events.
38
+ Expected shape:
39
+ - ``(batch_size, sequence_length, num_negatives)``
40
+ - ``(batch_size, num_negatives)``
41
+ - ``(num_negatives)`` - a case where the same negative events are used for the entire batch.
42
+ :param target_padding_mask: Padding mask for ``positive_labels`` (targets).
43
+ ``False`` value indicates that the corresponding ``key`` value will be ignored.
44
+ Expected shape: ``(batch_size, sequence_length, num_positives)``
45
+
46
+ :returns: LogInCESampledOutput. A dictionary containing positive and negative logits with labels.
47
+ """
48
+ ################## SHAPE CHECKING STAGE START ##################
49
+ batch_size, seq_len, num_positives = positive_labels.size()
50
+ assert target_padding_mask.size() == (batch_size, seq_len, num_positives)
51
+ num_negatives = negative_labels.size(-1)
52
+
53
+ if negative_labels.size() == (batch_size, num_negatives):
54
+ # [batch_size, num_negatives] -> [batch_size, 1, num_negatives]
55
+ negative_labels = negative_labels.unsqueeze(1).repeat(1, seq_len, 1)
56
+
57
+ assert negative_labels.size() == (batch_size, seq_len, num_negatives) or negative_labels.dim() == 1
58
+ ################## SHAPE CHECKING STAGE END ##################
59
+
60
+ # Get output embedding for every user event
61
+ embedding_dim = model_embeddings.size(-1)
62
+ assert model_embeddings.size() == (batch_size, seq_len, embedding_dim)
63
+
64
+ # [batch_size, seq_len, num_positives] -> [batch_size, seq_len]
65
+ masked_target_padding_mask: torch.BoolTensor = target_padding_mask.sum(-1).bool()
66
+ masked_batch_size = masked_target_padding_mask.sum().item()
67
+
68
+ # Apply target mask
69
+ # [batch_size, seq_len, emb_dim] -> [masked_batch_size, emb_dim]
70
+ model_embeddings = model_embeddings[masked_target_padding_mask]
71
+ assert model_embeddings.size() == (masked_batch_size, embedding_dim)
72
+
73
+ # [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
74
+ positive_labels = positive_labels[masked_target_padding_mask]
75
+ assert positive_labels.size() == (masked_batch_size, num_positives)
76
+
77
+ if negative_labels.dim() > 1: # pragma: no cover
78
+ # [batch_size, seq_len, num_negatives] -> [masked_batch_size, num_negatives]
79
+ negative_labels = negative_labels[masked_target_padding_mask]
80
+ assert negative_labels.size() == (masked_batch_size, num_negatives)
81
+
82
+ positive_logits = self.logits_callback(model_embeddings, positive_labels)
83
+ assert positive_logits.size() == (masked_batch_size, num_positives)
84
+
85
+ negative_logits = self.logits_callback(model_embeddings, negative_labels)
86
+ assert negative_logits.size() == (masked_batch_size, num_negatives)
87
+
88
+ # [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
89
+ target_padding_mask = target_padding_mask[masked_target_padding_mask]
90
+ assert target_padding_mask.size() == (masked_batch_size, num_positives)
91
+
92
+ return {
93
+ "positive_logits": positive_logits,
94
+ "negative_logits": negative_logits,
95
+ "positive_labels": positive_labels,
96
+ "negative_labels": negative_labels,
97
+ "target_padding_mask": target_padding_mask,
98
+ }
99
+
100
+
101
+ class LogInCE(LogInCEBase):
102
+ """
103
+ LogInCE loss.
104
+
105
+ .. math::
106
+
107
+ L_{\\text{InfoNCE}} = -\\log \\frac{\\sum_{p \\in P}
108
+ \\exp(\\mathrm{sim}(q, p))}{\\sum_{p \\in P}
109
+ \\exp(\\mathrm{sim}(q, p)) + \\sum_{n \\in N} \\exp(\\mathrm{sim}(q, n))},
110
+
111
+ where q -- query embedding, P -- set of positive logits, N -- set of negative logits,
112
+ :math:`sim(\\cdot, \\cdot)` -- similaruty function.
113
+
114
+ The loss supports the calculation of logits for the case of multi-positive labels
115
+ (there are several labels for each position in the sequence).
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ cardinality: int,
121
+ log_epsilon: float = 1e-6,
122
+ clamp_border: float = 100.0,
123
+ negative_labels_ignore_index: int = -100,
124
+ ):
125
+ """
126
+ :param cardinality: number of unique items in vocabulary (catalog).
127
+ The specified cardinality value must not take into account the padding value.
128
+ :param log_epsilon: correction to avoid zero in the logarithm during loss calculating.
129
+ Default: ``1e-6``.
130
+ :param clamp_border: upper bound for clamping loss tensor, lower bound will be setted to ``-clamp_border``.
131
+ Default: ``100.0``.
132
+ :param negative_labels_ignore_index: padding value for negative labels.
133
+ This may be the case when negative labels
134
+ are formed at the preprocessing level, rather than the negative sampler.
135
+ The index is ignored and does not contribute to the loss.
136
+ Default: ``-100``.
137
+ """
138
+ super().__init__()
139
+ self.cardinality = cardinality
140
+ self.log_epsilon = log_epsilon
141
+ self.clamp_border = clamp_border
142
+ self.negative_labels_ignore_index = negative_labels_ignore_index
143
+ self._logits_callback = None
144
+
145
+ @property
146
+ def logits_callback(
147
+ self,
148
+ ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
149
+ """
150
+ Property for calling a function for the logits computation.\n
151
+
152
+ This function is expected to receive model's last hidden state
153
+ and optionally item IDs, and return a logits tensor.
154
+
155
+ It is expected that the corresponding head model method will be used as this function,
156
+ for example, the ``get_logits`` method of the ``SasRec`` class.
157
+
158
+ :return: callable function.
159
+ """
160
+ if self._logits_callback is None:
161
+ msg = "The callback for getting logits is not defined"
162
+ raise AttributeError(msg)
163
+ return self._logits_callback
164
+
165
+ @logits_callback.setter
166
+ def logits_callback(self, func: Optional[Callable]) -> None:
167
+ self._logits_callback = func
168
+
169
+ def forward(
170
+ self,
171
+ model_embeddings: torch.Tensor,
172
+ feature_tensors: TensorMap, # noqa: ARG002
173
+ positive_labels: torch.LongTensor,
174
+ negative_labels: torch.LongTensor, # noqa: ARG002
175
+ padding_mask: torch.BoolTensor, # noqa: ARG002
176
+ target_padding_mask: torch.BoolTensor,
177
+ ) -> torch.Tensor:
178
+ """
179
+ forward(model_embeddings, positive_labels, target_padding_mask)
180
+ **Note**: At forward pass, the whole catalog of items is used as negatives.
181
+ Next, negative logits, corresponding to positions where negative labels
182
+ coincide with positive ones, are masked.
183
+
184
+ :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
185
+ :param positive_labels: ground truth labels of positive events
186
+ of shape (batch_size, sequence_length, num_positives).
187
+ :param target_padding_mask: padding mask corresponding for ``positive_labels``
188
+ of shape (batch_size, sequence_length, num_positives).
189
+ :return: computed loss value.
190
+ """
191
+ all_negative_labels = torch.arange(
192
+ self.cardinality,
193
+ dtype=torch.long,
194
+ device=positive_labels.device,
195
+ )
196
+ sampled = self.get_sampled_logits(
197
+ model_embeddings,
198
+ positive_labels,
199
+ all_negative_labels,
200
+ target_padding_mask,
201
+ )
202
+ positive_logits = sampled["positive_logits"] # [masked_batch_size, num_positives]
203
+ negative_logits = sampled["negative_logits"] # [masked_batch_size, num_negatives]
204
+ positive_labels = sampled["positive_labels"] # [masked_batch_size, num_positives]
205
+ all_negative_labels = sampled["negative_labels"] # [masked_batch_size, num_negatives] or [num_negatives]
206
+ target_padding_mask = sampled["target_padding_mask"] # [masked_batch_size, num_positives]
207
+
208
+ # [masked_batch_size, num_negatives] - assign low values to some negative logits
209
+ negative_logits = mask_negative_logits(
210
+ negative_logits,
211
+ all_negative_labels,
212
+ positive_labels,
213
+ self.negative_labels_ignore_index,
214
+ )
215
+
216
+ max_values = torch.max(
217
+ positive_logits.max(-1, keepdim=True).values,
218
+ negative_logits.max(-1, keepdim=True).values,
219
+ ) # [masked_batch_size, 1]
220
+ positive_logits = positive_logits - max_values
221
+ negative_logits = negative_logits - max_values
222
+
223
+ positive_logits = torch.exp(positive_logits)
224
+ positive_logits = positive_logits * target_padding_mask
225
+ # [masked_batch_size, num_positives] -> [masked_batch_size]
226
+ positive_logits = positive_logits.sum(-1)
227
+
228
+ negative_logits = torch.exp(negative_logits)
229
+ # [masked_batch_size, num_negatives] -> [masked_batch_size]
230
+ negative_logits = negative_logits.sum(-1)
231
+
232
+ probabilities = positive_logits / (positive_logits + negative_logits)
233
+ loss = -torch.clamp(
234
+ torch.log(probabilities + self.log_epsilon),
235
+ -self.clamp_border,
236
+ self.clamp_border,
237
+ )
238
+ return loss.mean()
239
+
240
+
241
+ class LogInCESampled(LogInCEBase):
242
+ """
243
+ Sampled version of LogInCE (Log InfoNCE) loss (with negative sampling items).
244
+
245
+ .. math::
246
+
247
+ L_{\\text{InfoNCE}} = -\\log \\frac{\\sum_{p \\in P} \\exp(\\mathrm{sim}(q, p))}{\\sum_{p \\in P}
248
+ \\exp(\\mathrm{sim}(q, p)) + \\sum_{n \\in N_{\\text{sampled}}} \\exp(\\mathrm{sim}(q, n))},
249
+
250
+ where q -- query embedding, P -- set of positive logits, :math:`N_sampled` -- set of negative logits,
251
+ :math:`sim(\\cdot, \\cdot)` -- similaruty function.\n
252
+ Same as ``LogInCE``, the difference in the set of negatives.
253
+
254
+ The loss supports the calculation of logits for the case of multi-positive labels
255
+ (there are several labels for each position in the sequence).
256
+ """
257
+
258
+ def __init__(
259
+ self,
260
+ log_epsilon: float = 1e-6,
261
+ clamp_border: float = 100.0,
262
+ negative_labels_ignore_index: int = -100,
263
+ ):
264
+ """
265
+ :param log_epsilon: correction to avoid zero in the logarithm during loss calculating.
266
+ Default: 1e-6.
267
+ :param clamp_border: upper bound for clamping loss tensor, lower bound will be setted to -`clamp_border`.
268
+ Default: 100.0.
269
+ :param negative_labels_ignore_index: padding value for negative labels.
270
+ This may be the case when negative labels
271
+ are formed at the preprocessing level, rather than the negative sampler.
272
+ The index is ignored and does not contribute to the loss.
273
+ Default: ``-100``.
274
+ """
275
+ super().__init__()
276
+ self.log_epsilon = log_epsilon
277
+ self.clamp_border = clamp_border
278
+ self.negative_labels_ignore_index = negative_labels_ignore_index
279
+ self._logits_callback = None
280
+
281
+ @property
282
+ def logits_callback(
283
+ self,
284
+ ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
285
+ """
286
+ Property for calling a function for the logits computation.\n
287
+
288
+ This function is expected to receive model's last hidden state
289
+ and optionally item IDs, and return a logits tensor.
290
+
291
+ It is expected that the corresponding head model method will be used as this function,
292
+ for example, the ``get_logits`` method of the ``SasRec`` class.
293
+
294
+ :return: callable function.
295
+ """
296
+ if self._logits_callback is None:
297
+ msg = "The callback for getting logits is not defined"
298
+ raise AttributeError(msg)
299
+ return self._logits_callback
300
+
301
+ @logits_callback.setter
302
+ def logits_callback(self, func: Optional[Callable]) -> None:
303
+ self._logits_callback = func
304
+
305
+ def forward(
306
+ self,
307
+ model_embeddings: torch.Tensor,
308
+ feature_tensors: TensorMap, # noqa: ARG002
309
+ positive_labels: torch.LongTensor,
310
+ negative_labels: torch.LongTensor,
311
+ padding_mask: torch.BoolTensor, # noqa: ARG002
312
+ target_padding_mask: torch.BoolTensor,
313
+ ) -> torch.Tensor:
314
+ """
315
+ forward(model_embeddings, positive_labels, negative_labels, target_padding_mask)
316
+
317
+ :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
318
+ :param positive_labels: labels of positive events
319
+ of shape ``(batch_size, sequence_length, num_positives)``.
320
+ :param negative_labels: labels of sampled negative events.
321
+
322
+ Expected shape:
323
+ - ``(batch_size, sequence_length, num_negatives)``
324
+ - ``(batch_size, num_negatives)``
325
+ - ``(num_negatives)`` - a case where the same negative events are used for the entire batch.
326
+ :param target_padding_mask: padding mask corresponding for ``positive_labels``
327
+ of shape ``(batch_size, sequence_length, num_positives)``
328
+
329
+ :return: computed loss value.
330
+ """
331
+ sampled = self.get_sampled_logits(
332
+ model_embeddings,
333
+ positive_labels,
334
+ negative_labels,
335
+ target_padding_mask,
336
+ )
337
+ positive_logits = sampled["positive_logits"] # [masked_batch_size, num_positives]
338
+ negative_logits = sampled["negative_logits"] # [masked_batch_size, num_negatives]
339
+ positive_labels = sampled["positive_labels"] # [masked_batch_size, num_positives]
340
+ negative_labels = sampled["negative_labels"] # [masked_batch_size, num_negatives] or [num_negatives]
341
+ target_padding_mask = sampled["target_padding_mask"] # [masked_batch_size, num_positives]
342
+
343
+ # [masked_batch_size, num_negatives] - assign low values to some negative logits
344
+ negative_logits = mask_negative_logits(
345
+ negative_logits,
346
+ negative_labels,
347
+ positive_labels,
348
+ self.negative_labels_ignore_index,
349
+ )
350
+
351
+ max_values = torch.max(
352
+ positive_logits.max(-1, keepdim=True).values,
353
+ negative_logits.max(-1, keepdim=True).values,
354
+ ) # [masked_batch_size, 1]
355
+ positive_logits = positive_logits - max_values
356
+ negative_logits = negative_logits - max_values
357
+
358
+ positive_logits = torch.exp(positive_logits)
359
+ positive_logits = positive_logits * target_padding_mask
360
+ # [masked_batch_size, num_positives] -> [masked_batch_size]
361
+ positive_logits = positive_logits.sum(-1)
362
+
363
+ negative_logits = torch.exp(negative_logits)
364
+ # [masked_batch_size, num_negatives] -> [masked_batch_size]
365
+ negative_logits = negative_logits.sum(-1)
366
+
367
+ probabilities = positive_logits / (positive_logits + negative_logits)
368
+ loss = -torch.clamp(
369
+ torch.log(probabilities + self.log_epsilon),
370
+ -self.clamp_border,
371
+ self.clamp_border,
372
+ )
373
+ return loss.mean()
@@ -0,0 +1,230 @@
1
+ from typing import Callable, Optional
2
+
3
+ import torch
4
+
5
+ from replay.data.nn import TensorMap
6
+
7
+ from .base import mask_negative_logits
8
+
9
+
10
+ class LogOutCE(torch.nn.Module):
11
+ """
12
+ LogOutCE loss.
13
+
14
+ .. math::
15
+
16
+ L_{\\text{InfoNCE}} = - \\sum_{p \\in P} \\log \\frac{ \\exp(\\mathrm{sim}(q, p))}
17
+ {\\exp(\\mathrm{sim}(q, p))
18
+ + \\sum_{n \\in N} \\exp(\\mathrm{sim}(q, n))}.
19
+
20
+ where q -- query embedding, P -- set of positive logits, N -- set of negative logits,
21
+ :math:`sim(\\cdot, \\cdot)` -- similaruty function.\n
22
+
23
+ The loss supports the calculation of logits for the case of multi-positive labels
24
+ (there are several labels for each position in the sequence).
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ cardinality: int,
30
+ negative_labels_ignore_index: int = -100,
31
+ **kwargs,
32
+ ):
33
+ """
34
+ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used.
35
+ You can pass all parameters for initializing the object via kwargs.
36
+
37
+ :param cardinality: number of unique items in vocabulary (catalog).
38
+ The specified cardinality value must not take into account the padding value.
39
+ :param negative_labels_ignore_index: padding value for negative labels.
40
+ This may be the case when negative labels
41
+ are formed at the preprocessing level, rather than the negative sampler.
42
+ The index is ignored and does not contribute to the loss.
43
+ Default: ``-100``.
44
+ """
45
+ super().__init__()
46
+ self.cardinality = cardinality
47
+ self.negative_labels_ignore_index = negative_labels_ignore_index
48
+ self._loss = torch.nn.CrossEntropyLoss(**kwargs)
49
+ self._logits_callback = None
50
+
51
+ @property
52
+ def logits_callback(
53
+ self,
54
+ ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
55
+ """
56
+ Property for calling a function for the logits computation.\n
57
+
58
+ This function is expected to receive model's last hidden state
59
+ and optionally item IDs, and return a logits tensor.
60
+
61
+ It is expected that the corresponding head model method will be used as this function,
62
+ for example, the ``get_logits`` method of the ``SasRec`` class.
63
+
64
+ :return: callable function.
65
+ """
66
+ if self._logits_callback is None:
67
+ msg = "The callback for getting logits is not defined"
68
+ raise AttributeError(msg)
69
+ return self._logits_callback
70
+
71
+ @logits_callback.setter
72
+ def logits_callback(self, func: Optional[Callable]) -> None:
73
+ self._logits_callback = func
74
+
75
+ def forward(
76
+ self,
77
+ model_embeddings: torch.Tensor,
78
+ feature_tensors: TensorMap, # noqa: ARG002
79
+ positive_labels: torch.LongTensor,
80
+ negative_labels: torch.LongTensor, # noqa: ARG002
81
+ padding_mask: torch.BoolTensor, # noqa: ARG002
82
+ target_padding_mask: torch.BoolTensor,
83
+ ) -> torch.Tensor:
84
+ """
85
+ forward(model_embeddings, positive_labels, target_padding_mask)
86
+ **Note**: At forward pass, the whole catalog of items is used as negatives.
87
+ Next, negative logits, corresponding to positions where negative labels
88
+ coincide with positive ones, are masked.
89
+
90
+ :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
91
+ :param positive_labels: ground truth labels of positive events
92
+ of shape (batch_size, sequence_length, num_positives).
93
+ :param target_padding_mask: padding mask corresponding for ``positive_labels``
94
+ of shape (batch_size, sequence_length, num_positives).
95
+ :return: computed loss value.
96
+ """
97
+ initial_target_padding_mask = target_padding_mask
98
+ num_positives = target_padding_mask.size(2)
99
+ # [batch_size, seq_len, num_positives] -> [batch_size * seq_len]
100
+ if num_positives == 1:
101
+ target_padding_mask = target_padding_mask.squeeze(-1)
102
+ else:
103
+ target_padding_mask = target_padding_mask.sum(-1).bool()
104
+ masked_batch_size = target_padding_mask.sum().item()
105
+
106
+ logits: torch.Tensor = self.logits_callback(model_embeddings) # [batch_size, seq_len, vocab_size]
107
+ all_negative_labels = torch.arange(self.cardinality, dtype=torch.long, device=positive_labels.device)
108
+
109
+ # [batch_size, seq_len, vocab_size] -> [masked_batch_size, vocab_size]
110
+ logits = logits[target_padding_mask]
111
+
112
+ # [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
113
+ positive_labels = positive_labels[target_padding_mask]
114
+
115
+ # [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
116
+ target_padding_mask = initial_target_padding_mask[target_padding_mask]
117
+
118
+ positive_ids = torch.arange(masked_batch_size, dtype=torch.long, device=positive_labels.device)
119
+ # [masked_batch_size, vocab_size] -> [masked_batch_size, num_positives]
120
+ positive_logits = logits[positive_ids.unsqueeze(-1), positive_labels]
121
+
122
+ # [masked_batch_size, vocab_size] - assign low values to some negative logits
123
+ negative_logits = mask_negative_logits(
124
+ logits,
125
+ all_negative_labels,
126
+ positive_labels,
127
+ self.negative_labels_ignore_index,
128
+ )
129
+
130
+ # [masked_batch_size, num_negatives] -> [masked_batch_size, 1, num_negatives]
131
+ negative_logits = negative_logits.unsqueeze(-2)
132
+ # [masked_batch_size, 1, num_negatives] -> [masked_batch_size, num_positives, num_negatives]
133
+ negative_logits = negative_logits.repeat(1, target_padding_mask.size(-1), 1)
134
+ # [masked_batch_size, num_positives, num_negatives] -> [masked_batch_size, num_negatives]
135
+ negative_logits = negative_logits[target_padding_mask]
136
+ # [masked_batch_size, num_positives] -> [masked_batch_size]
137
+ positive_logits = positive_logits[target_padding_mask]
138
+ # [masked_batch_size] -> [masked_batch_size, 1]
139
+ positive_logits = positive_logits.unsqueeze(-1)
140
+
141
+ # [masked_batch_size, 1 + num_negatives] - all logits
142
+ logits = torch.cat((positive_logits, negative_logits), dim=-1)
143
+ # [masked_batch_size] - positives are always at 0 position for all recommendation points
144
+ target = torch.zeros(logits.size(0), dtype=torch.long, device=positive_labels.device)
145
+ # [masked_batch_size] - loss for all recommendation points
146
+ loss = self._loss(logits, target)
147
+ return loss
148
+
149
+
150
+ class LogOutCEWeighted(LogOutCE):
151
+ """
152
+ LogOutCE loss with sample weights enabling.
153
+
154
+ .. math::
155
+
156
+ L_{\\text{InfoNCE}} = - \\sum_{p \\in P} \\log \\frac{ \\exp(\\mathrm{sim}(q, p))}
157
+ {\\exp(\\mathrm{sim}(q, p))
158
+ + \\sum_{n \\in N} \\exp(\\mathrm{sim}(q, n))}.
159
+
160
+ where q -- query embedding, P -- set of positive logits, N -- set of negative logits,
161
+ :math:`sim(\\cdot, \\cdot)` -- similaruty function.\n
162
+
163
+ In addition to calculating the standard loss,
164
+ weights are applied for each sample.
165
+ Therefore, it is expected that the sample weights will be in the generated batch,
166
+ which is fed into the model.
167
+
168
+ The loss supports the calculation of logits for the case of multi-positive labels
169
+ (there are several labels for each position in the sequence).
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ cardinality: int,
175
+ feature_name: str,
176
+ negative_labels_ignore_index: int = -100,
177
+ **kwargs,
178
+ ):
179
+ """
180
+ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``.
181
+ You can pass all other parameters for initializing the object via kwargs.
182
+
183
+ :param cardinality: number of unique items in vocabulary (catalog).
184
+ The specified cardinality value must not take into account the padding value.
185
+ :param feature_name: the name of the key in the batch.
186
+ The tensor is expected to contain sample weights.
187
+ :param negative_labels_ignore_index: padding value for negative labels.
188
+ This may be the case when negative labels
189
+ are formed at the preprocessing level, rather than the negative sampler.
190
+ The index is ignored and does not contribute to the loss.
191
+ Default: ``-100``.
192
+ """
193
+ super().__init__(
194
+ cardinality=cardinality,
195
+ negative_labels_ignore_index=negative_labels_ignore_index,
196
+ )
197
+ self.feature_name = feature_name
198
+ self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs)
199
+
200
+ def forward(
201
+ self,
202
+ model_embeddings: torch.Tensor,
203
+ feature_tensors: TensorMap,
204
+ positive_labels: torch.LongTensor,
205
+ negative_labels: torch.LongTensor, # noqa: ARG002
206
+ padding_mask: torch.BoolTensor, # noqa: ARG002
207
+ target_padding_mask: torch.BoolTensor,
208
+ ) -> torch.Tensor:
209
+ """
210
+ forward(model_embeddings, feature_tensors, positive_labels, target_padding_mask)
211
+ **Note**: At forward pass, the whole catalog of items is used as negatives.
212
+ Next, negative logits, corresponding to positions where negative labels
213
+ coincide with positive ones, are masked.
214
+
215
+ :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
216
+ :param feature_tensors: a dictionary of tensors from dataloader.
217
+ This dictionary is expected to contain a key with the name ``feature_name``,
218
+ which is specified in the constructor.
219
+ Expected shape of tensor ``(batch_size, sequence_length, num_positives)``.
220
+ :param positive_labels: ground truth labels of positive events
221
+ of shape (batch_size, sequence_length, num_positives).
222
+ :param target_padding_mask: padding mask corresponding for ``positive_labels``
223
+ of shape (batch_size, sequence_length, num_positives).
224
+ :return: computed loss value.
225
+ """
226
+ loss: torch.Tensor = super().forward(model_embeddings, None, positive_labels, None, None, target_padding_mask)
227
+ sample_weight = feature_tensors[self.feature_name]
228
+ sample_weight = sample_weight[target_padding_mask]
229
+ loss = (loss * sample_weight).mean()
230
+ return loss
replay/nn/mask.py ADDED
@@ -0,0 +1,87 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Protocol
3
+
4
+ import torch
5
+
6
+ from replay.data.nn.schema import TensorMap
7
+
8
+
9
+ class AttentionMaskProto(Protocol):
10
+ def __call__(self, feature_tensor: TensorMap, padding_mask: torch.BoolTensor) -> torch.Tensor: ...
11
+
12
+
13
+ class AttentionMaskBase(ABC):
14
+ def __init__(
15
+ self,
16
+ num_heads: int,
17
+ ) -> None:
18
+ self.num_heads = num_heads
19
+
20
+ def __call__(
21
+ self,
22
+ feature_tensor: TensorMap,
23
+ padding_mask: torch.BoolTensor,
24
+ ) -> torch.FloatTensor:
25
+ """
26
+ :param feature_tensor: dict of features tensors.
27
+ :param padding_mask: Padding mask where ``0`` - ``<PAD>``, ``1`` - otherwise.
28
+ :returns: Float attention mask of shape ``(B * num_heads, L, L)``,
29
+ where ``-inf`` for ``<PAD>``, ``0`` - otherwise.
30
+ """
31
+ attention_mask = self._get_attention_mask(feature_tensor)
32
+
33
+ diagonal_attention_mask = torch.diag(
34
+ torch.ones(padding_mask.size(1), dtype=torch.bool, device=padding_mask.device)
35
+ )
36
+ # (B, L) -> (B, 1, 1, L)
37
+ key_padding_mask = padding_mask.unsqueeze(1).unsqueeze(1)
38
+ # (B, 1, 1, L) -> (B, 1, L, L), where 0 - PAD, 1 - otherwise
39
+ key_padding_mask = key_padding_mask | diagonal_attention_mask
40
+
41
+ attention_mask = (attention_mask & key_padding_mask).float()
42
+ attention_mask = attention_mask.masked_fill(attention_mask == 0, float("-inf")).masked_fill(
43
+ attention_mask == 1, 0.0
44
+ )
45
+ if attention_mask.size(1) != self.num_heads and attention_mask.shape[1] == 1:
46
+ # for default attention_mask of shape (L, L) it becomes (B, 1, L, L)
47
+ # (B, 1, L, L) -> (B, num_heads, L, L)
48
+ attention_mask = attention_mask.repeat(1, self.num_heads, 1, 1)
49
+ # (B, num_heads, L, L) -> (B * num_heads, L, L)
50
+ attention_mask = attention_mask.reshape(-1, *attention_mask.shape[-2:])
51
+ return attention_mask
52
+
53
+ @abstractmethod
54
+ def _get_attention_mask(self, feature_tensor: TensorMap) -> torch.Tensor:
55
+ raise NotImplementedError() # pragma: no cover
56
+
57
+
58
+ class DefaultAttentionMask(AttentionMaskBase):
59
+ """
60
+ Constructs a float lower-triangular attenstion mask
61
+ of shape ``(batch_size * num_heads, sequence_length, sequence_length)``,
62
+ where ``-inf`` for ``<PAD>``, ``0`` - otherwise.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ reference_feature_name: str,
68
+ num_heads: int,
69
+ ) -> None:
70
+ """
71
+ :param reference_feature_name: To build a mask, you need a reference tensor.
72
+ So you need to pass the name of the tensor, which will definitely be in the dictionary of feature tensors.
73
+ The second dimension (1 in zero indexing) of the tensor will be used to construct the attention mask.
74
+ :param num_heads: Number of attention heads.
75
+ """
76
+ super().__init__(num_heads)
77
+ self._feature_name = reference_feature_name
78
+
79
+ def _get_attention_mask(self, feature_tensor: TensorMap) -> torch.Tensor:
80
+ input_sequence = feature_tensor[self._feature_name]
81
+ return torch.tril(
82
+ torch.ones(
83
+ (input_sequence.size(1), input_sequence.size(1)),
84
+ dtype=torch.bool,
85
+ device=input_sequence.device,
86
+ )
87
+ )