replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay_rec-0.20.3.dist-info/RECORD +0 -138
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
replay/nn/loss/bce.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
from typing import Callable, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from replay.data.nn import TensorMap
|
|
6
|
+
|
|
7
|
+
from .base import SampledLossBase, mask_negative_logits
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BCE(torch.nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
Pointwise Binary Cross-Entropy loss.
|
|
13
|
+
Calculates loss over all items catalog.
|
|
14
|
+
|
|
15
|
+
The loss supports the calculation of logits for the case of multi-positive labels
|
|
16
|
+
(there are several labels for each position in the sequence).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, **kwargs):
|
|
20
|
+
"""
|
|
21
|
+
To calculate the loss, ``torch.nn.BCEWithLogitsLoss`` is used with the parameter ``reduction="sum"``.
|
|
22
|
+
You can pass all other parameters for initializing the object via kwargs.
|
|
23
|
+
"""
|
|
24
|
+
super().__init__()
|
|
25
|
+
self._loss = torch.nn.BCEWithLogitsLoss(reduction="sum", **kwargs)
|
|
26
|
+
self._logits_callback = None
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def logits_callback(
|
|
30
|
+
self,
|
|
31
|
+
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
|
32
|
+
"""
|
|
33
|
+
Property for calling a function for the logits computation.\n
|
|
34
|
+
|
|
35
|
+
This function is expected to receive model's last hidden state
|
|
36
|
+
and optionally item IDs, and return a logits tensor.
|
|
37
|
+
|
|
38
|
+
It is expected that the corresponding head model method will be used as this function,
|
|
39
|
+
for example, the ``get_logits`` method of the ``SasRec`` class.
|
|
40
|
+
|
|
41
|
+
:return: callable function.
|
|
42
|
+
"""
|
|
43
|
+
if self._logits_callback is None:
|
|
44
|
+
msg = "The callback for getting logits is not defined"
|
|
45
|
+
raise AttributeError(msg)
|
|
46
|
+
return self._logits_callback
|
|
47
|
+
|
|
48
|
+
@logits_callback.setter
|
|
49
|
+
def logits_callback(self, func: Optional[Callable]) -> None:
|
|
50
|
+
self._logits_callback = func
|
|
51
|
+
|
|
52
|
+
def forward(
|
|
53
|
+
self,
|
|
54
|
+
model_embeddings: torch.Tensor,
|
|
55
|
+
feature_tensors: TensorMap, # noqa: ARG002
|
|
56
|
+
positive_labels: torch.LongTensor,
|
|
57
|
+
negative_labels: torch.LongTensor, # noqa: ARG002
|
|
58
|
+
padding_mask: torch.BoolTensor, # noqa: ARG002
|
|
59
|
+
target_padding_mask: torch.BoolTensor,
|
|
60
|
+
) -> torch.Tensor:
|
|
61
|
+
"""
|
|
62
|
+
forward(model_embeddings, positive_labels, target_padding_mask)
|
|
63
|
+
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
64
|
+
:param positive_labels: labels of positive events
|
|
65
|
+
of shape ``(batch_size, sequence_length, num_positives)``.
|
|
66
|
+
:param target_padding_mask: padding mask corresponding for `positive_labels`
|
|
67
|
+
of shape ``(batch_size, sequence_length, num_positives)``.
|
|
68
|
+
:return: computed loss value.
|
|
69
|
+
"""
|
|
70
|
+
logits = self.logits_callback(model_embeddings)
|
|
71
|
+
|
|
72
|
+
# [batch_size, seq_len, num_positives] -> [batch_size, seq_len]
|
|
73
|
+
if target_padding_mask.size(-1) == 1:
|
|
74
|
+
target_padding_mask.squeeze_(-1)
|
|
75
|
+
else:
|
|
76
|
+
target_padding_mask = target_padding_mask.sum(-1).bool()
|
|
77
|
+
|
|
78
|
+
# Take only logits which correspond to non-padded tokens
|
|
79
|
+
# [batch_size, seq_len, vocab_size] -> [masked_batch_size, vocab_size]
|
|
80
|
+
logits = logits[target_padding_mask]
|
|
81
|
+
|
|
82
|
+
# [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
|
|
83
|
+
labels = positive_labels[target_padding_mask]
|
|
84
|
+
|
|
85
|
+
bce_labels = torch.zeros_like(logits)
|
|
86
|
+
|
|
87
|
+
# Fill positives with ones, all negatives are zeros
|
|
88
|
+
bce_labels.scatter_(
|
|
89
|
+
dim=-1,
|
|
90
|
+
index=labels,
|
|
91
|
+
value=1,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
loss = self._loss(logits, bce_labels) / logits.size(0)
|
|
95
|
+
return loss
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class BCESampled(SampledLossBase):
|
|
99
|
+
"""
|
|
100
|
+
Sampled Pointwise Binary Cross-Entropy loss (BCE with negative sampling).
|
|
101
|
+
Calculates loss between one positive item and K negatively sampled items.
|
|
102
|
+
|
|
103
|
+
The loss supports the calculation of logits for the case of multi-positive labels
|
|
104
|
+
(there are several labels for each position in the sequence).
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
log_epsilon: float = 1e-6,
|
|
110
|
+
clamp_border: float = 100.0,
|
|
111
|
+
negative_labels_ignore_index: int = -100,
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
:param log_epsilon: correction to avoid zero in the logarithm during loss calculating.
|
|
115
|
+
Default: ``1e-6``.
|
|
116
|
+
:param clamp_border: upper bound for clamping loss tensor, lower bound will be setted to -`clamp_border`.
|
|
117
|
+
Default: ``100.0``.
|
|
118
|
+
:param negative_labels_ignore_index: padding value for negative labels.
|
|
119
|
+
This may be the case when negative labels
|
|
120
|
+
are formed at the preprocessing level, rather than the negative sampler.
|
|
121
|
+
The index is ignored and does not contribute to the loss.
|
|
122
|
+
Default: ``-100``.
|
|
123
|
+
"""
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.log_epsilon = log_epsilon
|
|
126
|
+
self.clamp_border = clamp_border
|
|
127
|
+
self.negative_labels_ignore_index = negative_labels_ignore_index
|
|
128
|
+
self._logits_callback = None
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def logits_callback(
|
|
132
|
+
self,
|
|
133
|
+
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
|
134
|
+
"""
|
|
135
|
+
Property for calling a function for the logits computation.\n
|
|
136
|
+
|
|
137
|
+
This function is expected to receive model's last hidden state
|
|
138
|
+
and optionally item IDs, and return a logits tensor.
|
|
139
|
+
|
|
140
|
+
It is expected that the corresponding head model method will be used as this function,
|
|
141
|
+
for example, the ``get_logits`` method of the ``SasRec`` class.
|
|
142
|
+
|
|
143
|
+
:return: callable function.
|
|
144
|
+
"""
|
|
145
|
+
if self._logits_callback is None:
|
|
146
|
+
msg = "The callback for getting logits is not defined"
|
|
147
|
+
raise AttributeError(msg)
|
|
148
|
+
return self._logits_callback
|
|
149
|
+
|
|
150
|
+
@logits_callback.setter
|
|
151
|
+
def logits_callback(self, func: Optional[Callable]) -> None:
|
|
152
|
+
self._logits_callback = func
|
|
153
|
+
|
|
154
|
+
def forward(
|
|
155
|
+
self,
|
|
156
|
+
model_embeddings: torch.Tensor,
|
|
157
|
+
feature_tensors: TensorMap, # noqa: ARG002
|
|
158
|
+
positive_labels: torch.LongTensor,
|
|
159
|
+
negative_labels: torch.LongTensor,
|
|
160
|
+
padding_mask: torch.BoolTensor, # noqa: ARG002
|
|
161
|
+
target_padding_mask: torch.BoolTensor,
|
|
162
|
+
) -> torch.Tensor:
|
|
163
|
+
"""
|
|
164
|
+
forward(model_embeddings, positive_labels, negative_labels, target_padding_mask)
|
|
165
|
+
|
|
166
|
+
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
167
|
+
:param positive_labels: labels of positive events
|
|
168
|
+
of shape ``(batch_size, sequence_length, num_positives)``.
|
|
169
|
+
:param negative_labels: labels of sampled negative events.
|
|
170
|
+
Expected shape:
|
|
171
|
+
|
|
172
|
+
- ``(batch_size, sequence_length, num_negatives)``
|
|
173
|
+
- ``(batch_size, num_negatives)``
|
|
174
|
+
- ``(num_negatives)`` - a case where the same negative events are used for the entire batch.
|
|
175
|
+
:param target_padding_mask: padding mask corresponding for ``positive_labels``
|
|
176
|
+
of shape ``(batch_size, sequence_length, num_positives)``
|
|
177
|
+
:return: computed loss value.
|
|
178
|
+
|
|
179
|
+
"""
|
|
180
|
+
sampled = self.get_sampled_logits(
|
|
181
|
+
model_embeddings,
|
|
182
|
+
positive_labels,
|
|
183
|
+
negative_labels,
|
|
184
|
+
target_padding_mask,
|
|
185
|
+
)
|
|
186
|
+
positive_logits = sampled["positive_logits"] # [masked_batch_size, num_positives]
|
|
187
|
+
negative_logits = sampled["negative_logits"] # [masked_batch_size, num_negatives]
|
|
188
|
+
positive_labels = sampled["positive_labels"] # [masked_batch_size, num_positives]
|
|
189
|
+
negative_labels = sampled["negative_labels"] # [masked_batch_size, num_negatives] or [num_negatives]
|
|
190
|
+
|
|
191
|
+
# Reject negative samples matching target label & correct for remaining samples
|
|
192
|
+
negative_logits = mask_negative_logits(
|
|
193
|
+
negative_logits,
|
|
194
|
+
negative_labels,
|
|
195
|
+
positive_labels,
|
|
196
|
+
self.negative_labels_ignore_index,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
positive_prob = torch.sigmoid(positive_logits)
|
|
200
|
+
negative_prob = torch.sigmoid(negative_logits)
|
|
201
|
+
|
|
202
|
+
positive_loss = torch.clamp(
|
|
203
|
+
torch.log((positive_prob) + self.log_epsilon),
|
|
204
|
+
-self.clamp_border,
|
|
205
|
+
self.clamp_border,
|
|
206
|
+
).sum()
|
|
207
|
+
negative_loss = torch.clamp(
|
|
208
|
+
torch.log((1 - negative_prob) + self.log_epsilon),
|
|
209
|
+
-self.clamp_border,
|
|
210
|
+
self.clamp_border,
|
|
211
|
+
).sum()
|
|
212
|
+
|
|
213
|
+
loss = -(positive_loss + negative_loss)
|
|
214
|
+
loss /= positive_logits.size(0)
|
|
215
|
+
|
|
216
|
+
return loss
|
replay/nn/loss/ce.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
from typing import Callable, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from replay.data.nn import TensorMap
|
|
6
|
+
|
|
7
|
+
from .base import SampledLossBase, mask_negative_logits
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CE(torch.nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
Full Cross-Entropy loss
|
|
13
|
+
Calculates loss over all items catalog.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, **kwargs):
|
|
17
|
+
"""
|
|
18
|
+
To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used.
|
|
19
|
+
You can pass all parameters for initializing the object via kwargs.
|
|
20
|
+
"""
|
|
21
|
+
super().__init__()
|
|
22
|
+
self._loss = torch.nn.CrossEntropyLoss(**kwargs)
|
|
23
|
+
self._logits_callback = None
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def logits_callback(
|
|
27
|
+
self,
|
|
28
|
+
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
|
29
|
+
"""
|
|
30
|
+
Property for calling a function for the logits computation.\n
|
|
31
|
+
|
|
32
|
+
This function is expected to receive model's last hidden state
|
|
33
|
+
and optionally item IDs, and return a logits tensor.
|
|
34
|
+
|
|
35
|
+
It is expected that the corresponding head model method will be used as this function,
|
|
36
|
+
for example, the ``get_logits`` method of the ``SasRec`` class.
|
|
37
|
+
|
|
38
|
+
:return: callable function.
|
|
39
|
+
"""
|
|
40
|
+
if self._logits_callback is None:
|
|
41
|
+
msg = "The callback for getting logits is not defined"
|
|
42
|
+
raise AttributeError(msg)
|
|
43
|
+
return self._logits_callback
|
|
44
|
+
|
|
45
|
+
@logits_callback.setter
|
|
46
|
+
def logits_callback(self, func: Optional[Callable]) -> None:
|
|
47
|
+
self._logits_callback = func
|
|
48
|
+
|
|
49
|
+
def forward(
|
|
50
|
+
self,
|
|
51
|
+
model_embeddings: torch.Tensor,
|
|
52
|
+
feature_tensors: TensorMap, # noqa: ARG002
|
|
53
|
+
positive_labels: torch.LongTensor,
|
|
54
|
+
negative_labels: torch.LongTensor, # noqa: ARG002
|
|
55
|
+
padding_mask: torch.BoolTensor, # noqa: ARG002
|
|
56
|
+
target_padding_mask: torch.BoolTensor,
|
|
57
|
+
) -> torch.Tensor:
|
|
58
|
+
"""
|
|
59
|
+
forward(model_embeddings, positive_labels, target_padding_mask)
|
|
60
|
+
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
61
|
+
:param positive_labels: labels of positive events
|
|
62
|
+
of shape ``(batch_size, sequence_length, num_positives)``.
|
|
63
|
+
:param target_padding_mask: padding mask corresponding for `positive_labels`
|
|
64
|
+
of shape ``(batch_size, sequence_length, num_positives)``.
|
|
65
|
+
:return: computed loss value.
|
|
66
|
+
"""
|
|
67
|
+
if positive_labels.size(-1) != 1:
|
|
68
|
+
msg = "The case of multi-positive labels is not supported in the CE loss"
|
|
69
|
+
raise NotImplementedError(msg)
|
|
70
|
+
logits: torch.Tensor = self.logits_callback(model_embeddings) # [batch_size, seq_len, vocab_size]
|
|
71
|
+
labels = positive_labels.masked_fill(
|
|
72
|
+
mask=(~target_padding_mask),
|
|
73
|
+
value=self._loss.ignore_index,
|
|
74
|
+
) # [batch_size, seq_len, 1]
|
|
75
|
+
|
|
76
|
+
# [batch_size, seq_len, vocab_size] -> [batch_size * seq_len, vocab_size]
|
|
77
|
+
logits_flat = logits.view(-1, logits.size(-1))
|
|
78
|
+
# [batch_size, seq_len, 1] -> [batch_size * seq_len]
|
|
79
|
+
labels_flat: torch.LongTensor = labels.view(-1)
|
|
80
|
+
loss = self._loss(logits_flat, labels_flat)
|
|
81
|
+
return loss
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class CEWeighted(CE):
|
|
85
|
+
"""
|
|
86
|
+
Full Cross-Entropy loss
|
|
87
|
+
Calculates loss over all items catalog.
|
|
88
|
+
|
|
89
|
+
In addition to calculating the standard loss,
|
|
90
|
+
weights are applied for each sample.
|
|
91
|
+
Therefore, it is expected that the sample weights will be in the generated batch,
|
|
92
|
+
which is fed into the model.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
feature_name: str,
|
|
98
|
+
**kwargs,
|
|
99
|
+
):
|
|
100
|
+
"""
|
|
101
|
+
To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``.
|
|
102
|
+
You can pass all other parameters for initializing the object via kwargs.
|
|
103
|
+
|
|
104
|
+
:param feature_name: the name of the key in the batch.
|
|
105
|
+
The tensor is expected to contain sample weights.
|
|
106
|
+
"""
|
|
107
|
+
super().__init__()
|
|
108
|
+
self.feature_name = feature_name
|
|
109
|
+
self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs)
|
|
110
|
+
|
|
111
|
+
def forward(
|
|
112
|
+
self,
|
|
113
|
+
model_embeddings: torch.Tensor,
|
|
114
|
+
feature_tensors: TensorMap,
|
|
115
|
+
positive_labels: torch.LongTensor,
|
|
116
|
+
negative_labels: torch.LongTensor, # noqa: ARG002
|
|
117
|
+
padding_mask: torch.BoolTensor, # noqa: ARG002
|
|
118
|
+
target_padding_mask: torch.BoolTensor,
|
|
119
|
+
) -> torch.Tensor:
|
|
120
|
+
"""
|
|
121
|
+
forward(model_embeddings, feature_tensors, positive_labels, target_padding_mask)
|
|
122
|
+
:param feature_tensors: a dictionary of tensors from dataloader.
|
|
123
|
+
This dictionary is expected to contain a key with the name ``feature_name``,
|
|
124
|
+
which is specified in the constructor.
|
|
125
|
+
Expected shape of tensor ``(batch_size, sequence_length, num_positives)``.
|
|
126
|
+
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
127
|
+
:param positive_labels: labels of positive events
|
|
128
|
+
of shape ``(batch_size, sequence_length, num_positives)``.
|
|
129
|
+
:param target_padding_mask: padding mask corresponding for `positive_labels`
|
|
130
|
+
of shape ``(batch_size, sequence_length, num_positives)``.
|
|
131
|
+
:return: computed loss value.
|
|
132
|
+
"""
|
|
133
|
+
loss: torch.Tensor = super().forward(
|
|
134
|
+
model_embeddings,
|
|
135
|
+
None,
|
|
136
|
+
positive_labels,
|
|
137
|
+
None,
|
|
138
|
+
None,
|
|
139
|
+
target_padding_mask,
|
|
140
|
+
)
|
|
141
|
+
sample_weight = feature_tensors[self.feature_name]
|
|
142
|
+
loss = (loss * sample_weight).mean()
|
|
143
|
+
return loss
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class CESampled(SampledLossBase):
|
|
147
|
+
"""
|
|
148
|
+
Sampled Cross-Entropy loss (Cross-Entropy with negative sampling).
|
|
149
|
+
Calculates loss between one positive item and K negatively sampled items.
|
|
150
|
+
|
|
151
|
+
The loss supports the calculation of logits for the case of multi-positive labels
|
|
152
|
+
(there are several labels for each position in the sequence).
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
def __init__(
|
|
156
|
+
self,
|
|
157
|
+
negative_labels_ignore_index: int = -100,
|
|
158
|
+
**kwargs,
|
|
159
|
+
):
|
|
160
|
+
"""
|
|
161
|
+
To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used.
|
|
162
|
+
You can pass all parameters for initializing the object via kwargs.
|
|
163
|
+
|
|
164
|
+
:param negative_labels_ignore_index: padding value for negative labels.
|
|
165
|
+
This may be the case when negative labels
|
|
166
|
+
are formed at the preprocessing level, rather than the negative sampler.
|
|
167
|
+
The index is ignored and does not contribute to the loss.
|
|
168
|
+
Default: ``-100``.
|
|
169
|
+
"""
|
|
170
|
+
super().__init__()
|
|
171
|
+
self.negative_labels_ignore_index = negative_labels_ignore_index
|
|
172
|
+
self._loss = torch.nn.CrossEntropyLoss(**kwargs)
|
|
173
|
+
self._logits_callback = None
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def logits_callback(
|
|
177
|
+
self,
|
|
178
|
+
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
|
179
|
+
"""
|
|
180
|
+
Property for calling a function for the logits computation.\n
|
|
181
|
+
|
|
182
|
+
This function is expected to receive model's last hidden state
|
|
183
|
+
and optionally item IDs, and return a logits tensor.
|
|
184
|
+
|
|
185
|
+
It is expected that the corresponding head model method will be used as this function,
|
|
186
|
+
for example, the ``get_logits`` method of the ``SasRec`` class.
|
|
187
|
+
|
|
188
|
+
:return: callable function.
|
|
189
|
+
"""
|
|
190
|
+
if self._logits_callback is None:
|
|
191
|
+
msg = "The callback for getting logits is not defined"
|
|
192
|
+
raise AttributeError(msg)
|
|
193
|
+
return self._logits_callback
|
|
194
|
+
|
|
195
|
+
@logits_callback.setter
|
|
196
|
+
def logits_callback(self, func: Optional[Callable]) -> None:
|
|
197
|
+
self._logits_callback = func
|
|
198
|
+
|
|
199
|
+
def forward(
|
|
200
|
+
self,
|
|
201
|
+
model_embeddings: torch.Tensor,
|
|
202
|
+
feature_tensors: TensorMap, # noqa: ARG002
|
|
203
|
+
positive_labels: torch.LongTensor,
|
|
204
|
+
negative_labels: torch.LongTensor,
|
|
205
|
+
padding_mask: torch.BoolTensor, # noqa: ARG002
|
|
206
|
+
target_padding_mask: torch.BoolTensor,
|
|
207
|
+
) -> torch.Tensor:
|
|
208
|
+
"""
|
|
209
|
+
forward(model_embeddings, positive_labels, negative_labels, target_padding_mask)
|
|
210
|
+
|
|
211
|
+
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
212
|
+
:param positive_labels: labels of positive events
|
|
213
|
+
of shape ``(batch_size, sequence_length, num_positives)``.
|
|
214
|
+
:param negative_labels: labels of sampled negative events.
|
|
215
|
+
|
|
216
|
+
Expected shape:
|
|
217
|
+
- ``(batch_size, sequence_length, num_negatives)``
|
|
218
|
+
- ``(batch_size, num_negatives)``
|
|
219
|
+
- ``(num_negatives)`` - a case where the same negative events are used for the entire batch.
|
|
220
|
+
:param target_padding_mask: padding mask corresponding for ``positive_labels``
|
|
221
|
+
of shape ``(batch_size, sequence_length, num_positives)``
|
|
222
|
+
|
|
223
|
+
:return: computed loss value.
|
|
224
|
+
"""
|
|
225
|
+
sampled = self.get_sampled_logits(
|
|
226
|
+
model_embeddings,
|
|
227
|
+
positive_labels,
|
|
228
|
+
negative_labels,
|
|
229
|
+
target_padding_mask,
|
|
230
|
+
)
|
|
231
|
+
positive_logits = sampled["positive_logits"] # [masked_batch_size, num_positives]
|
|
232
|
+
negative_logits = sampled["negative_logits"] # [masked_batch_size, num_negatives]
|
|
233
|
+
positive_labels = sampled["positive_labels"] # [masked_batch_size, num_positives]
|
|
234
|
+
negative_labels = sampled["negative_labels"] # [masked_batch_size, num_negatives] or [num_negatives]
|
|
235
|
+
|
|
236
|
+
# [masked_batch_size, num_negatives] - assign low values to some negative logits
|
|
237
|
+
negative_logits = mask_negative_logits(
|
|
238
|
+
negative_logits,
|
|
239
|
+
negative_labels,
|
|
240
|
+
positive_labels,
|
|
241
|
+
self.negative_labels_ignore_index,
|
|
242
|
+
)
|
|
243
|
+
# [masked_batch_size, 1 + num_negatives] - all logits
|
|
244
|
+
logits = torch.cat((positive_logits, negative_logits), dim=-1)
|
|
245
|
+
# [masked_batch_size] - positives are always at 0 position for all recommendation points
|
|
246
|
+
target = torch.zeros(positive_logits.size(0), dtype=torch.long, device=logits.device)
|
|
247
|
+
# [masked_batch_size] - loss for all recommendation points
|
|
248
|
+
loss = self._loss(logits, target)
|
|
249
|
+
return loss
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class CESampledWeighted(CESampled):
|
|
253
|
+
"""
|
|
254
|
+
Sampled Cross-Entropy loss (Cross-Entropy with negative sampling).
|
|
255
|
+
Calculates loss between one positive item and K negatively sampled items.
|
|
256
|
+
|
|
257
|
+
In addition to calculating the standard loss,
|
|
258
|
+
weights are applied for each sample.
|
|
259
|
+
Therefore, it is expected that the sample weights will be in the generated batch,
|
|
260
|
+
which is fed into the model.
|
|
261
|
+
|
|
262
|
+
The loss supports the calculation of logits for the case of multi-positive labels
|
|
263
|
+
(there are several labels for each position in the sequence).
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
def __init__(
|
|
267
|
+
self,
|
|
268
|
+
feature_name: str,
|
|
269
|
+
negative_labels_ignore_index: int = -100,
|
|
270
|
+
**kwargs,
|
|
271
|
+
):
|
|
272
|
+
"""
|
|
273
|
+
To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``.
|
|
274
|
+
You can pass all other parameters for initializing the object via kwargs.
|
|
275
|
+
|
|
276
|
+
:param feature_name: the name of the key in the batch.
|
|
277
|
+
The tensor is expected to contain sample weights.
|
|
278
|
+
:param negative_labels_ignore_index: padding value for negative labels.
|
|
279
|
+
This may be the case when negative labels
|
|
280
|
+
are formed at the preprocessing level, rather than the negative sampler.
|
|
281
|
+
The index is ignored and does not contribute to the loss.
|
|
282
|
+
Default: ``-100``.
|
|
283
|
+
"""
|
|
284
|
+
super().__init__(negative_labels_ignore_index=negative_labels_ignore_index)
|
|
285
|
+
self.feature_name = feature_name
|
|
286
|
+
self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs)
|
|
287
|
+
|
|
288
|
+
def forward(
|
|
289
|
+
self,
|
|
290
|
+
model_embeddings: torch.Tensor,
|
|
291
|
+
feature_tensors: TensorMap,
|
|
292
|
+
positive_labels: torch.LongTensor,
|
|
293
|
+
negative_labels: torch.LongTensor,
|
|
294
|
+
padding_mask: torch.BoolTensor, # noqa: ARG002
|
|
295
|
+
target_padding_mask: torch.BoolTensor,
|
|
296
|
+
) -> torch.Tensor:
|
|
297
|
+
"""
|
|
298
|
+
forward(model_embeddings, feature_tensors, positive_labels, negative_labels, target_padding_mask)
|
|
299
|
+
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
300
|
+
:param feature_tensors: a dictionary of tensors from dataloader.
|
|
301
|
+
This dictionary is expected to contain a key with the name ``feature_name``,
|
|
302
|
+
which is specified in the constructor.
|
|
303
|
+
Expected shape of tensor ``(batch_size, sequence_length, num_positives)``.
|
|
304
|
+
:param positive_labels: labels of positive events
|
|
305
|
+
of shape ``(batch_size, sequence_length, num_positives)``.
|
|
306
|
+
:param negative_labels: labels of sampled negative events of shape (num_negatives).
|
|
307
|
+
:param target_padding_mask: padding mask corresponding for ``positive_labels``
|
|
308
|
+
of shape ``(batch_size, sequence_length, num_positives)``
|
|
309
|
+
:return: computed loss value.
|
|
310
|
+
"""
|
|
311
|
+
loss: torch.Tensor = super().forward(
|
|
312
|
+
model_embeddings, None, positive_labels, negative_labels, None, target_padding_mask
|
|
313
|
+
)
|
|
314
|
+
sample_weight = feature_tensors[self.feature_name]
|
|
315
|
+
sample_weight = sample_weight[target_padding_mask]
|
|
316
|
+
loss = (loss * sample_weight).mean()
|
|
317
|
+
return loss
|