replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
replay/nn/loss/bce.py ADDED
@@ -0,0 +1,216 @@
1
+ from typing import Callable, Optional
2
+
3
+ import torch
4
+
5
+ from replay.data.nn import TensorMap
6
+
7
+ from .base import SampledLossBase, mask_negative_logits
8
+
9
+
10
+ class BCE(torch.nn.Module):
11
+ """
12
+ Pointwise Binary Cross-Entropy loss.
13
+ Calculates loss over all items catalog.
14
+
15
+ The loss supports the calculation of logits for the case of multi-positive labels
16
+ (there are several labels for each position in the sequence).
17
+ """
18
+
19
+ def __init__(self, **kwargs):
20
+ """
21
+ To calculate the loss, ``torch.nn.BCEWithLogitsLoss`` is used with the parameter ``reduction="sum"``.
22
+ You can pass all other parameters for initializing the object via kwargs.
23
+ """
24
+ super().__init__()
25
+ self._loss = torch.nn.BCEWithLogitsLoss(reduction="sum", **kwargs)
26
+ self._logits_callback = None
27
+
28
+ @property
29
+ def logits_callback(
30
+ self,
31
+ ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
32
+ """
33
+ Property for calling a function for the logits computation.\n
34
+
35
+ This function is expected to receive model's last hidden state
36
+ and optionally item IDs, and return a logits tensor.
37
+
38
+ It is expected that the corresponding head model method will be used as this function,
39
+ for example, the ``get_logits`` method of the ``SasRec`` class.
40
+
41
+ :return: callable function.
42
+ """
43
+ if self._logits_callback is None:
44
+ msg = "The callback for getting logits is not defined"
45
+ raise AttributeError(msg)
46
+ return self._logits_callback
47
+
48
+ @logits_callback.setter
49
+ def logits_callback(self, func: Optional[Callable]) -> None:
50
+ self._logits_callback = func
51
+
52
+ def forward(
53
+ self,
54
+ model_embeddings: torch.Tensor,
55
+ feature_tensors: TensorMap, # noqa: ARG002
56
+ positive_labels: torch.LongTensor,
57
+ negative_labels: torch.LongTensor, # noqa: ARG002
58
+ padding_mask: torch.BoolTensor, # noqa: ARG002
59
+ target_padding_mask: torch.BoolTensor,
60
+ ) -> torch.Tensor:
61
+ """
62
+ forward(model_embeddings, positive_labels, target_padding_mask)
63
+ :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
64
+ :param positive_labels: labels of positive events
65
+ of shape ``(batch_size, sequence_length, num_positives)``.
66
+ :param target_padding_mask: padding mask corresponding for `positive_labels`
67
+ of shape ``(batch_size, sequence_length, num_positives)``.
68
+ :return: computed loss value.
69
+ """
70
+ logits = self.logits_callback(model_embeddings)
71
+
72
+ # [batch_size, seq_len, num_positives] -> [batch_size, seq_len]
73
+ if target_padding_mask.size(-1) == 1:
74
+ target_padding_mask.squeeze_(-1)
75
+ else:
76
+ target_padding_mask = target_padding_mask.sum(-1).bool()
77
+
78
+ # Take only logits which correspond to non-padded tokens
79
+ # [batch_size, seq_len, vocab_size] -> [masked_batch_size, vocab_size]
80
+ logits = logits[target_padding_mask]
81
+
82
+ # [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
83
+ labels = positive_labels[target_padding_mask]
84
+
85
+ bce_labels = torch.zeros_like(logits)
86
+
87
+ # Fill positives with ones, all negatives are zeros
88
+ bce_labels.scatter_(
89
+ dim=-1,
90
+ index=labels,
91
+ value=1,
92
+ )
93
+
94
+ loss = self._loss(logits, bce_labels) / logits.size(0)
95
+ return loss
96
+
97
+
98
+ class BCESampled(SampledLossBase):
99
+ """
100
+ Sampled Pointwise Binary Cross-Entropy loss (BCE with negative sampling).
101
+ Calculates loss between one positive item and K negatively sampled items.
102
+
103
+ The loss supports the calculation of logits for the case of multi-positive labels
104
+ (there are several labels for each position in the sequence).
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ log_epsilon: float = 1e-6,
110
+ clamp_border: float = 100.0,
111
+ negative_labels_ignore_index: int = -100,
112
+ ):
113
+ """
114
+ :param log_epsilon: correction to avoid zero in the logarithm during loss calculating.
115
+ Default: ``1e-6``.
116
+ :param clamp_border: upper bound for clamping loss tensor, lower bound will be setted to -`clamp_border`.
117
+ Default: ``100.0``.
118
+ :param negative_labels_ignore_index: padding value for negative labels.
119
+ This may be the case when negative labels
120
+ are formed at the preprocessing level, rather than the negative sampler.
121
+ The index is ignored and does not contribute to the loss.
122
+ Default: ``-100``.
123
+ """
124
+ super().__init__()
125
+ self.log_epsilon = log_epsilon
126
+ self.clamp_border = clamp_border
127
+ self.negative_labels_ignore_index = negative_labels_ignore_index
128
+ self._logits_callback = None
129
+
130
+ @property
131
+ def logits_callback(
132
+ self,
133
+ ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
134
+ """
135
+ Property for calling a function for the logits computation.\n
136
+
137
+ This function is expected to receive model's last hidden state
138
+ and optionally item IDs, and return a logits tensor.
139
+
140
+ It is expected that the corresponding head model method will be used as this function,
141
+ for example, the ``get_logits`` method of the ``SasRec`` class.
142
+
143
+ :return: callable function.
144
+ """
145
+ if self._logits_callback is None:
146
+ msg = "The callback for getting logits is not defined"
147
+ raise AttributeError(msg)
148
+ return self._logits_callback
149
+
150
+ @logits_callback.setter
151
+ def logits_callback(self, func: Optional[Callable]) -> None:
152
+ self._logits_callback = func
153
+
154
+ def forward(
155
+ self,
156
+ model_embeddings: torch.Tensor,
157
+ feature_tensors: TensorMap, # noqa: ARG002
158
+ positive_labels: torch.LongTensor,
159
+ negative_labels: torch.LongTensor,
160
+ padding_mask: torch.BoolTensor, # noqa: ARG002
161
+ target_padding_mask: torch.BoolTensor,
162
+ ) -> torch.Tensor:
163
+ """
164
+ forward(model_embeddings, positive_labels, negative_labels, target_padding_mask)
165
+
166
+ :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
167
+ :param positive_labels: labels of positive events
168
+ of shape ``(batch_size, sequence_length, num_positives)``.
169
+ :param negative_labels: labels of sampled negative events.
170
+ Expected shape:
171
+
172
+ - ``(batch_size, sequence_length, num_negatives)``
173
+ - ``(batch_size, num_negatives)``
174
+ - ``(num_negatives)`` - a case where the same negative events are used for the entire batch.
175
+ :param target_padding_mask: padding mask corresponding for ``positive_labels``
176
+ of shape ``(batch_size, sequence_length, num_positives)``
177
+ :return: computed loss value.
178
+
179
+ """
180
+ sampled = self.get_sampled_logits(
181
+ model_embeddings,
182
+ positive_labels,
183
+ negative_labels,
184
+ target_padding_mask,
185
+ )
186
+ positive_logits = sampled["positive_logits"] # [masked_batch_size, num_positives]
187
+ negative_logits = sampled["negative_logits"] # [masked_batch_size, num_negatives]
188
+ positive_labels = sampled["positive_labels"] # [masked_batch_size, num_positives]
189
+ negative_labels = sampled["negative_labels"] # [masked_batch_size, num_negatives] or [num_negatives]
190
+
191
+ # Reject negative samples matching target label & correct for remaining samples
192
+ negative_logits = mask_negative_logits(
193
+ negative_logits,
194
+ negative_labels,
195
+ positive_labels,
196
+ self.negative_labels_ignore_index,
197
+ )
198
+
199
+ positive_prob = torch.sigmoid(positive_logits)
200
+ negative_prob = torch.sigmoid(negative_logits)
201
+
202
+ positive_loss = torch.clamp(
203
+ torch.log((positive_prob) + self.log_epsilon),
204
+ -self.clamp_border,
205
+ self.clamp_border,
206
+ ).sum()
207
+ negative_loss = torch.clamp(
208
+ torch.log((1 - negative_prob) + self.log_epsilon),
209
+ -self.clamp_border,
210
+ self.clamp_border,
211
+ ).sum()
212
+
213
+ loss = -(positive_loss + negative_loss)
214
+ loss /= positive_logits.size(0)
215
+
216
+ return loss
replay/nn/loss/ce.py ADDED
@@ -0,0 +1,317 @@
1
+ from typing import Callable, Optional
2
+
3
+ import torch
4
+
5
+ from replay.data.nn import TensorMap
6
+
7
+ from .base import SampledLossBase, mask_negative_logits
8
+
9
+
10
+ class CE(torch.nn.Module):
11
+ """
12
+ Full Cross-Entropy loss
13
+ Calculates loss over all items catalog.
14
+ """
15
+
16
+ def __init__(self, **kwargs):
17
+ """
18
+ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used.
19
+ You can pass all parameters for initializing the object via kwargs.
20
+ """
21
+ super().__init__()
22
+ self._loss = torch.nn.CrossEntropyLoss(**kwargs)
23
+ self._logits_callback = None
24
+
25
+ @property
26
+ def logits_callback(
27
+ self,
28
+ ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
29
+ """
30
+ Property for calling a function for the logits computation.\n
31
+
32
+ This function is expected to receive model's last hidden state
33
+ and optionally item IDs, and return a logits tensor.
34
+
35
+ It is expected that the corresponding head model method will be used as this function,
36
+ for example, the ``get_logits`` method of the ``SasRec`` class.
37
+
38
+ :return: callable function.
39
+ """
40
+ if self._logits_callback is None:
41
+ msg = "The callback for getting logits is not defined"
42
+ raise AttributeError(msg)
43
+ return self._logits_callback
44
+
45
+ @logits_callback.setter
46
+ def logits_callback(self, func: Optional[Callable]) -> None:
47
+ self._logits_callback = func
48
+
49
+ def forward(
50
+ self,
51
+ model_embeddings: torch.Tensor,
52
+ feature_tensors: TensorMap, # noqa: ARG002
53
+ positive_labels: torch.LongTensor,
54
+ negative_labels: torch.LongTensor, # noqa: ARG002
55
+ padding_mask: torch.BoolTensor, # noqa: ARG002
56
+ target_padding_mask: torch.BoolTensor,
57
+ ) -> torch.Tensor:
58
+ """
59
+ forward(model_embeddings, positive_labels, target_padding_mask)
60
+ :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
61
+ :param positive_labels: labels of positive events
62
+ of shape ``(batch_size, sequence_length, num_positives)``.
63
+ :param target_padding_mask: padding mask corresponding for `positive_labels`
64
+ of shape ``(batch_size, sequence_length, num_positives)``.
65
+ :return: computed loss value.
66
+ """
67
+ if positive_labels.size(-1) != 1:
68
+ msg = "The case of multi-positive labels is not supported in the CE loss"
69
+ raise NotImplementedError(msg)
70
+ logits: torch.Tensor = self.logits_callback(model_embeddings) # [batch_size, seq_len, vocab_size]
71
+ labels = positive_labels.masked_fill(
72
+ mask=(~target_padding_mask),
73
+ value=self._loss.ignore_index,
74
+ ) # [batch_size, seq_len, 1]
75
+
76
+ # [batch_size, seq_len, vocab_size] -> [batch_size * seq_len, vocab_size]
77
+ logits_flat = logits.view(-1, logits.size(-1))
78
+ # [batch_size, seq_len, 1] -> [batch_size * seq_len]
79
+ labels_flat: torch.LongTensor = labels.view(-1)
80
+ loss = self._loss(logits_flat, labels_flat)
81
+ return loss
82
+
83
+
84
+ class CEWeighted(CE):
85
+ """
86
+ Full Cross-Entropy loss
87
+ Calculates loss over all items catalog.
88
+
89
+ In addition to calculating the standard loss,
90
+ weights are applied for each sample.
91
+ Therefore, it is expected that the sample weights will be in the generated batch,
92
+ which is fed into the model.
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ feature_name: str,
98
+ **kwargs,
99
+ ):
100
+ """
101
+ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``.
102
+ You can pass all other parameters for initializing the object via kwargs.
103
+
104
+ :param feature_name: the name of the key in the batch.
105
+ The tensor is expected to contain sample weights.
106
+ """
107
+ super().__init__()
108
+ self.feature_name = feature_name
109
+ self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs)
110
+
111
+ def forward(
112
+ self,
113
+ model_embeddings: torch.Tensor,
114
+ feature_tensors: TensorMap,
115
+ positive_labels: torch.LongTensor,
116
+ negative_labels: torch.LongTensor, # noqa: ARG002
117
+ padding_mask: torch.BoolTensor, # noqa: ARG002
118
+ target_padding_mask: torch.BoolTensor,
119
+ ) -> torch.Tensor:
120
+ """
121
+ forward(model_embeddings, feature_tensors, positive_labels, target_padding_mask)
122
+ :param feature_tensors: a dictionary of tensors from dataloader.
123
+ This dictionary is expected to contain a key with the name ``feature_name``,
124
+ which is specified in the constructor.
125
+ Expected shape of tensor ``(batch_size, sequence_length, num_positives)``.
126
+ :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
127
+ :param positive_labels: labels of positive events
128
+ of shape ``(batch_size, sequence_length, num_positives)``.
129
+ :param target_padding_mask: padding mask corresponding for `positive_labels`
130
+ of shape ``(batch_size, sequence_length, num_positives)``.
131
+ :return: computed loss value.
132
+ """
133
+ loss: torch.Tensor = super().forward(
134
+ model_embeddings,
135
+ None,
136
+ positive_labels,
137
+ None,
138
+ None,
139
+ target_padding_mask,
140
+ )
141
+ sample_weight = feature_tensors[self.feature_name]
142
+ loss = (loss * sample_weight).mean()
143
+ return loss
144
+
145
+
146
+ class CESampled(SampledLossBase):
147
+ """
148
+ Sampled Cross-Entropy loss (Cross-Entropy with negative sampling).
149
+ Calculates loss between one positive item and K negatively sampled items.
150
+
151
+ The loss supports the calculation of logits for the case of multi-positive labels
152
+ (there are several labels for each position in the sequence).
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ negative_labels_ignore_index: int = -100,
158
+ **kwargs,
159
+ ):
160
+ """
161
+ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used.
162
+ You can pass all parameters for initializing the object via kwargs.
163
+
164
+ :param negative_labels_ignore_index: padding value for negative labels.
165
+ This may be the case when negative labels
166
+ are formed at the preprocessing level, rather than the negative sampler.
167
+ The index is ignored and does not contribute to the loss.
168
+ Default: ``-100``.
169
+ """
170
+ super().__init__()
171
+ self.negative_labels_ignore_index = negative_labels_ignore_index
172
+ self._loss = torch.nn.CrossEntropyLoss(**kwargs)
173
+ self._logits_callback = None
174
+
175
+ @property
176
+ def logits_callback(
177
+ self,
178
+ ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
179
+ """
180
+ Property for calling a function for the logits computation.\n
181
+
182
+ This function is expected to receive model's last hidden state
183
+ and optionally item IDs, and return a logits tensor.
184
+
185
+ It is expected that the corresponding head model method will be used as this function,
186
+ for example, the ``get_logits`` method of the ``SasRec`` class.
187
+
188
+ :return: callable function.
189
+ """
190
+ if self._logits_callback is None:
191
+ msg = "The callback for getting logits is not defined"
192
+ raise AttributeError(msg)
193
+ return self._logits_callback
194
+
195
+ @logits_callback.setter
196
+ def logits_callback(self, func: Optional[Callable]) -> None:
197
+ self._logits_callback = func
198
+
199
+ def forward(
200
+ self,
201
+ model_embeddings: torch.Tensor,
202
+ feature_tensors: TensorMap, # noqa: ARG002
203
+ positive_labels: torch.LongTensor,
204
+ negative_labels: torch.LongTensor,
205
+ padding_mask: torch.BoolTensor, # noqa: ARG002
206
+ target_padding_mask: torch.BoolTensor,
207
+ ) -> torch.Tensor:
208
+ """
209
+ forward(model_embeddings, positive_labels, negative_labels, target_padding_mask)
210
+
211
+ :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
212
+ :param positive_labels: labels of positive events
213
+ of shape ``(batch_size, sequence_length, num_positives)``.
214
+ :param negative_labels: labels of sampled negative events.
215
+
216
+ Expected shape:
217
+ - ``(batch_size, sequence_length, num_negatives)``
218
+ - ``(batch_size, num_negatives)``
219
+ - ``(num_negatives)`` - a case where the same negative events are used for the entire batch.
220
+ :param target_padding_mask: padding mask corresponding for ``positive_labels``
221
+ of shape ``(batch_size, sequence_length, num_positives)``
222
+
223
+ :return: computed loss value.
224
+ """
225
+ sampled = self.get_sampled_logits(
226
+ model_embeddings,
227
+ positive_labels,
228
+ negative_labels,
229
+ target_padding_mask,
230
+ )
231
+ positive_logits = sampled["positive_logits"] # [masked_batch_size, num_positives]
232
+ negative_logits = sampled["negative_logits"] # [masked_batch_size, num_negatives]
233
+ positive_labels = sampled["positive_labels"] # [masked_batch_size, num_positives]
234
+ negative_labels = sampled["negative_labels"] # [masked_batch_size, num_negatives] or [num_negatives]
235
+
236
+ # [masked_batch_size, num_negatives] - assign low values to some negative logits
237
+ negative_logits = mask_negative_logits(
238
+ negative_logits,
239
+ negative_labels,
240
+ positive_labels,
241
+ self.negative_labels_ignore_index,
242
+ )
243
+ # [masked_batch_size, 1 + num_negatives] - all logits
244
+ logits = torch.cat((positive_logits, negative_logits), dim=-1)
245
+ # [masked_batch_size] - positives are always at 0 position for all recommendation points
246
+ target = torch.zeros(positive_logits.size(0), dtype=torch.long, device=logits.device)
247
+ # [masked_batch_size] - loss for all recommendation points
248
+ loss = self._loss(logits, target)
249
+ return loss
250
+
251
+
252
+ class CESampledWeighted(CESampled):
253
+ """
254
+ Sampled Cross-Entropy loss (Cross-Entropy with negative sampling).
255
+ Calculates loss between one positive item and K negatively sampled items.
256
+
257
+ In addition to calculating the standard loss,
258
+ weights are applied for each sample.
259
+ Therefore, it is expected that the sample weights will be in the generated batch,
260
+ which is fed into the model.
261
+
262
+ The loss supports the calculation of logits for the case of multi-positive labels
263
+ (there are several labels for each position in the sequence).
264
+ """
265
+
266
+ def __init__(
267
+ self,
268
+ feature_name: str,
269
+ negative_labels_ignore_index: int = -100,
270
+ **kwargs,
271
+ ):
272
+ """
273
+ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``.
274
+ You can pass all other parameters for initializing the object via kwargs.
275
+
276
+ :param feature_name: the name of the key in the batch.
277
+ The tensor is expected to contain sample weights.
278
+ :param negative_labels_ignore_index: padding value for negative labels.
279
+ This may be the case when negative labels
280
+ are formed at the preprocessing level, rather than the negative sampler.
281
+ The index is ignored and does not contribute to the loss.
282
+ Default: ``-100``.
283
+ """
284
+ super().__init__(negative_labels_ignore_index=negative_labels_ignore_index)
285
+ self.feature_name = feature_name
286
+ self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs)
287
+
288
+ def forward(
289
+ self,
290
+ model_embeddings: torch.Tensor,
291
+ feature_tensors: TensorMap,
292
+ positive_labels: torch.LongTensor,
293
+ negative_labels: torch.LongTensor,
294
+ padding_mask: torch.BoolTensor, # noqa: ARG002
295
+ target_padding_mask: torch.BoolTensor,
296
+ ) -> torch.Tensor:
297
+ """
298
+ forward(model_embeddings, feature_tensors, positive_labels, negative_labels, target_padding_mask)
299
+ :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
300
+ :param feature_tensors: a dictionary of tensors from dataloader.
301
+ This dictionary is expected to contain a key with the name ``feature_name``,
302
+ which is specified in the constructor.
303
+ Expected shape of tensor ``(batch_size, sequence_length, num_positives)``.
304
+ :param positive_labels: labels of positive events
305
+ of shape ``(batch_size, sequence_length, num_positives)``.
306
+ :param negative_labels: labels of sampled negative events of shape (num_negatives).
307
+ :param target_padding_mask: padding mask corresponding for ``positive_labels``
308
+ of shape ``(batch_size, sequence_length, num_positives)``
309
+ :return: computed loss value.
310
+ """
311
+ loss: torch.Tensor = super().forward(
312
+ model_embeddings, None, positive_labels, negative_labels, None, target_padding_mask
313
+ )
314
+ sample_weight = feature_tensors[self.feature_name]
315
+ sample_weight = sample_weight[target_padding_mask]
316
+ loss = (loss * sample_weight).mean()
317
+ return loss