replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay_rec-0.20.3.dist-info/RECORD +0 -138
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
|
+
|
|
4
|
+
import lightning
|
|
5
|
+
import torch
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from replay.nn.lightning.optimizer import BaseOptimizerFactory, OptimizerFactory
|
|
9
|
+
from replay.nn.lightning.scheduler import BaseLRSchedulerFactory
|
|
10
|
+
from replay.nn.output import InferenceOutput, TrainOutput
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LightningModule(lightning.LightningModule):
|
|
14
|
+
"""
|
|
15
|
+
A universal wrapper class above the PyTorch model for working with Lightning library.\n
|
|
16
|
+
Pay attention to the format of the ``forward`` function's return value.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model: torch.nn.Module,
|
|
22
|
+
optimizer_factory: Optional[BaseOptimizerFactory] = None,
|
|
23
|
+
lr_scheduler_factory: Optional[BaseLRSchedulerFactory] = None,
|
|
24
|
+
) -> None:
|
|
25
|
+
"""
|
|
26
|
+
:param model: Initialized model.\n
|
|
27
|
+
Expected result of the model's ``forward`` function
|
|
28
|
+
is an object of the ``TrainOutput`` class after training stage
|
|
29
|
+
and ``InferenceOutput`` after inference stage.
|
|
30
|
+
:param optimizer_factory: Optimizer factory.
|
|
31
|
+
Default: ``None``.
|
|
32
|
+
:param lr_scheduler_factory: Learning rate schedule factory.
|
|
33
|
+
Default: ``None``.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.save_hyperparameters(ignore=["model"])
|
|
37
|
+
self.model = model
|
|
38
|
+
|
|
39
|
+
self._optimizer_factory = optimizer_factory
|
|
40
|
+
self._lr_scheduler_factory = lr_scheduler_factory
|
|
41
|
+
self.candidates_to_score = None
|
|
42
|
+
|
|
43
|
+
def forward(self, batch: dict) -> Union[TrainOutput, InferenceOutput]:
|
|
44
|
+
"""
|
|
45
|
+
Implementation of the forward function.
|
|
46
|
+
|
|
47
|
+
:param batch: A dictionary containing all the necessary information to run the forward function on the model.
|
|
48
|
+
The dictionary keys must match the names of the arguments in the model's forward function.
|
|
49
|
+
Keys that do not match the arguments of the model's forward function are filtered out.
|
|
50
|
+
If the model supports calculating logits for custom candidates on the inference stage,
|
|
51
|
+
then you can submit them inside the batch or using the ``candidates_to_score`` field.
|
|
52
|
+
:returns: During training, the model will return an object
|
|
53
|
+
of the ``TrainOutput`` container class or its successor.
|
|
54
|
+
At the inference stage, the ``InferenceOutput`` class or its successor will be returned.
|
|
55
|
+
"""
|
|
56
|
+
if "candidates_to_score" not in batch and self.candidates_to_score is not None and not self.training:
|
|
57
|
+
batch["candidates_to_score"] = self.candidates_to_score
|
|
58
|
+
# select only args for model.forward
|
|
59
|
+
modified_batch = {k: v for k, v in batch.items() if k in inspect.signature(self.model.forward).parameters}
|
|
60
|
+
return self.model(**modified_batch)
|
|
61
|
+
|
|
62
|
+
def training_step(self, batch: dict) -> torch.Tensor:
|
|
63
|
+
model_output: TrainOutput = self(batch)
|
|
64
|
+
loss = model_output["loss"]
|
|
65
|
+
lr = self.optimizers().param_groups[0]["lr"] # Get current learning rate
|
|
66
|
+
self.log("learning_rate", lr, on_step=True, on_epoch=True, prog_bar=True)
|
|
67
|
+
self.log(
|
|
68
|
+
"train_loss",
|
|
69
|
+
loss,
|
|
70
|
+
on_step=True,
|
|
71
|
+
on_epoch=True,
|
|
72
|
+
prog_bar=True,
|
|
73
|
+
sync_dist=True,
|
|
74
|
+
)
|
|
75
|
+
return loss
|
|
76
|
+
|
|
77
|
+
@override
|
|
78
|
+
def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
|
|
79
|
+
model_output: InferenceOutput = self(batch)
|
|
80
|
+
return model_output
|
|
81
|
+
|
|
82
|
+
@override
|
|
83
|
+
def test_step(self, batch: dict, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
|
|
84
|
+
model_output: InferenceOutput = self(batch)
|
|
85
|
+
return model_output
|
|
86
|
+
|
|
87
|
+
@override
|
|
88
|
+
def validation_step(self, batch: dict, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
|
|
89
|
+
model_output: InferenceOutput = self(batch)
|
|
90
|
+
return model_output
|
|
91
|
+
|
|
92
|
+
def configure_optimizers(self) -> Any:
|
|
93
|
+
"""
|
|
94
|
+
Returns:
|
|
95
|
+
Tuple[List[torch.optim.Optimizer], List[torch.optim.lr_scheduler._LRScheduler]]:
|
|
96
|
+
Configured optimizer and lr scheduler.
|
|
97
|
+
"""
|
|
98
|
+
optimizer_factory = self._optimizer_factory or OptimizerFactory()
|
|
99
|
+
optimizer = optimizer_factory.create(self.model.parameters())
|
|
100
|
+
|
|
101
|
+
if self._lr_scheduler_factory is None:
|
|
102
|
+
return optimizer
|
|
103
|
+
|
|
104
|
+
lr_scheduler = self._lr_scheduler_factory.create(optimizer)
|
|
105
|
+
return [optimizer], [lr_scheduler]
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def candidates_to_score(self) -> Optional[torch.LongTensor]:
|
|
109
|
+
"""
|
|
110
|
+
:getter: Returns a tensor containing the candidate IDs.
|
|
111
|
+
The tensor will be used during the inference stage of the model.\n
|
|
112
|
+
If the parameter was not previously set, ``None`` will be returned.
|
|
113
|
+
:setter: A one-dimensional tensor containing candidate IDs is expected.
|
|
114
|
+
"""
|
|
115
|
+
return self._candidates_to_score
|
|
116
|
+
|
|
117
|
+
@candidates_to_score.setter
|
|
118
|
+
def candidates_to_score(self, candidates: Optional[torch.LongTensor] = None) -> None:
|
|
119
|
+
if (candidates is not None) and bool(candidates.unique().numel() != candidates.numel()):
|
|
120
|
+
msg = "The tensor of candidates to score must be unique."
|
|
121
|
+
raise ValueError(msg)
|
|
122
|
+
|
|
123
|
+
self._candidates_to_score = candidates
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseOptimizerFactory(abc.ABC):
|
|
9
|
+
"""
|
|
10
|
+
Interface for optimizer factory
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@abc.abstractmethod
|
|
14
|
+
def create(self, parameters: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer: # pragma: no cover
|
|
15
|
+
"""
|
|
16
|
+
Creates optimizer based on parameters.
|
|
17
|
+
|
|
18
|
+
:param parameters: torch parameters to initialize optimizer
|
|
19
|
+
|
|
20
|
+
:returns: torch optimizer
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OptimizerFactory(BaseOptimizerFactory):
|
|
25
|
+
"""
|
|
26
|
+
Factory that creates optimizer depending on passed parameters
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
optimizer: Literal["adam", "sgd"] = "adam",
|
|
32
|
+
learning_rate: float = 0.001,
|
|
33
|
+
weight_decay: float = 0.0,
|
|
34
|
+
sgd_momentum: float = 0.0,
|
|
35
|
+
betas: tuple[float, float] = (0.9, 0.98),
|
|
36
|
+
) -> None:
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.optimizer = optimizer
|
|
39
|
+
self.learning_rate = learning_rate
|
|
40
|
+
self.weight_decay = weight_decay
|
|
41
|
+
self.sgd_momentum = sgd_momentum
|
|
42
|
+
self.betas = betas
|
|
43
|
+
|
|
44
|
+
def create(self, parameters: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer:
|
|
45
|
+
"""
|
|
46
|
+
Creates optimizer based on parameters.
|
|
47
|
+
|
|
48
|
+
:param parameters: torch parameters to initialize optimizer
|
|
49
|
+
|
|
50
|
+
:returns: torch optimizer
|
|
51
|
+
"""
|
|
52
|
+
if self.optimizer == "adam":
|
|
53
|
+
return torch.optim.Adam(parameters, lr=self.learning_rate, weight_decay=self.weight_decay, betas=self.betas)
|
|
54
|
+
if self.optimizer == "sgd":
|
|
55
|
+
return torch.optim.SGD(
|
|
56
|
+
parameters, lr=self.learning_rate, weight_decay=self.weight_decay, momentum=self.sgd_momentum
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
msg = "Unexpected optimizer"
|
|
60
|
+
raise ValueError(msg)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Optional, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PostprocessorBase(abc.ABC): # pragma: no cover
|
|
8
|
+
"""
|
|
9
|
+
Abstract base class for post processor
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
@abc.abstractmethod
|
|
13
|
+
def on_validation(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
|
|
14
|
+
"""
|
|
15
|
+
The method is called externally inside the callback at the validation stage.
|
|
16
|
+
|
|
17
|
+
:param batch: the batch sent to the model from the dataloader
|
|
18
|
+
:param logits: logits from the model
|
|
19
|
+
|
|
20
|
+
:returns: modified logits
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@abc.abstractmethod
|
|
24
|
+
def on_prediction(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
"""
|
|
26
|
+
The method is called externally inside the callback at the prediction (inference) stage.
|
|
27
|
+
|
|
28
|
+
:param batch: the batch sent to the model from the dataloader
|
|
29
|
+
:param logits: logits from the model
|
|
30
|
+
|
|
31
|
+
:returns: modified logits
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def candidates(self) -> Union[torch.LongTensor, None]:
|
|
36
|
+
"""
|
|
37
|
+
Returns tensor of item ids to calculate scores.
|
|
38
|
+
"""
|
|
39
|
+
return self._candidates
|
|
40
|
+
|
|
41
|
+
@candidates.setter
|
|
42
|
+
def candidates(self, candidates: Optional[torch.LongTensor] = None) -> None:
|
|
43
|
+
"""
|
|
44
|
+
Sets tensor of item ids to calculate scores.
|
|
45
|
+
:param candidates: Tensor of item ids to calculate scores.
|
|
46
|
+
"""
|
|
47
|
+
if (candidates is not None) and bool(candidates.unique().numel() != candidates.numel()):
|
|
48
|
+
msg = "The tensor of candidates to score must be unique."
|
|
49
|
+
raise ValueError(msg)
|
|
50
|
+
|
|
51
|
+
self._candidates = candidates
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from replay.data.nn import TensorMap
|
|
4
|
+
|
|
5
|
+
from ._base import PostprocessorBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SeenItemsFilter(PostprocessorBase):
|
|
9
|
+
"""
|
|
10
|
+
Masks (sets logits value to ``-inf``) the items that already have been seen in the given dataset
|
|
11
|
+
(i.e. in the sequence of items for that logits are calculated).\n
|
|
12
|
+
Should be used in Lightning callbacks for inferencing or metrics computing.
|
|
13
|
+
|
|
14
|
+
.. rubric:: Input example:
|
|
15
|
+
|
|
16
|
+
logits [B=2 users, I=3 items]::
|
|
17
|
+
|
|
18
|
+
logits =
|
|
19
|
+
[[0.1, 0.2, 0.3], # user0
|
|
20
|
+
[-0.1, -0.2, -0.3]] # user1
|
|
21
|
+
|
|
22
|
+
Seen items per user::
|
|
23
|
+
|
|
24
|
+
seen_items =
|
|
25
|
+
user0: [1, 0]
|
|
26
|
+
user1: [1, 2, 1]
|
|
27
|
+
|
|
28
|
+
.. rubric:: Output example:
|
|
29
|
+
|
|
30
|
+
SeenItemsFilter sets logits of seen items to ``-inf``::
|
|
31
|
+
|
|
32
|
+
processed_logits =
|
|
33
|
+
[[ -inf, -inf, 0.3000], # user0
|
|
34
|
+
[-0.1000, -inf, -inf]] # user1
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, item_count: int, seen_items_column="seen_ids") -> None:
|
|
39
|
+
"""
|
|
40
|
+
:param item_count: Total number of items that the model knows about (``cardinality``).
|
|
41
|
+
It is recommended to take this value from ``TensorSchema``. \n
|
|
42
|
+
Please note that values outside the range [0, `item_count-1`] are filtered out (considered as padding).
|
|
43
|
+
:param seen_items_column: Name of the column in batch that contains users' interactions (seen item ids).
|
|
44
|
+
"""
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.item_count = item_count
|
|
47
|
+
self.seen_items_column = seen_items_column
|
|
48
|
+
self._candidates = None
|
|
49
|
+
|
|
50
|
+
def on_validation(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
|
|
51
|
+
return self._compute_scores(batch, logits.detach().clone())
|
|
52
|
+
|
|
53
|
+
def on_prediction(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
return self._compute_scores(batch, logits.detach().clone())
|
|
55
|
+
|
|
56
|
+
def _compute_scores(
|
|
57
|
+
self,
|
|
58
|
+
batch: TensorMap,
|
|
59
|
+
logits: torch.Tensor,
|
|
60
|
+
) -> torch.Tensor:
|
|
61
|
+
seen_ids_padded = batch[self.seen_items_column]
|
|
62
|
+
padding_mask = (seen_ids_padded < self.item_count) & (seen_ids_padded >= 0)
|
|
63
|
+
|
|
64
|
+
batch_factors = torch.arange(0, logits.size(0), device=logits.device) * self.item_count
|
|
65
|
+
factored_ids = seen_ids_padded + batch_factors.unsqueeze(1)
|
|
66
|
+
seen_ids_flat = factored_ids[padding_mask]
|
|
67
|
+
|
|
68
|
+
if self._candidates is not None:
|
|
69
|
+
_logits = torch.full((logits.size(0), self.item_count), -torch.inf)
|
|
70
|
+
_logits[:, self._candidates] = torch.reshape(logits, _logits[:, self.candidates].shape)
|
|
71
|
+
logits = _logits
|
|
72
|
+
|
|
73
|
+
if logits.is_contiguous():
|
|
74
|
+
logits.view(-1)[seen_ids_flat] = -torch.inf
|
|
75
|
+
else:
|
|
76
|
+
flat_scores = logits.flatten()
|
|
77
|
+
flat_scores[seen_ids_flat] = -torch.inf
|
|
78
|
+
logits = flat_scores.reshape(logits.shape)
|
|
79
|
+
|
|
80
|
+
if self._candidates is not None:
|
|
81
|
+
logits = logits[:, self._candidates]
|
|
82
|
+
|
|
83
|
+
return logits
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseLRSchedulerFactory(abc.ABC):
|
|
9
|
+
"""
|
|
10
|
+
Interface for learning rate scheduler factory
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@abc.abstractmethod
|
|
14
|
+
def create(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler: # pragma: no cover
|
|
15
|
+
"""
|
|
16
|
+
Creates learning rate scheduler based on optimizer.
|
|
17
|
+
|
|
18
|
+
:param optimizer: torch optimizer
|
|
19
|
+
|
|
20
|
+
:returns: torch LRScheduler
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LRSchedulerFactory(BaseLRSchedulerFactory):
|
|
25
|
+
"""
|
|
26
|
+
Factory that creates learning rate schedule depending on passed parameters
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, decay_step: int = 25, gamma: float = 1.0) -> None:
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.decay_step = decay_step
|
|
32
|
+
self.gamma = gamma
|
|
33
|
+
|
|
34
|
+
def create(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
|
35
|
+
"""
|
|
36
|
+
Creates learning rate scheduler based on optimizer.
|
|
37
|
+
|
|
38
|
+
:param optimizer: torch optimizer
|
|
39
|
+
|
|
40
|
+
:returns: torch LRScheduler
|
|
41
|
+
"""
|
|
42
|
+
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.decay_step, gamma=self.gamma)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LambdaLRSchedulerFactory(BaseLRSchedulerFactory):
|
|
46
|
+
"""
|
|
47
|
+
Factory that creates learning rate schedule depending on passed parameters
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
warmup_steps: int,
|
|
53
|
+
warmup_lr: float = 1.0,
|
|
54
|
+
normal_lr: float = 0.1,
|
|
55
|
+
update_interval: Literal["epoch", "step"] = "epoch",
|
|
56
|
+
) -> None:
|
|
57
|
+
super().__init__()
|
|
58
|
+
|
|
59
|
+
if normal_lr <= 0.0:
|
|
60
|
+
msg = f"Normal LR must be positive. Got {normal_lr}"
|
|
61
|
+
raise ValueError(msg)
|
|
62
|
+
if warmup_lr <= 0.0:
|
|
63
|
+
msg = f"Warmup LR must be positive. Got {warmup_lr}"
|
|
64
|
+
raise ValueError(msg)
|
|
65
|
+
if normal_lr >= warmup_lr:
|
|
66
|
+
msg = f"Suspicious LR pair: {normal_lr=}, {warmup_lr=}"
|
|
67
|
+
warnings.warn(msg, stacklevel=2)
|
|
68
|
+
|
|
69
|
+
self.warmup_lr = warmup_lr
|
|
70
|
+
self.normal_lr = normal_lr
|
|
71
|
+
self.warmup_steps = warmup_steps
|
|
72
|
+
self.update_interval = update_interval
|
|
73
|
+
|
|
74
|
+
def create(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
|
75
|
+
"""
|
|
76
|
+
Creates learning rate scheduler based on optimizer.
|
|
77
|
+
|
|
78
|
+
:param optimizer: torch optimizer
|
|
79
|
+
|
|
80
|
+
:returns: torch LambdaLR
|
|
81
|
+
"""
|
|
82
|
+
return {
|
|
83
|
+
"scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, self.lr_lambda),
|
|
84
|
+
"interval": self.update_interval,
|
|
85
|
+
"frequency": 1,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
def lr_lambda(self, current_step: int) -> float:
|
|
89
|
+
if current_step >= self.warmup_steps:
|
|
90
|
+
return self.normal_lr
|
|
91
|
+
return self.warmup_lr
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from .base import LossProto
|
|
2
|
+
from .bce import BCE, BCESampled
|
|
3
|
+
from .ce import CE, CESampled, CESampledWeighted, CEWeighted
|
|
4
|
+
from .login_ce import LogInCE, LogInCESampled
|
|
5
|
+
from .logout_ce import LogOutCE, LogOutCEWeighted
|
|
6
|
+
|
|
7
|
+
LogOutCESampled = CE
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"BCE",
|
|
11
|
+
"CE",
|
|
12
|
+
"BCESampled",
|
|
13
|
+
"CESampled",
|
|
14
|
+
"CESampledWeighted",
|
|
15
|
+
"CEWeighted",
|
|
16
|
+
"LogInCE",
|
|
17
|
+
"LogInCESampled",
|
|
18
|
+
"LogOutCE",
|
|
19
|
+
"LogOutCESampled",
|
|
20
|
+
"LogOutCEWeighted",
|
|
21
|
+
"LossProto",
|
|
22
|
+
]
|
replay/nn/loss/base.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from typing import Callable, Optional, Protocol, TypedDict
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from replay.data.nn import TensorMap
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LossProto(Protocol):
|
|
9
|
+
"""Class-protocol for working with losses inside models"""
|
|
10
|
+
|
|
11
|
+
@property
|
|
12
|
+
def logits_callback(
|
|
13
|
+
self,
|
|
14
|
+
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: ...
|
|
15
|
+
|
|
16
|
+
@logits_callback.setter
|
|
17
|
+
def logits_callback(self, func: Optional[Callable]) -> None: ...
|
|
18
|
+
|
|
19
|
+
def forward(
|
|
20
|
+
self,
|
|
21
|
+
model_embeddings: torch.Tensor,
|
|
22
|
+
feature_tensors: TensorMap,
|
|
23
|
+
positive_labels: torch.LongTensor,
|
|
24
|
+
negative_labels: torch.LongTensor,
|
|
25
|
+
padding_mask: torch.BoolTensor,
|
|
26
|
+
target_padding_mask: torch.BoolTensor,
|
|
27
|
+
) -> torch.Tensor: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SampledLossOutput(TypedDict):
|
|
31
|
+
"""A class containing result of the `get_sampled_logits` function in sampled losses"""
|
|
32
|
+
|
|
33
|
+
positive_logits: torch.Tensor
|
|
34
|
+
negative_logits: torch.Tensor
|
|
35
|
+
positive_labels: torch.LongTensor
|
|
36
|
+
negative_labels: torch.LongTensor
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SampledLossBase(torch.nn.Module):
|
|
40
|
+
"""The base class for calculating sampled losses"""
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def logits_callback(
|
|
44
|
+
self,
|
|
45
|
+
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
|
46
|
+
raise NotImplementedError() # pragma: no cover
|
|
47
|
+
|
|
48
|
+
def get_sampled_logits(
|
|
49
|
+
self,
|
|
50
|
+
model_embeddings: torch.Tensor,
|
|
51
|
+
positive_labels: torch.LongTensor, # [batch_size, seq_len, num_positives]
|
|
52
|
+
negative_labels: torch.LongTensor, # [num_negatives] or [batch_size, seq_len, num_negatives]
|
|
53
|
+
target_padding_mask: torch.BoolTensor, # [batch_size, seq_len, num_positives]
|
|
54
|
+
) -> SampledLossOutput:
|
|
55
|
+
"""
|
|
56
|
+
The function of calculating positive and negative logits.
|
|
57
|
+
Based on the model last hidden state, positive and negative labels.
|
|
58
|
+
|
|
59
|
+
The function supports the calculation of logits for the case of multi-positive labels
|
|
60
|
+
(there are several labels for each position in the sequence).
|
|
61
|
+
|
|
62
|
+
:param model_embeddings: Embeddings from the model. This is usually the last hidden state.
|
|
63
|
+
Expected shape: ``(batch_size, sequence_length, embedding_dim)``
|
|
64
|
+
:param positive_labels: a tensor containing labels with positive events.
|
|
65
|
+
Expected shape: ``(batch_size, sequence_length, num_positives)``
|
|
66
|
+
:param negative_labels: a tensor containing labels with negative events.
|
|
67
|
+
Expected shape:
|
|
68
|
+
- ``(batch_size, sequence_length, num_negatives)``
|
|
69
|
+
- ``(batch_size, num_negatives)``
|
|
70
|
+
- ``(num_negatives)`` - a case where the same negative events are used for the entire batch.
|
|
71
|
+
:param target_padding_mask: Padding mask for ``positive_labels`` (targets).
|
|
72
|
+
``False`` value indicates that the corresponding ``key`` value will be ignored.
|
|
73
|
+
Expected shape: ``(batch_size, sequence_length, num_positives)``
|
|
74
|
+
|
|
75
|
+
:returns: SampledLossOutput. A dictionary containing positive and negative logits with labels.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
initial_positive_labels = positive_labels
|
|
79
|
+
################## SHAPE CHECKING STAGE START ##################
|
|
80
|
+
batch_size, seq_len, num_positives = positive_labels.size()
|
|
81
|
+
assert target_padding_mask.size() == (batch_size, seq_len, num_positives)
|
|
82
|
+
num_negatives = negative_labels.size(-1)
|
|
83
|
+
|
|
84
|
+
if negative_labels.size() == (batch_size, num_negatives):
|
|
85
|
+
# [batch_size, num_negatives] -> [batch_size, 1, num_negatives]
|
|
86
|
+
negative_labels = negative_labels.unsqueeze(1).repeat(1, seq_len, 1)
|
|
87
|
+
|
|
88
|
+
if negative_labels.dim() == 3: # pragma: no cover
|
|
89
|
+
# [batch_size, seq_len, num_negatives] -> [batch_size, seq_len, 1, num_negatives]
|
|
90
|
+
negative_labels = negative_labels.unsqueeze(-2)
|
|
91
|
+
if num_positives != 1:
|
|
92
|
+
# [batch_size, seq_len, num_negatives] -> [batch_size, seq_len, num_positives, num_negatives]
|
|
93
|
+
negative_labels = negative_labels.repeat((1, 1, num_positives, 1))
|
|
94
|
+
assert (
|
|
95
|
+
negative_labels.size() == (batch_size, seq_len, num_positives, num_negatives) or negative_labels.dim() == 1
|
|
96
|
+
)
|
|
97
|
+
################## SHAPE CHECKING STAGE END ##################
|
|
98
|
+
|
|
99
|
+
# Get output embedding for every user event
|
|
100
|
+
embedding_dim = model_embeddings.size(-1)
|
|
101
|
+
assert model_embeddings.size() == (batch_size, seq_len, embedding_dim)
|
|
102
|
+
|
|
103
|
+
# [batch_size, seq_len, emb_dim] -> [batch_size, seq_len, 1, emb_dim]
|
|
104
|
+
model_embeddings = model_embeddings.unsqueeze(-2)
|
|
105
|
+
if num_positives != 1: # multti positive branch
|
|
106
|
+
model_embeddings = model_embeddings.repeat((1, 1, num_positives, 1))
|
|
107
|
+
assert model_embeddings.size() == (
|
|
108
|
+
batch_size,
|
|
109
|
+
seq_len,
|
|
110
|
+
num_positives,
|
|
111
|
+
embedding_dim,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Apply target mask
|
|
115
|
+
# [batch_size, seq_len, num_positives] -> [batch_size, seq_len]
|
|
116
|
+
masked_batch_size = target_padding_mask.sum().item()
|
|
117
|
+
|
|
118
|
+
# [batch_size, seq_len, num_positives] -> [masked_batch_size, 1]
|
|
119
|
+
positive_labels = positive_labels[target_padding_mask].unsqueeze(-1)
|
|
120
|
+
assert positive_labels.size() == (masked_batch_size, 1)
|
|
121
|
+
|
|
122
|
+
if negative_labels.dim() != 1: # pragma: no cover
|
|
123
|
+
# [batch_size, seq_len, num_positives, num_negatives] -> [masked_batch_size, num_negatives]
|
|
124
|
+
negative_labels = negative_labels[target_padding_mask]
|
|
125
|
+
assert negative_labels.size() == (masked_batch_size, num_negatives)
|
|
126
|
+
|
|
127
|
+
# [batch_size, seq_len, num_positives, emb_dim] -> [masked_batch_size, emb_dim]
|
|
128
|
+
model_embeddings = model_embeddings[target_padding_mask]
|
|
129
|
+
assert model_embeddings.size() == (masked_batch_size, embedding_dim)
|
|
130
|
+
|
|
131
|
+
# Get positive and negative logits
|
|
132
|
+
positive_logits = self.logits_callback(model_embeddings, positive_labels)
|
|
133
|
+
assert positive_logits.size() == (masked_batch_size, 1)
|
|
134
|
+
|
|
135
|
+
negative_logits = self.logits_callback(model_embeddings, negative_labels)
|
|
136
|
+
assert negative_logits.size() == (masked_batch_size, num_negatives)
|
|
137
|
+
|
|
138
|
+
if num_positives != 1:
|
|
139
|
+
# [batch_size, seq_len, num_positives] -> [batch_size * seq_len]
|
|
140
|
+
masked_target_padding_mask = target_padding_mask.sum(-1).view(-1)
|
|
141
|
+
# [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
|
|
142
|
+
positive_labels = torch.repeat_interleave(
|
|
143
|
+
initial_positive_labels.view(-1, num_positives),
|
|
144
|
+
masked_target_padding_mask,
|
|
145
|
+
dim=0,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
return {
|
|
149
|
+
"positive_logits": positive_logits,
|
|
150
|
+
"negative_logits": negative_logits,
|
|
151
|
+
"positive_labels": positive_labels,
|
|
152
|
+
"negative_labels": negative_labels,
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def mask_negative_logits(
|
|
157
|
+
negative_logits: torch.Tensor,
|
|
158
|
+
negative_labels: torch.LongTensor,
|
|
159
|
+
positive_labels: torch.LongTensor,
|
|
160
|
+
negative_labels_ignore_index: int,
|
|
161
|
+
) -> torch.Tensor:
|
|
162
|
+
"""
|
|
163
|
+
Assign very small values in negative logits
|
|
164
|
+
for positions where positive labels equal to negative ones.
|
|
165
|
+
|
|
166
|
+
:param negative_logits: Logits from the model for ``negative labels``.
|
|
167
|
+
Expected shape: (masked_batch_size, num_negatives)
|
|
168
|
+
:param negative_labels: a tensor containing labels with negative events.
|
|
169
|
+
Expected shape:
|
|
170
|
+
- (masked_batch_size, num_negatives)
|
|
171
|
+
- (num_negatives) - a case where the same negative events are used for the entire batch
|
|
172
|
+
:param positive_labels: a tensor containing labels with positive events.
|
|
173
|
+
Expected shape: (masked_batch_size, num_positives)
|
|
174
|
+
:param negative_labels_ignore_index: padding value for negative labels.
|
|
175
|
+
This may be the case when negative labels
|
|
176
|
+
are formed at the preprocessing level, rather than the negative sampler.
|
|
177
|
+
The index is ignored and does not contribute to the loss.
|
|
178
|
+
|
|
179
|
+
:returns: Negative logits with modified elements in those positions
|
|
180
|
+
where positive labels are equal to negative ones.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
if negative_labels_ignore_index >= 0:
|
|
184
|
+
negative_logits.masked_fill_(negative_labels == negative_labels_ignore_index, -1e9)
|
|
185
|
+
|
|
186
|
+
if negative_labels.dim() > 1: # pragma: no cover
|
|
187
|
+
# [masked_batch_size, num_negatives] -> [masked_batch_size, 1, num_negatives]
|
|
188
|
+
negative_labels = negative_labels.unsqueeze(-2)
|
|
189
|
+
|
|
190
|
+
# [masked_batch_size, num_positives] -> [masked_batch_size, num_positives, 1]
|
|
191
|
+
positive_labels = positive_labels.unsqueeze(-1)
|
|
192
|
+
negative_mask = positive_labels == negative_labels # [masked_batch_size, num_positives, num_negatives]
|
|
193
|
+
|
|
194
|
+
# [masked_batch_size, num_positives, num_negatives] -> [masked_batch_size, num_negatives]
|
|
195
|
+
negative_mask = negative_mask.sum(-2).bool()
|
|
196
|
+
negative_logits.masked_fill_(negative_mask, -1e9)
|
|
197
|
+
return negative_logits
|