replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,123 @@
1
+ import inspect
2
+ from typing import Any, Optional, Union
3
+
4
+ import lightning
5
+ import torch
6
+ from typing_extensions import override
7
+
8
+ from replay.nn.lightning.optimizer import BaseOptimizerFactory, OptimizerFactory
9
+ from replay.nn.lightning.scheduler import BaseLRSchedulerFactory
10
+ from replay.nn.output import InferenceOutput, TrainOutput
11
+
12
+
13
+ class LightningModule(lightning.LightningModule):
14
+ """
15
+ A universal wrapper class above the PyTorch model for working with Lightning library.\n
16
+ Pay attention to the format of the ``forward`` function's return value.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ model: torch.nn.Module,
22
+ optimizer_factory: Optional[BaseOptimizerFactory] = None,
23
+ lr_scheduler_factory: Optional[BaseLRSchedulerFactory] = None,
24
+ ) -> None:
25
+ """
26
+ :param model: Initialized model.\n
27
+ Expected result of the model's ``forward`` function
28
+ is an object of the ``TrainOutput`` class after training stage
29
+ and ``InferenceOutput`` after inference stage.
30
+ :param optimizer_factory: Optimizer factory.
31
+ Default: ``None``.
32
+ :param lr_scheduler_factory: Learning rate schedule factory.
33
+ Default: ``None``.
34
+ """
35
+ super().__init__()
36
+ self.save_hyperparameters(ignore=["model"])
37
+ self.model = model
38
+
39
+ self._optimizer_factory = optimizer_factory
40
+ self._lr_scheduler_factory = lr_scheduler_factory
41
+ self.candidates_to_score = None
42
+
43
+ def forward(self, batch: dict) -> Union[TrainOutput, InferenceOutput]:
44
+ """
45
+ Implementation of the forward function.
46
+
47
+ :param batch: A dictionary containing all the necessary information to run the forward function on the model.
48
+ The dictionary keys must match the names of the arguments in the model's forward function.
49
+ Keys that do not match the arguments of the model's forward function are filtered out.
50
+ If the model supports calculating logits for custom candidates on the inference stage,
51
+ then you can submit them inside the batch or using the ``candidates_to_score`` field.
52
+ :returns: During training, the model will return an object
53
+ of the ``TrainOutput`` container class or its successor.
54
+ At the inference stage, the ``InferenceOutput`` class or its successor will be returned.
55
+ """
56
+ if "candidates_to_score" not in batch and self.candidates_to_score is not None and not self.training:
57
+ batch["candidates_to_score"] = self.candidates_to_score
58
+ # select only args for model.forward
59
+ modified_batch = {k: v for k, v in batch.items() if k in inspect.signature(self.model.forward).parameters}
60
+ return self.model(**modified_batch)
61
+
62
+ def training_step(self, batch: dict) -> torch.Tensor:
63
+ model_output: TrainOutput = self(batch)
64
+ loss = model_output["loss"]
65
+ lr = self.optimizers().param_groups[0]["lr"] # Get current learning rate
66
+ self.log("learning_rate", lr, on_step=True, on_epoch=True, prog_bar=True)
67
+ self.log(
68
+ "train_loss",
69
+ loss,
70
+ on_step=True,
71
+ on_epoch=True,
72
+ prog_bar=True,
73
+ sync_dist=True,
74
+ )
75
+ return loss
76
+
77
+ @override
78
+ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
79
+ model_output: InferenceOutput = self(batch)
80
+ return model_output
81
+
82
+ @override
83
+ def test_step(self, batch: dict, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
84
+ model_output: InferenceOutput = self(batch)
85
+ return model_output
86
+
87
+ @override
88
+ def validation_step(self, batch: dict, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
89
+ model_output: InferenceOutput = self(batch)
90
+ return model_output
91
+
92
+ def configure_optimizers(self) -> Any:
93
+ """
94
+ Returns:
95
+ Tuple[List[torch.optim.Optimizer], List[torch.optim.lr_scheduler._LRScheduler]]:
96
+ Configured optimizer and lr scheduler.
97
+ """
98
+ optimizer_factory = self._optimizer_factory or OptimizerFactory()
99
+ optimizer = optimizer_factory.create(self.model.parameters())
100
+
101
+ if self._lr_scheduler_factory is None:
102
+ return optimizer
103
+
104
+ lr_scheduler = self._lr_scheduler_factory.create(optimizer)
105
+ return [optimizer], [lr_scheduler]
106
+
107
+ @property
108
+ def candidates_to_score(self) -> Optional[torch.LongTensor]:
109
+ """
110
+ :getter: Returns a tensor containing the candidate IDs.
111
+ The tensor will be used during the inference stage of the model.\n
112
+ If the parameter was not previously set, ``None`` will be returned.
113
+ :setter: A one-dimensional tensor containing candidate IDs is expected.
114
+ """
115
+ return self._candidates_to_score
116
+
117
+ @candidates_to_score.setter
118
+ def candidates_to_score(self, candidates: Optional[torch.LongTensor] = None) -> None:
119
+ if (candidates is not None) and bool(candidates.unique().numel() != candidates.numel()):
120
+ msg = "The tensor of candidates to score must be unique."
121
+ raise ValueError(msg)
122
+
123
+ self._candidates_to_score = candidates
@@ -0,0 +1,60 @@
1
+ import abc
2
+ from collections.abc import Iterator
3
+ from typing import Literal
4
+
5
+ import torch
6
+
7
+
8
+ class BaseOptimizerFactory(abc.ABC):
9
+ """
10
+ Interface for optimizer factory
11
+ """
12
+
13
+ @abc.abstractmethod
14
+ def create(self, parameters: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer: # pragma: no cover
15
+ """
16
+ Creates optimizer based on parameters.
17
+
18
+ :param parameters: torch parameters to initialize optimizer
19
+
20
+ :returns: torch optimizer
21
+ """
22
+
23
+
24
+ class OptimizerFactory(BaseOptimizerFactory):
25
+ """
26
+ Factory that creates optimizer depending on passed parameters
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ optimizer: Literal["adam", "sgd"] = "adam",
32
+ learning_rate: float = 0.001,
33
+ weight_decay: float = 0.0,
34
+ sgd_momentum: float = 0.0,
35
+ betas: tuple[float, float] = (0.9, 0.98),
36
+ ) -> None:
37
+ super().__init__()
38
+ self.optimizer = optimizer
39
+ self.learning_rate = learning_rate
40
+ self.weight_decay = weight_decay
41
+ self.sgd_momentum = sgd_momentum
42
+ self.betas = betas
43
+
44
+ def create(self, parameters: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer:
45
+ """
46
+ Creates optimizer based on parameters.
47
+
48
+ :param parameters: torch parameters to initialize optimizer
49
+
50
+ :returns: torch optimizer
51
+ """
52
+ if self.optimizer == "adam":
53
+ return torch.optim.Adam(parameters, lr=self.learning_rate, weight_decay=self.weight_decay, betas=self.betas)
54
+ if self.optimizer == "sgd":
55
+ return torch.optim.SGD(
56
+ parameters, lr=self.learning_rate, weight_decay=self.weight_decay, momentum=self.sgd_momentum
57
+ )
58
+
59
+ msg = "Unexpected optimizer"
60
+ raise ValueError(msg)
@@ -0,0 +1,2 @@
1
+ from ._base import PostprocessorBase
2
+ from .seen_items import SeenItemsFilter
@@ -0,0 +1,51 @@
1
+ import abc
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+
6
+
7
+ class PostprocessorBase(abc.ABC): # pragma: no cover
8
+ """
9
+ Abstract base class for post processor
10
+ """
11
+
12
+ @abc.abstractmethod
13
+ def on_validation(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
14
+ """
15
+ The method is called externally inside the callback at the validation stage.
16
+
17
+ :param batch: the batch sent to the model from the dataloader
18
+ :param logits: logits from the model
19
+
20
+ :returns: modified logits
21
+ """
22
+
23
+ @abc.abstractmethod
24
+ def on_prediction(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
25
+ """
26
+ The method is called externally inside the callback at the prediction (inference) stage.
27
+
28
+ :param batch: the batch sent to the model from the dataloader
29
+ :param logits: logits from the model
30
+
31
+ :returns: modified logits
32
+ """
33
+
34
+ @property
35
+ def candidates(self) -> Union[torch.LongTensor, None]:
36
+ """
37
+ Returns tensor of item ids to calculate scores.
38
+ """
39
+ return self._candidates
40
+
41
+ @candidates.setter
42
+ def candidates(self, candidates: Optional[torch.LongTensor] = None) -> None:
43
+ """
44
+ Sets tensor of item ids to calculate scores.
45
+ :param candidates: Tensor of item ids to calculate scores.
46
+ """
47
+ if (candidates is not None) and bool(candidates.unique().numel() != candidates.numel()):
48
+ msg = "The tensor of candidates to score must be unique."
49
+ raise ValueError(msg)
50
+
51
+ self._candidates = candidates
@@ -0,0 +1,83 @@
1
+ import torch
2
+
3
+ from replay.data.nn import TensorMap
4
+
5
+ from ._base import PostprocessorBase
6
+
7
+
8
+ class SeenItemsFilter(PostprocessorBase):
9
+ """
10
+ Masks (sets logits value to ``-inf``) the items that already have been seen in the given dataset
11
+ (i.e. in the sequence of items for that logits are calculated).\n
12
+ Should be used in Lightning callbacks for inferencing or metrics computing.
13
+
14
+ .. rubric:: Input example:
15
+
16
+ logits [B=2 users, I=3 items]::
17
+
18
+ logits =
19
+ [[0.1, 0.2, 0.3], # user0
20
+ [-0.1, -0.2, -0.3]] # user1
21
+
22
+ Seen items per user::
23
+
24
+ seen_items =
25
+ user0: [1, 0]
26
+ user1: [1, 2, 1]
27
+
28
+ .. rubric:: Output example:
29
+
30
+ SeenItemsFilter sets logits of seen items to ``-inf``::
31
+
32
+ processed_logits =
33
+ [[ -inf, -inf, 0.3000], # user0
34
+ [-0.1000, -inf, -inf]] # user1
35
+
36
+ """
37
+
38
+ def __init__(self, item_count: int, seen_items_column="seen_ids") -> None:
39
+ """
40
+ :param item_count: Total number of items that the model knows about (``cardinality``).
41
+ It is recommended to take this value from ``TensorSchema``. \n
42
+ Please note that values ​​outside the range [0, `item_count-1`] are filtered out (considered as padding).
43
+ :param seen_items_column: Name of the column in batch that contains users' interactions (seen item ids).
44
+ """
45
+ super().__init__()
46
+ self.item_count = item_count
47
+ self.seen_items_column = seen_items_column
48
+ self._candidates = None
49
+
50
+ def on_validation(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
51
+ return self._compute_scores(batch, logits.detach().clone())
52
+
53
+ def on_prediction(self, batch: dict, logits: torch.Tensor) -> torch.Tensor:
54
+ return self._compute_scores(batch, logits.detach().clone())
55
+
56
+ def _compute_scores(
57
+ self,
58
+ batch: TensorMap,
59
+ logits: torch.Tensor,
60
+ ) -> torch.Tensor:
61
+ seen_ids_padded = batch[self.seen_items_column]
62
+ padding_mask = (seen_ids_padded < self.item_count) & (seen_ids_padded >= 0)
63
+
64
+ batch_factors = torch.arange(0, logits.size(0), device=logits.device) * self.item_count
65
+ factored_ids = seen_ids_padded + batch_factors.unsqueeze(1)
66
+ seen_ids_flat = factored_ids[padding_mask]
67
+
68
+ if self._candidates is not None:
69
+ _logits = torch.full((logits.size(0), self.item_count), -torch.inf)
70
+ _logits[:, self._candidates] = torch.reshape(logits, _logits[:, self.candidates].shape)
71
+ logits = _logits
72
+
73
+ if logits.is_contiguous():
74
+ logits.view(-1)[seen_ids_flat] = -torch.inf
75
+ else:
76
+ flat_scores = logits.flatten()
77
+ flat_scores[seen_ids_flat] = -torch.inf
78
+ logits = flat_scores.reshape(logits.shape)
79
+
80
+ if self._candidates is not None:
81
+ logits = logits[:, self._candidates]
82
+
83
+ return logits
@@ -0,0 +1,91 @@
1
+ import abc
2
+ import warnings
3
+ from typing import Literal
4
+
5
+ import torch
6
+
7
+
8
+ class BaseLRSchedulerFactory(abc.ABC):
9
+ """
10
+ Interface for learning rate scheduler factory
11
+ """
12
+
13
+ @abc.abstractmethod
14
+ def create(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler: # pragma: no cover
15
+ """
16
+ Creates learning rate scheduler based on optimizer.
17
+
18
+ :param optimizer: torch optimizer
19
+
20
+ :returns: torch LRScheduler
21
+ """
22
+
23
+
24
+ class LRSchedulerFactory(BaseLRSchedulerFactory):
25
+ """
26
+ Factory that creates learning rate schedule depending on passed parameters
27
+ """
28
+
29
+ def __init__(self, decay_step: int = 25, gamma: float = 1.0) -> None:
30
+ super().__init__()
31
+ self.decay_step = decay_step
32
+ self.gamma = gamma
33
+
34
+ def create(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
35
+ """
36
+ Creates learning rate scheduler based on optimizer.
37
+
38
+ :param optimizer: torch optimizer
39
+
40
+ :returns: torch LRScheduler
41
+ """
42
+ return torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.decay_step, gamma=self.gamma)
43
+
44
+
45
+ class LambdaLRSchedulerFactory(BaseLRSchedulerFactory):
46
+ """
47
+ Factory that creates learning rate schedule depending on passed parameters
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ warmup_steps: int,
53
+ warmup_lr: float = 1.0,
54
+ normal_lr: float = 0.1,
55
+ update_interval: Literal["epoch", "step"] = "epoch",
56
+ ) -> None:
57
+ super().__init__()
58
+
59
+ if normal_lr <= 0.0:
60
+ msg = f"Normal LR must be positive. Got {normal_lr}"
61
+ raise ValueError(msg)
62
+ if warmup_lr <= 0.0:
63
+ msg = f"Warmup LR must be positive. Got {warmup_lr}"
64
+ raise ValueError(msg)
65
+ if normal_lr >= warmup_lr:
66
+ msg = f"Suspicious LR pair: {normal_lr=}, {warmup_lr=}"
67
+ warnings.warn(msg, stacklevel=2)
68
+
69
+ self.warmup_lr = warmup_lr
70
+ self.normal_lr = normal_lr
71
+ self.warmup_steps = warmup_steps
72
+ self.update_interval = update_interval
73
+
74
+ def create(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
75
+ """
76
+ Creates learning rate scheduler based on optimizer.
77
+
78
+ :param optimizer: torch optimizer
79
+
80
+ :returns: torch LambdaLR
81
+ """
82
+ return {
83
+ "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, self.lr_lambda),
84
+ "interval": self.update_interval,
85
+ "frequency": 1,
86
+ }
87
+
88
+ def lr_lambda(self, current_step: int) -> float:
89
+ if current_step >= self.warmup_steps:
90
+ return self.normal_lr
91
+ return self.warmup_lr
@@ -0,0 +1,22 @@
1
+ from .base import LossProto
2
+ from .bce import BCE, BCESampled
3
+ from .ce import CE, CESampled, CESampledWeighted, CEWeighted
4
+ from .login_ce import LogInCE, LogInCESampled
5
+ from .logout_ce import LogOutCE, LogOutCEWeighted
6
+
7
+ LogOutCESampled = CE
8
+
9
+ __all__ = [
10
+ "BCE",
11
+ "CE",
12
+ "BCESampled",
13
+ "CESampled",
14
+ "CESampledWeighted",
15
+ "CEWeighted",
16
+ "LogInCE",
17
+ "LogInCESampled",
18
+ "LogOutCE",
19
+ "LogOutCESampled",
20
+ "LogOutCEWeighted",
21
+ "LossProto",
22
+ ]
replay/nn/loss/base.py ADDED
@@ -0,0 +1,197 @@
1
+ from typing import Callable, Optional, Protocol, TypedDict
2
+
3
+ import torch
4
+
5
+ from replay.data.nn import TensorMap
6
+
7
+
8
+ class LossProto(Protocol):
9
+ """Class-protocol for working with losses inside models"""
10
+
11
+ @property
12
+ def logits_callback(
13
+ self,
14
+ ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: ...
15
+
16
+ @logits_callback.setter
17
+ def logits_callback(self, func: Optional[Callable]) -> None: ...
18
+
19
+ def forward(
20
+ self,
21
+ model_embeddings: torch.Tensor,
22
+ feature_tensors: TensorMap,
23
+ positive_labels: torch.LongTensor,
24
+ negative_labels: torch.LongTensor,
25
+ padding_mask: torch.BoolTensor,
26
+ target_padding_mask: torch.BoolTensor,
27
+ ) -> torch.Tensor: ...
28
+
29
+
30
+ class SampledLossOutput(TypedDict):
31
+ """A class containing result of the `get_sampled_logits` function in sampled losses"""
32
+
33
+ positive_logits: torch.Tensor
34
+ negative_logits: torch.Tensor
35
+ positive_labels: torch.LongTensor
36
+ negative_labels: torch.LongTensor
37
+
38
+
39
+ class SampledLossBase(torch.nn.Module):
40
+ """The base class for calculating sampled losses"""
41
+
42
+ @property
43
+ def logits_callback(
44
+ self,
45
+ ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
46
+ raise NotImplementedError() # pragma: no cover
47
+
48
+ def get_sampled_logits(
49
+ self,
50
+ model_embeddings: torch.Tensor,
51
+ positive_labels: torch.LongTensor, # [batch_size, seq_len, num_positives]
52
+ negative_labels: torch.LongTensor, # [num_negatives] or [batch_size, seq_len, num_negatives]
53
+ target_padding_mask: torch.BoolTensor, # [batch_size, seq_len, num_positives]
54
+ ) -> SampledLossOutput:
55
+ """
56
+ The function of calculating positive and negative logits.
57
+ Based on the model last hidden state, positive and negative labels.
58
+
59
+ The function supports the calculation of logits for the case of multi-positive labels
60
+ (there are several labels for each position in the sequence).
61
+
62
+ :param model_embeddings: Embeddings from the model. This is usually the last hidden state.
63
+ Expected shape: ``(batch_size, sequence_length, embedding_dim)``
64
+ :param positive_labels: a tensor containing labels with positive events.
65
+ Expected shape: ``(batch_size, sequence_length, num_positives)``
66
+ :param negative_labels: a tensor containing labels with negative events.
67
+ Expected shape:
68
+ - ``(batch_size, sequence_length, num_negatives)``
69
+ - ``(batch_size, num_negatives)``
70
+ - ``(num_negatives)`` - a case where the same negative events are used for the entire batch.
71
+ :param target_padding_mask: Padding mask for ``positive_labels`` (targets).
72
+ ``False`` value indicates that the corresponding ``key`` value will be ignored.
73
+ Expected shape: ``(batch_size, sequence_length, num_positives)``
74
+
75
+ :returns: SampledLossOutput. A dictionary containing positive and negative logits with labels.
76
+ """
77
+
78
+ initial_positive_labels = positive_labels
79
+ ################## SHAPE CHECKING STAGE START ##################
80
+ batch_size, seq_len, num_positives = positive_labels.size()
81
+ assert target_padding_mask.size() == (batch_size, seq_len, num_positives)
82
+ num_negatives = negative_labels.size(-1)
83
+
84
+ if negative_labels.size() == (batch_size, num_negatives):
85
+ # [batch_size, num_negatives] -> [batch_size, 1, num_negatives]
86
+ negative_labels = negative_labels.unsqueeze(1).repeat(1, seq_len, 1)
87
+
88
+ if negative_labels.dim() == 3: # pragma: no cover
89
+ # [batch_size, seq_len, num_negatives] -> [batch_size, seq_len, 1, num_negatives]
90
+ negative_labels = negative_labels.unsqueeze(-2)
91
+ if num_positives != 1:
92
+ # [batch_size, seq_len, num_negatives] -> [batch_size, seq_len, num_positives, num_negatives]
93
+ negative_labels = negative_labels.repeat((1, 1, num_positives, 1))
94
+ assert (
95
+ negative_labels.size() == (batch_size, seq_len, num_positives, num_negatives) or negative_labels.dim() == 1
96
+ )
97
+ ################## SHAPE CHECKING STAGE END ##################
98
+
99
+ # Get output embedding for every user event
100
+ embedding_dim = model_embeddings.size(-1)
101
+ assert model_embeddings.size() == (batch_size, seq_len, embedding_dim)
102
+
103
+ # [batch_size, seq_len, emb_dim] -> [batch_size, seq_len, 1, emb_dim]
104
+ model_embeddings = model_embeddings.unsqueeze(-2)
105
+ if num_positives != 1: # multti positive branch
106
+ model_embeddings = model_embeddings.repeat((1, 1, num_positives, 1))
107
+ assert model_embeddings.size() == (
108
+ batch_size,
109
+ seq_len,
110
+ num_positives,
111
+ embedding_dim,
112
+ )
113
+
114
+ # Apply target mask
115
+ # [batch_size, seq_len, num_positives] -> [batch_size, seq_len]
116
+ masked_batch_size = target_padding_mask.sum().item()
117
+
118
+ # [batch_size, seq_len, num_positives] -> [masked_batch_size, 1]
119
+ positive_labels = positive_labels[target_padding_mask].unsqueeze(-1)
120
+ assert positive_labels.size() == (masked_batch_size, 1)
121
+
122
+ if negative_labels.dim() != 1: # pragma: no cover
123
+ # [batch_size, seq_len, num_positives, num_negatives] -> [masked_batch_size, num_negatives]
124
+ negative_labels = negative_labels[target_padding_mask]
125
+ assert negative_labels.size() == (masked_batch_size, num_negatives)
126
+
127
+ # [batch_size, seq_len, num_positives, emb_dim] -> [masked_batch_size, emb_dim]
128
+ model_embeddings = model_embeddings[target_padding_mask]
129
+ assert model_embeddings.size() == (masked_batch_size, embedding_dim)
130
+
131
+ # Get positive and negative logits
132
+ positive_logits = self.logits_callback(model_embeddings, positive_labels)
133
+ assert positive_logits.size() == (masked_batch_size, 1)
134
+
135
+ negative_logits = self.logits_callback(model_embeddings, negative_labels)
136
+ assert negative_logits.size() == (masked_batch_size, num_negatives)
137
+
138
+ if num_positives != 1:
139
+ # [batch_size, seq_len, num_positives] -> [batch_size * seq_len]
140
+ masked_target_padding_mask = target_padding_mask.sum(-1).view(-1)
141
+ # [batch_size, seq_len, num_positives] -> [masked_batch_size, num_positives]
142
+ positive_labels = torch.repeat_interleave(
143
+ initial_positive_labels.view(-1, num_positives),
144
+ masked_target_padding_mask,
145
+ dim=0,
146
+ )
147
+
148
+ return {
149
+ "positive_logits": positive_logits,
150
+ "negative_logits": negative_logits,
151
+ "positive_labels": positive_labels,
152
+ "negative_labels": negative_labels,
153
+ }
154
+
155
+
156
+ def mask_negative_logits(
157
+ negative_logits: torch.Tensor,
158
+ negative_labels: torch.LongTensor,
159
+ positive_labels: torch.LongTensor,
160
+ negative_labels_ignore_index: int,
161
+ ) -> torch.Tensor:
162
+ """
163
+ Assign very small values in negative logits
164
+ for positions where positive labels equal to negative ones.
165
+
166
+ :param negative_logits: Logits from the model for ``negative labels``.
167
+ Expected shape: (masked_batch_size, num_negatives)
168
+ :param negative_labels: a tensor containing labels with negative events.
169
+ Expected shape:
170
+ - (masked_batch_size, num_negatives)
171
+ - (num_negatives) - a case where the same negative events are used for the entire batch
172
+ :param positive_labels: a tensor containing labels with positive events.
173
+ Expected shape: (masked_batch_size, num_positives)
174
+ :param negative_labels_ignore_index: padding value for negative labels.
175
+ This may be the case when negative labels
176
+ are formed at the preprocessing level, rather than the negative sampler.
177
+ The index is ignored and does not contribute to the loss.
178
+
179
+ :returns: Negative logits with modified elements in those positions
180
+ where positive labels are equal to negative ones.
181
+ """
182
+
183
+ if negative_labels_ignore_index >= 0:
184
+ negative_logits.masked_fill_(negative_labels == negative_labels_ignore_index, -1e9)
185
+
186
+ if negative_labels.dim() > 1: # pragma: no cover
187
+ # [masked_batch_size, num_negatives] -> [masked_batch_size, 1, num_negatives]
188
+ negative_labels = negative_labels.unsqueeze(-2)
189
+
190
+ # [masked_batch_size, num_positives] -> [masked_batch_size, num_positives, 1]
191
+ positive_labels = positive_labels.unsqueeze(-1)
192
+ negative_mask = positive_labels == negative_labels # [masked_batch_size, num_positives, num_negatives]
193
+
194
+ # [masked_batch_size, num_positives, num_negatives] -> [masked_batch_size, num_negatives]
195
+ negative_mask = negative_mask.sum(-2).bool()
196
+ negative_logits.masked_fill_(negative_mask, -1e9)
197
+ return negative_logits