replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay_rec-0.20.3.dist-info/RECORD +0 -138
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
from typing import Callable, Optional, TypedDict
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from replay.data.nn import TensorMap
|
|
6
|
+
|
|
7
|
+
from .base import SampledLossBase, mask_negative_logits
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LogInCESampledOutput(TypedDict):
|
|
11
|
+
positive_logits: torch.Tensor
|
|
12
|
+
negative_logits: torch.Tensor
|
|
13
|
+
positive_labels: torch.LongTensor
|
|
14
|
+
negative_labels: torch.LongTensor
|
|
15
|
+
target_padding_mask: torch.BoolTensor
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LogInCEBase(SampledLossBase):
|
|
19
|
+
def get_sampled_logits(
|
|
20
|
+
self,
|
|
21
|
+
model_embeddings: torch.Tensor,
|
|
22
|
+
positive_labels: torch.LongTensor, # [batch_size, seq_len, num_positives]
|
|
23
|
+
negative_labels: torch.LongTensor, # [num_negatives] or [batch_size, seq_len, num_negatives]
|
|
24
|
+
target_padding_mask: torch.BoolTensor, # [batch_size, seq_len, num_positives]
|
|
25
|
+
) -> LogInCESampledOutput:
|
|
26
|
+
"""
|
|
27
|
+
The function of calculating positive and negative logits in LogInCE losses.
|
|
28
|
+
Based on the embeddings from the model, positive and negative labels.
|
|
29
|
+
|
|
30
|
+
The function supports the calculation of logits for the case of multi-positive labels
|
|
31
|
+
(there are several labels for each position in the sequence).
|
|
32
|
+
|
|
33
|
+
:param model_embeddings: Embeddings from the model. This is usually the last hidden state.
|
|
34
|
+
Expected shape: ``(batch_size, sequence_length, embedding_dim)``
|
|
35
|
+
:param positive_labels: a tensor containing labels with positive events.
|
|
36
|
+
Expected shape: ``(batch_size, sequence_length, num_positives)``
|
|
37
|
+
:param negative_labels: a tensor containing labels with negative events.
|
|
38
|
+
Expected shape:
|
|
39
|
+
- ``(batch_size, sequence_length, num_negatives)``
|
|
40
|
+
- ``(batch_size, num_negatives)``
|
|
41
|
+
- ``(num_negatives)`` - a case where the same negative events are used for the entire batch.
|
|
42
|
+
:param target_padding_mask: Padding mask for ``positive_labels`` (targets).
|
|
43
|
+
``False`` value indicates that the corresponding ``key`` value will be ignored.
|
|
44
|
+
Expected shape: ``(batch_size, sequence_length, num_positives)``
|
|
45
|
+
|
|
46
|
+
:returns: LogInCESampledOutput. A dictionary containing positive and negative logits with labels.
|
|
47
|
+
"""
|
|
48
|
+
################## SHAPE CHECKING STAGE START ##################
|
|
49
|
+
batch_size, seq_len, num_positives = positive_labels.size()
|
|
50
|
+
assert target_padding_mask.size() == (batch_size, seq_len, num_positives)
|
|
51
|
+
num_negatives = negative_labels.size(-1)
|
|
52
|
+
|
|
53
|
+
if negative_labels.size() == (batch_size, num_negatives):
|
|
54
|
+
# [batch_size, num_negatives] -> [batch_size, 1, num_negatives]
|
|
55
|
+
negative_labels = negative_labels.unsqueeze(1).repeat(1, seq_len, 1)
|
|
56
|
+
|
|
57
|
+
assert negative_labels.size() == (batch_size, seq_len, num_negatives) or negative_labels.dim() == 1
|
|
58
|
+
################## SHAPE CHECKING STAGE END ##################
|
|
59
|
+
|
|
60
|
+
# Get output embedding for every user event
|
|
61
|
+
embedding_dim = model_embeddings.size(-1)
|
|
62
|
+
assert model_embeddings.size() == (batch_size, seq_len, embedding_dim)
|
|
63
|
+
|
|
64
|
+
# [batch_size, seq_len, num_positives] -> [batch_size, seq_len]
|
|
65
|
+
masked_target_padding_mask: torch.BoolTensor = target_padding_mask.sum(-1).bool()
|
|
66
|
+
masked_batch_size = masked_target_padding_mask.sum().item()
|
|
67
|
+
|
|
68
|
+
# Apply target mask
|
|
69
|
+
# [batch_size, seq_len, emb_dim] -> [masked_batch_size, emb_dim]
|
|
70
|
+
model_embeddings = model_embeddings[masked_target_padding_mask]
|
|
71
|
+
assert model_embeddings.size() == (masked_batch_size, embedding_dim)
|
|
72
|
+
|
|
73
|
+
# [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
|
|
74
|
+
positive_labels = positive_labels[masked_target_padding_mask]
|
|
75
|
+
assert positive_labels.size() == (masked_batch_size, num_positives)
|
|
76
|
+
|
|
77
|
+
if negative_labels.dim() > 1: # pragma: no cover
|
|
78
|
+
# [batch_size, seq_len, num_negatives] -> [masked_batch_size, num_negatives]
|
|
79
|
+
negative_labels = negative_labels[masked_target_padding_mask]
|
|
80
|
+
assert negative_labels.size() == (masked_batch_size, num_negatives)
|
|
81
|
+
|
|
82
|
+
positive_logits = self.logits_callback(model_embeddings, positive_labels)
|
|
83
|
+
assert positive_logits.size() == (masked_batch_size, num_positives)
|
|
84
|
+
|
|
85
|
+
negative_logits = self.logits_callback(model_embeddings, negative_labels)
|
|
86
|
+
assert negative_logits.size() == (masked_batch_size, num_negatives)
|
|
87
|
+
|
|
88
|
+
# [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
|
|
89
|
+
target_padding_mask = target_padding_mask[masked_target_padding_mask]
|
|
90
|
+
assert target_padding_mask.size() == (masked_batch_size, num_positives)
|
|
91
|
+
|
|
92
|
+
return {
|
|
93
|
+
"positive_logits": positive_logits,
|
|
94
|
+
"negative_logits": negative_logits,
|
|
95
|
+
"positive_labels": positive_labels,
|
|
96
|
+
"negative_labels": negative_labels,
|
|
97
|
+
"target_padding_mask": target_padding_mask,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class LogInCE(LogInCEBase):
|
|
102
|
+
"""
|
|
103
|
+
LogInCE loss.
|
|
104
|
+
|
|
105
|
+
.. math::
|
|
106
|
+
|
|
107
|
+
L_{\\text{InfoNCE}} = -\\log \\frac{\\sum_{p \\in P}
|
|
108
|
+
\\exp(\\mathrm{sim}(q, p))}{\\sum_{p \\in P}
|
|
109
|
+
\\exp(\\mathrm{sim}(q, p)) + \\sum_{n \\in N} \\exp(\\mathrm{sim}(q, n))},
|
|
110
|
+
|
|
111
|
+
where q -- query embedding, P -- set of positive logits, N -- set of negative logits,
|
|
112
|
+
:math:`sim(\\cdot, \\cdot)` -- similaruty function.
|
|
113
|
+
|
|
114
|
+
The loss supports the calculation of logits for the case of multi-positive labels
|
|
115
|
+
(there are several labels for each position in the sequence).
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
cardinality: int,
|
|
121
|
+
log_epsilon: float = 1e-6,
|
|
122
|
+
clamp_border: float = 100.0,
|
|
123
|
+
negative_labels_ignore_index: int = -100,
|
|
124
|
+
):
|
|
125
|
+
"""
|
|
126
|
+
:param cardinality: number of unique items in vocabulary (catalog).
|
|
127
|
+
The specified cardinality value must not take into account the padding value.
|
|
128
|
+
:param log_epsilon: correction to avoid zero in the logarithm during loss calculating.
|
|
129
|
+
Default: ``1e-6``.
|
|
130
|
+
:param clamp_border: upper bound for clamping loss tensor, lower bound will be setted to ``-clamp_border``.
|
|
131
|
+
Default: ``100.0``.
|
|
132
|
+
:param negative_labels_ignore_index: padding value for negative labels.
|
|
133
|
+
This may be the case when negative labels
|
|
134
|
+
are formed at the preprocessing level, rather than the negative sampler.
|
|
135
|
+
The index is ignored and does not contribute to the loss.
|
|
136
|
+
Default: ``-100``.
|
|
137
|
+
"""
|
|
138
|
+
super().__init__()
|
|
139
|
+
self.cardinality = cardinality
|
|
140
|
+
self.log_epsilon = log_epsilon
|
|
141
|
+
self.clamp_border = clamp_border
|
|
142
|
+
self.negative_labels_ignore_index = negative_labels_ignore_index
|
|
143
|
+
self._logits_callback = None
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def logits_callback(
|
|
147
|
+
self,
|
|
148
|
+
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
|
149
|
+
"""
|
|
150
|
+
Property for calling a function for the logits computation.\n
|
|
151
|
+
|
|
152
|
+
This function is expected to receive model's last hidden state
|
|
153
|
+
and optionally item IDs, and return a logits tensor.
|
|
154
|
+
|
|
155
|
+
It is expected that the corresponding head model method will be used as this function,
|
|
156
|
+
for example, the ``get_logits`` method of the ``SasRec`` class.
|
|
157
|
+
|
|
158
|
+
:return: callable function.
|
|
159
|
+
"""
|
|
160
|
+
if self._logits_callback is None:
|
|
161
|
+
msg = "The callback for getting logits is not defined"
|
|
162
|
+
raise AttributeError(msg)
|
|
163
|
+
return self._logits_callback
|
|
164
|
+
|
|
165
|
+
@logits_callback.setter
|
|
166
|
+
def logits_callback(self, func: Optional[Callable]) -> None:
|
|
167
|
+
self._logits_callback = func
|
|
168
|
+
|
|
169
|
+
def forward(
|
|
170
|
+
self,
|
|
171
|
+
model_embeddings: torch.Tensor,
|
|
172
|
+
feature_tensors: TensorMap, # noqa: ARG002
|
|
173
|
+
positive_labels: torch.LongTensor,
|
|
174
|
+
negative_labels: torch.LongTensor, # noqa: ARG002
|
|
175
|
+
padding_mask: torch.BoolTensor, # noqa: ARG002
|
|
176
|
+
target_padding_mask: torch.BoolTensor,
|
|
177
|
+
) -> torch.Tensor:
|
|
178
|
+
"""
|
|
179
|
+
forward(model_embeddings, positive_labels, target_padding_mask)
|
|
180
|
+
**Note**: At forward pass, the whole catalog of items is used as negatives.
|
|
181
|
+
Next, negative logits, corresponding to positions where negative labels
|
|
182
|
+
coincide with positive ones, are masked.
|
|
183
|
+
|
|
184
|
+
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
185
|
+
:param positive_labels: ground truth labels of positive events
|
|
186
|
+
of shape (batch_size, sequence_length, num_positives).
|
|
187
|
+
:param target_padding_mask: padding mask corresponding for ``positive_labels``
|
|
188
|
+
of shape (batch_size, sequence_length, num_positives).
|
|
189
|
+
:return: computed loss value.
|
|
190
|
+
"""
|
|
191
|
+
all_negative_labels = torch.arange(
|
|
192
|
+
self.cardinality,
|
|
193
|
+
dtype=torch.long,
|
|
194
|
+
device=positive_labels.device,
|
|
195
|
+
)
|
|
196
|
+
sampled = self.get_sampled_logits(
|
|
197
|
+
model_embeddings,
|
|
198
|
+
positive_labels,
|
|
199
|
+
all_negative_labels,
|
|
200
|
+
target_padding_mask,
|
|
201
|
+
)
|
|
202
|
+
positive_logits = sampled["positive_logits"] # [masked_batch_size, num_positives]
|
|
203
|
+
negative_logits = sampled["negative_logits"] # [masked_batch_size, num_negatives]
|
|
204
|
+
positive_labels = sampled["positive_labels"] # [masked_batch_size, num_positives]
|
|
205
|
+
all_negative_labels = sampled["negative_labels"] # [masked_batch_size, num_negatives] or [num_negatives]
|
|
206
|
+
target_padding_mask = sampled["target_padding_mask"] # [masked_batch_size, num_positives]
|
|
207
|
+
|
|
208
|
+
# [masked_batch_size, num_negatives] - assign low values to some negative logits
|
|
209
|
+
negative_logits = mask_negative_logits(
|
|
210
|
+
negative_logits,
|
|
211
|
+
all_negative_labels,
|
|
212
|
+
positive_labels,
|
|
213
|
+
self.negative_labels_ignore_index,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
max_values = torch.max(
|
|
217
|
+
positive_logits.max(-1, keepdim=True).values,
|
|
218
|
+
negative_logits.max(-1, keepdim=True).values,
|
|
219
|
+
) # [masked_batch_size, 1]
|
|
220
|
+
positive_logits = positive_logits - max_values
|
|
221
|
+
negative_logits = negative_logits - max_values
|
|
222
|
+
|
|
223
|
+
positive_logits = torch.exp(positive_logits)
|
|
224
|
+
positive_logits = positive_logits * target_padding_mask
|
|
225
|
+
# [masked_batch_size, num_positives] -> [masked_batch_size]
|
|
226
|
+
positive_logits = positive_logits.sum(-1)
|
|
227
|
+
|
|
228
|
+
negative_logits = torch.exp(negative_logits)
|
|
229
|
+
# [masked_batch_size, num_negatives] -> [masked_batch_size]
|
|
230
|
+
negative_logits = negative_logits.sum(-1)
|
|
231
|
+
|
|
232
|
+
probabilities = positive_logits / (positive_logits + negative_logits)
|
|
233
|
+
loss = -torch.clamp(
|
|
234
|
+
torch.log(probabilities + self.log_epsilon),
|
|
235
|
+
-self.clamp_border,
|
|
236
|
+
self.clamp_border,
|
|
237
|
+
)
|
|
238
|
+
return loss.mean()
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class LogInCESampled(LogInCEBase):
|
|
242
|
+
"""
|
|
243
|
+
Sampled version of LogInCE (Log InfoNCE) loss (with negative sampling items).
|
|
244
|
+
|
|
245
|
+
.. math::
|
|
246
|
+
|
|
247
|
+
L_{\\text{InfoNCE}} = -\\log \\frac{\\sum_{p \\in P} \\exp(\\mathrm{sim}(q, p))}{\\sum_{p \\in P}
|
|
248
|
+
\\exp(\\mathrm{sim}(q, p)) + \\sum_{n \\in N_{\\text{sampled}}} \\exp(\\mathrm{sim}(q, n))},
|
|
249
|
+
|
|
250
|
+
where q -- query embedding, P -- set of positive logits, :math:`N_sampled` -- set of negative logits,
|
|
251
|
+
:math:`sim(\\cdot, \\cdot)` -- similaruty function.\n
|
|
252
|
+
Same as ``LogInCE``, the difference in the set of negatives.
|
|
253
|
+
|
|
254
|
+
The loss supports the calculation of logits for the case of multi-positive labels
|
|
255
|
+
(there are several labels for each position in the sequence).
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
def __init__(
|
|
259
|
+
self,
|
|
260
|
+
log_epsilon: float = 1e-6,
|
|
261
|
+
clamp_border: float = 100.0,
|
|
262
|
+
negative_labels_ignore_index: int = -100,
|
|
263
|
+
):
|
|
264
|
+
"""
|
|
265
|
+
:param log_epsilon: correction to avoid zero in the logarithm during loss calculating.
|
|
266
|
+
Default: 1e-6.
|
|
267
|
+
:param clamp_border: upper bound for clamping loss tensor, lower bound will be setted to -`clamp_border`.
|
|
268
|
+
Default: 100.0.
|
|
269
|
+
:param negative_labels_ignore_index: padding value for negative labels.
|
|
270
|
+
This may be the case when negative labels
|
|
271
|
+
are formed at the preprocessing level, rather than the negative sampler.
|
|
272
|
+
The index is ignored and does not contribute to the loss.
|
|
273
|
+
Default: ``-100``.
|
|
274
|
+
"""
|
|
275
|
+
super().__init__()
|
|
276
|
+
self.log_epsilon = log_epsilon
|
|
277
|
+
self.clamp_border = clamp_border
|
|
278
|
+
self.negative_labels_ignore_index = negative_labels_ignore_index
|
|
279
|
+
self._logits_callback = None
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def logits_callback(
|
|
283
|
+
self,
|
|
284
|
+
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
|
285
|
+
"""
|
|
286
|
+
Property for calling a function for the logits computation.\n
|
|
287
|
+
|
|
288
|
+
This function is expected to receive model's last hidden state
|
|
289
|
+
and optionally item IDs, and return a logits tensor.
|
|
290
|
+
|
|
291
|
+
It is expected that the corresponding head model method will be used as this function,
|
|
292
|
+
for example, the ``get_logits`` method of the ``SasRec`` class.
|
|
293
|
+
|
|
294
|
+
:return: callable function.
|
|
295
|
+
"""
|
|
296
|
+
if self._logits_callback is None:
|
|
297
|
+
msg = "The callback for getting logits is not defined"
|
|
298
|
+
raise AttributeError(msg)
|
|
299
|
+
return self._logits_callback
|
|
300
|
+
|
|
301
|
+
@logits_callback.setter
|
|
302
|
+
def logits_callback(self, func: Optional[Callable]) -> None:
|
|
303
|
+
self._logits_callback = func
|
|
304
|
+
|
|
305
|
+
def forward(
|
|
306
|
+
self,
|
|
307
|
+
model_embeddings: torch.Tensor,
|
|
308
|
+
feature_tensors: TensorMap, # noqa: ARG002
|
|
309
|
+
positive_labels: torch.LongTensor,
|
|
310
|
+
negative_labels: torch.LongTensor,
|
|
311
|
+
padding_mask: torch.BoolTensor, # noqa: ARG002
|
|
312
|
+
target_padding_mask: torch.BoolTensor,
|
|
313
|
+
) -> torch.Tensor:
|
|
314
|
+
"""
|
|
315
|
+
forward(model_embeddings, positive_labels, negative_labels, target_padding_mask)
|
|
316
|
+
|
|
317
|
+
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
318
|
+
:param positive_labels: labels of positive events
|
|
319
|
+
of shape ``(batch_size, sequence_length, num_positives)``.
|
|
320
|
+
:param negative_labels: labels of sampled negative events.
|
|
321
|
+
|
|
322
|
+
Expected shape:
|
|
323
|
+
- ``(batch_size, sequence_length, num_negatives)``
|
|
324
|
+
- ``(batch_size, num_negatives)``
|
|
325
|
+
- ``(num_negatives)`` - a case where the same negative events are used for the entire batch.
|
|
326
|
+
:param target_padding_mask: padding mask corresponding for ``positive_labels``
|
|
327
|
+
of shape ``(batch_size, sequence_length, num_positives)``
|
|
328
|
+
|
|
329
|
+
:return: computed loss value.
|
|
330
|
+
"""
|
|
331
|
+
sampled = self.get_sampled_logits(
|
|
332
|
+
model_embeddings,
|
|
333
|
+
positive_labels,
|
|
334
|
+
negative_labels,
|
|
335
|
+
target_padding_mask,
|
|
336
|
+
)
|
|
337
|
+
positive_logits = sampled["positive_logits"] # [masked_batch_size, num_positives]
|
|
338
|
+
negative_logits = sampled["negative_logits"] # [masked_batch_size, num_negatives]
|
|
339
|
+
positive_labels = sampled["positive_labels"] # [masked_batch_size, num_positives]
|
|
340
|
+
negative_labels = sampled["negative_labels"] # [masked_batch_size, num_negatives] or [num_negatives]
|
|
341
|
+
target_padding_mask = sampled["target_padding_mask"] # [masked_batch_size, num_positives]
|
|
342
|
+
|
|
343
|
+
# [masked_batch_size, num_negatives] - assign low values to some negative logits
|
|
344
|
+
negative_logits = mask_negative_logits(
|
|
345
|
+
negative_logits,
|
|
346
|
+
negative_labels,
|
|
347
|
+
positive_labels,
|
|
348
|
+
self.negative_labels_ignore_index,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
max_values = torch.max(
|
|
352
|
+
positive_logits.max(-1, keepdim=True).values,
|
|
353
|
+
negative_logits.max(-1, keepdim=True).values,
|
|
354
|
+
) # [masked_batch_size, 1]
|
|
355
|
+
positive_logits = positive_logits - max_values
|
|
356
|
+
negative_logits = negative_logits - max_values
|
|
357
|
+
|
|
358
|
+
positive_logits = torch.exp(positive_logits)
|
|
359
|
+
positive_logits = positive_logits * target_padding_mask
|
|
360
|
+
# [masked_batch_size, num_positives] -> [masked_batch_size]
|
|
361
|
+
positive_logits = positive_logits.sum(-1)
|
|
362
|
+
|
|
363
|
+
negative_logits = torch.exp(negative_logits)
|
|
364
|
+
# [masked_batch_size, num_negatives] -> [masked_batch_size]
|
|
365
|
+
negative_logits = negative_logits.sum(-1)
|
|
366
|
+
|
|
367
|
+
probabilities = positive_logits / (positive_logits + negative_logits)
|
|
368
|
+
loss = -torch.clamp(
|
|
369
|
+
torch.log(probabilities + self.log_epsilon),
|
|
370
|
+
-self.clamp_border,
|
|
371
|
+
self.clamp_border,
|
|
372
|
+
)
|
|
373
|
+
return loss.mean()
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from typing import Callable, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from replay.data.nn import TensorMap
|
|
6
|
+
|
|
7
|
+
from .base import mask_negative_logits
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LogOutCE(torch.nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
LogOutCE loss.
|
|
13
|
+
|
|
14
|
+
.. math::
|
|
15
|
+
|
|
16
|
+
L_{\\text{InfoNCE}} = - \\sum_{p \\in P} \\log \\frac{ \\exp(\\mathrm{sim}(q, p))}
|
|
17
|
+
{\\exp(\\mathrm{sim}(q, p))
|
|
18
|
+
+ \\sum_{n \\in N} \\exp(\\mathrm{sim}(q, n))}.
|
|
19
|
+
|
|
20
|
+
where q -- query embedding, P -- set of positive logits, N -- set of negative logits,
|
|
21
|
+
:math:`sim(\\cdot, \\cdot)` -- similaruty function.\n
|
|
22
|
+
|
|
23
|
+
The loss supports the calculation of logits for the case of multi-positive labels
|
|
24
|
+
(there are several labels for each position in the sequence).
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
cardinality: int,
|
|
30
|
+
negative_labels_ignore_index: int = -100,
|
|
31
|
+
**kwargs,
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used.
|
|
35
|
+
You can pass all parameters for initializing the object via kwargs.
|
|
36
|
+
|
|
37
|
+
:param cardinality: number of unique items in vocabulary (catalog).
|
|
38
|
+
The specified cardinality value must not take into account the padding value.
|
|
39
|
+
:param negative_labels_ignore_index: padding value for negative labels.
|
|
40
|
+
This may be the case when negative labels
|
|
41
|
+
are formed at the preprocessing level, rather than the negative sampler.
|
|
42
|
+
The index is ignored and does not contribute to the loss.
|
|
43
|
+
Default: ``-100``.
|
|
44
|
+
"""
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.cardinality = cardinality
|
|
47
|
+
self.negative_labels_ignore_index = negative_labels_ignore_index
|
|
48
|
+
self._loss = torch.nn.CrossEntropyLoss(**kwargs)
|
|
49
|
+
self._logits_callback = None
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def logits_callback(
|
|
53
|
+
self,
|
|
54
|
+
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
|
55
|
+
"""
|
|
56
|
+
Property for calling a function for the logits computation.\n
|
|
57
|
+
|
|
58
|
+
This function is expected to receive model's last hidden state
|
|
59
|
+
and optionally item IDs, and return a logits tensor.
|
|
60
|
+
|
|
61
|
+
It is expected that the corresponding head model method will be used as this function,
|
|
62
|
+
for example, the ``get_logits`` method of the ``SasRec`` class.
|
|
63
|
+
|
|
64
|
+
:return: callable function.
|
|
65
|
+
"""
|
|
66
|
+
if self._logits_callback is None:
|
|
67
|
+
msg = "The callback for getting logits is not defined"
|
|
68
|
+
raise AttributeError(msg)
|
|
69
|
+
return self._logits_callback
|
|
70
|
+
|
|
71
|
+
@logits_callback.setter
|
|
72
|
+
def logits_callback(self, func: Optional[Callable]) -> None:
|
|
73
|
+
self._logits_callback = func
|
|
74
|
+
|
|
75
|
+
def forward(
|
|
76
|
+
self,
|
|
77
|
+
model_embeddings: torch.Tensor,
|
|
78
|
+
feature_tensors: TensorMap, # noqa: ARG002
|
|
79
|
+
positive_labels: torch.LongTensor,
|
|
80
|
+
negative_labels: torch.LongTensor, # noqa: ARG002
|
|
81
|
+
padding_mask: torch.BoolTensor, # noqa: ARG002
|
|
82
|
+
target_padding_mask: torch.BoolTensor,
|
|
83
|
+
) -> torch.Tensor:
|
|
84
|
+
"""
|
|
85
|
+
forward(model_embeddings, positive_labels, target_padding_mask)
|
|
86
|
+
**Note**: At forward pass, the whole catalog of items is used as negatives.
|
|
87
|
+
Next, negative logits, corresponding to positions where negative labels
|
|
88
|
+
coincide with positive ones, are masked.
|
|
89
|
+
|
|
90
|
+
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
91
|
+
:param positive_labels: ground truth labels of positive events
|
|
92
|
+
of shape (batch_size, sequence_length, num_positives).
|
|
93
|
+
:param target_padding_mask: padding mask corresponding for ``positive_labels``
|
|
94
|
+
of shape (batch_size, sequence_length, num_positives).
|
|
95
|
+
:return: computed loss value.
|
|
96
|
+
"""
|
|
97
|
+
initial_target_padding_mask = target_padding_mask
|
|
98
|
+
num_positives = target_padding_mask.size(2)
|
|
99
|
+
# [batch_size, seq_len, num_positives] -> [batch_size * seq_len]
|
|
100
|
+
if num_positives == 1:
|
|
101
|
+
target_padding_mask = target_padding_mask.squeeze(-1)
|
|
102
|
+
else:
|
|
103
|
+
target_padding_mask = target_padding_mask.sum(-1).bool()
|
|
104
|
+
masked_batch_size = target_padding_mask.sum().item()
|
|
105
|
+
|
|
106
|
+
logits: torch.Tensor = self.logits_callback(model_embeddings) # [batch_size, seq_len, vocab_size]
|
|
107
|
+
all_negative_labels = torch.arange(self.cardinality, dtype=torch.long, device=positive_labels.device)
|
|
108
|
+
|
|
109
|
+
# [batch_size, seq_len, vocab_size] -> [masked_batch_size, vocab_size]
|
|
110
|
+
logits = logits[target_padding_mask]
|
|
111
|
+
|
|
112
|
+
# [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
|
|
113
|
+
positive_labels = positive_labels[target_padding_mask]
|
|
114
|
+
|
|
115
|
+
# [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
|
|
116
|
+
target_padding_mask = initial_target_padding_mask[target_padding_mask]
|
|
117
|
+
|
|
118
|
+
positive_ids = torch.arange(masked_batch_size, dtype=torch.long, device=positive_labels.device)
|
|
119
|
+
# [masked_batch_size, vocab_size] -> [masked_batch_size, num_positives]
|
|
120
|
+
positive_logits = logits[positive_ids.unsqueeze(-1), positive_labels]
|
|
121
|
+
|
|
122
|
+
# [masked_batch_size, vocab_size] - assign low values to some negative logits
|
|
123
|
+
negative_logits = mask_negative_logits(
|
|
124
|
+
logits,
|
|
125
|
+
all_negative_labels,
|
|
126
|
+
positive_labels,
|
|
127
|
+
self.negative_labels_ignore_index,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# [masked_batch_size, num_negatives] -> [masked_batch_size, 1, num_negatives]
|
|
131
|
+
negative_logits = negative_logits.unsqueeze(-2)
|
|
132
|
+
# [masked_batch_size, 1, num_negatives] -> [masked_batch_size, num_positives, num_negatives]
|
|
133
|
+
negative_logits = negative_logits.repeat(1, target_padding_mask.size(-1), 1)
|
|
134
|
+
# [masked_batch_size, num_positives, num_negatives] -> [masked_batch_size, num_negatives]
|
|
135
|
+
negative_logits = negative_logits[target_padding_mask]
|
|
136
|
+
# [masked_batch_size, num_positives] -> [masked_batch_size]
|
|
137
|
+
positive_logits = positive_logits[target_padding_mask]
|
|
138
|
+
# [masked_batch_size] -> [masked_batch_size, 1]
|
|
139
|
+
positive_logits = positive_logits.unsqueeze(-1)
|
|
140
|
+
|
|
141
|
+
# [masked_batch_size, 1 + num_negatives] - all logits
|
|
142
|
+
logits = torch.cat((positive_logits, negative_logits), dim=-1)
|
|
143
|
+
# [masked_batch_size] - positives are always at 0 position for all recommendation points
|
|
144
|
+
target = torch.zeros(logits.size(0), dtype=torch.long, device=positive_labels.device)
|
|
145
|
+
# [masked_batch_size] - loss for all recommendation points
|
|
146
|
+
loss = self._loss(logits, target)
|
|
147
|
+
return loss
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class LogOutCEWeighted(LogOutCE):
|
|
151
|
+
"""
|
|
152
|
+
LogOutCE loss with sample weights enabling.
|
|
153
|
+
|
|
154
|
+
.. math::
|
|
155
|
+
|
|
156
|
+
L_{\\text{InfoNCE}} = - \\sum_{p \\in P} \\log \\frac{ \\exp(\\mathrm{sim}(q, p))}
|
|
157
|
+
{\\exp(\\mathrm{sim}(q, p))
|
|
158
|
+
+ \\sum_{n \\in N} \\exp(\\mathrm{sim}(q, n))}.
|
|
159
|
+
|
|
160
|
+
where q -- query embedding, P -- set of positive logits, N -- set of negative logits,
|
|
161
|
+
:math:`sim(\\cdot, \\cdot)` -- similaruty function.\n
|
|
162
|
+
|
|
163
|
+
In addition to calculating the standard loss,
|
|
164
|
+
weights are applied for each sample.
|
|
165
|
+
Therefore, it is expected that the sample weights will be in the generated batch,
|
|
166
|
+
which is fed into the model.
|
|
167
|
+
|
|
168
|
+
The loss supports the calculation of logits for the case of multi-positive labels
|
|
169
|
+
(there are several labels for each position in the sequence).
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(
|
|
173
|
+
self,
|
|
174
|
+
cardinality: int,
|
|
175
|
+
feature_name: str,
|
|
176
|
+
negative_labels_ignore_index: int = -100,
|
|
177
|
+
**kwargs,
|
|
178
|
+
):
|
|
179
|
+
"""
|
|
180
|
+
To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``.
|
|
181
|
+
You can pass all other parameters for initializing the object via kwargs.
|
|
182
|
+
|
|
183
|
+
:param cardinality: number of unique items in vocabulary (catalog).
|
|
184
|
+
The specified cardinality value must not take into account the padding value.
|
|
185
|
+
:param feature_name: the name of the key in the batch.
|
|
186
|
+
The tensor is expected to contain sample weights.
|
|
187
|
+
:param negative_labels_ignore_index: padding value for negative labels.
|
|
188
|
+
This may be the case when negative labels
|
|
189
|
+
are formed at the preprocessing level, rather than the negative sampler.
|
|
190
|
+
The index is ignored and does not contribute to the loss.
|
|
191
|
+
Default: ``-100``.
|
|
192
|
+
"""
|
|
193
|
+
super().__init__(
|
|
194
|
+
cardinality=cardinality,
|
|
195
|
+
negative_labels_ignore_index=negative_labels_ignore_index,
|
|
196
|
+
)
|
|
197
|
+
self.feature_name = feature_name
|
|
198
|
+
self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs)
|
|
199
|
+
|
|
200
|
+
def forward(
|
|
201
|
+
self,
|
|
202
|
+
model_embeddings: torch.Tensor,
|
|
203
|
+
feature_tensors: TensorMap,
|
|
204
|
+
positive_labels: torch.LongTensor,
|
|
205
|
+
negative_labels: torch.LongTensor, # noqa: ARG002
|
|
206
|
+
padding_mask: torch.BoolTensor, # noqa: ARG002
|
|
207
|
+
target_padding_mask: torch.BoolTensor,
|
|
208
|
+
) -> torch.Tensor:
|
|
209
|
+
"""
|
|
210
|
+
forward(model_embeddings, feature_tensors, positive_labels, target_padding_mask)
|
|
211
|
+
**Note**: At forward pass, the whole catalog of items is used as negatives.
|
|
212
|
+
Next, negative logits, corresponding to positions where negative labels
|
|
213
|
+
coincide with positive ones, are masked.
|
|
214
|
+
|
|
215
|
+
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
216
|
+
:param feature_tensors: a dictionary of tensors from dataloader.
|
|
217
|
+
This dictionary is expected to contain a key with the name ``feature_name``,
|
|
218
|
+
which is specified in the constructor.
|
|
219
|
+
Expected shape of tensor ``(batch_size, sequence_length, num_positives)``.
|
|
220
|
+
:param positive_labels: ground truth labels of positive events
|
|
221
|
+
of shape (batch_size, sequence_length, num_positives).
|
|
222
|
+
:param target_padding_mask: padding mask corresponding for ``positive_labels``
|
|
223
|
+
of shape (batch_size, sequence_length, num_positives).
|
|
224
|
+
:return: computed loss value.
|
|
225
|
+
"""
|
|
226
|
+
loss: torch.Tensor = super().forward(model_embeddings, None, positive_labels, None, None, target_padding_mask)
|
|
227
|
+
sample_weight = feature_tensors[self.feature_name]
|
|
228
|
+
sample_weight = sample_weight[target_padding_mask]
|
|
229
|
+
loss = (loss * sample_weight).mean()
|
|
230
|
+
return loss
|
replay/nn/mask.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Protocol
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from replay.data.nn.schema import TensorMap
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AttentionMaskProto(Protocol):
|
|
10
|
+
def __call__(self, feature_tensor: TensorMap, padding_mask: torch.BoolTensor) -> torch.Tensor: ...
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AttentionMaskBase(ABC):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
num_heads: int,
|
|
17
|
+
) -> None:
|
|
18
|
+
self.num_heads = num_heads
|
|
19
|
+
|
|
20
|
+
def __call__(
|
|
21
|
+
self,
|
|
22
|
+
feature_tensor: TensorMap,
|
|
23
|
+
padding_mask: torch.BoolTensor,
|
|
24
|
+
) -> torch.FloatTensor:
|
|
25
|
+
"""
|
|
26
|
+
:param feature_tensor: dict of features tensors.
|
|
27
|
+
:param padding_mask: Padding mask where ``0`` - ``<PAD>``, ``1`` - otherwise.
|
|
28
|
+
:returns: Float attention mask of shape ``(B * num_heads, L, L)``,
|
|
29
|
+
where ``-inf`` for ``<PAD>``, ``0`` - otherwise.
|
|
30
|
+
"""
|
|
31
|
+
attention_mask = self._get_attention_mask(feature_tensor)
|
|
32
|
+
|
|
33
|
+
diagonal_attention_mask = torch.diag(
|
|
34
|
+
torch.ones(padding_mask.size(1), dtype=torch.bool, device=padding_mask.device)
|
|
35
|
+
)
|
|
36
|
+
# (B, L) -> (B, 1, 1, L)
|
|
37
|
+
key_padding_mask = padding_mask.unsqueeze(1).unsqueeze(1)
|
|
38
|
+
# (B, 1, 1, L) -> (B, 1, L, L), where 0 - PAD, 1 - otherwise
|
|
39
|
+
key_padding_mask = key_padding_mask | diagonal_attention_mask
|
|
40
|
+
|
|
41
|
+
attention_mask = (attention_mask & key_padding_mask).float()
|
|
42
|
+
attention_mask = attention_mask.masked_fill(attention_mask == 0, float("-inf")).masked_fill(
|
|
43
|
+
attention_mask == 1, 0.0
|
|
44
|
+
)
|
|
45
|
+
if attention_mask.size(1) != self.num_heads and attention_mask.shape[1] == 1:
|
|
46
|
+
# for default attention_mask of shape (L, L) it becomes (B, 1, L, L)
|
|
47
|
+
# (B, 1, L, L) -> (B, num_heads, L, L)
|
|
48
|
+
attention_mask = attention_mask.repeat(1, self.num_heads, 1, 1)
|
|
49
|
+
# (B, num_heads, L, L) -> (B * num_heads, L, L)
|
|
50
|
+
attention_mask = attention_mask.reshape(-1, *attention_mask.shape[-2:])
|
|
51
|
+
return attention_mask
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def _get_attention_mask(self, feature_tensor: TensorMap) -> torch.Tensor:
|
|
55
|
+
raise NotImplementedError() # pragma: no cover
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class DefaultAttentionMask(AttentionMaskBase):
|
|
59
|
+
"""
|
|
60
|
+
Constructs a float lower-triangular attenstion mask
|
|
61
|
+
of shape ``(batch_size * num_heads, sequence_length, sequence_length)``,
|
|
62
|
+
where ``-inf`` for ``<PAD>``, ``0`` - otherwise.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
reference_feature_name: str,
|
|
68
|
+
num_heads: int,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""
|
|
71
|
+
:param reference_feature_name: To build a mask, you need a reference tensor.
|
|
72
|
+
So you need to pass the name of the tensor, which will definitely be in the dictionary of feature tensors.
|
|
73
|
+
The second dimension (1 in zero indexing) of the tensor will be used to construct the attention mask.
|
|
74
|
+
:param num_heads: Number of attention heads.
|
|
75
|
+
"""
|
|
76
|
+
super().__init__(num_heads)
|
|
77
|
+
self._feature_name = reference_feature_name
|
|
78
|
+
|
|
79
|
+
def _get_attention_mask(self, feature_tensor: TensorMap) -> torch.Tensor:
|
|
80
|
+
input_sequence = feature_tensor[self._feature_name]
|
|
81
|
+
return torch.tril(
|
|
82
|
+
torch.ones(
|
|
83
|
+
(input_sequence.size(1), input_sequence.size(1)),
|
|
84
|
+
dtype=torch.bool,
|
|
85
|
+
device=input_sequence.device,
|
|
86
|
+
)
|
|
87
|
+
)
|