replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,9 @@
1
+ from typing import Protocol
2
+
3
+ import torch
4
+
5
+
6
+ class NormalizerProto(Protocol):
7
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor: ...
8
+
9
+ def reset_parameters(self) -> None: ...
replay/nn/output.py ADDED
@@ -0,0 +1,37 @@
1
+ from typing import TypedDict
2
+
3
+ import torch
4
+ from typing_extensions import NotRequired
5
+
6
+
7
+ class TrainOutput(TypedDict):
8
+ """
9
+ Storing outputs from models training stage.
10
+
11
+ :param loss: a tensor containing the calculated loss.\n
12
+ It is important that the tensor contains a gradient to call back propagation from the outside.
13
+ :param hidden_states: Tuple of `torch.Tensor`.\n
14
+ One for the output of the embeddings, if the model has an embedding layer, +
15
+ one for the output of each layer.\n
16
+ Expected shape: ``(batch_size, sequence_length, hidden_size)``.
17
+ """
18
+
19
+ loss: torch.Tensor
20
+ hidden_states: NotRequired[tuple[torch.Tensor, ...]]
21
+
22
+
23
+ class InferenceOutput(TypedDict):
24
+ """
25
+ Storing outputs from models inference stage.
26
+
27
+ :param logits:
28
+ Sequence of hidden-states at the output of the last layer of the model.\n
29
+ Expected shape: ``(batch_size, sequence_length, hidden_size)``.
30
+ :param hidden_states: Tuple of `torch.Tensor`
31
+ (one for the output of the embeddings, if the model has an embedding layer, +
32
+ one for the output of each layer).\n
33
+ Expected shape: ``(batch_size, sequence_length, hidden_size)``.
34
+ """
35
+
36
+ logits: torch.Tensor
37
+ hidden_states: NotRequired[tuple[torch.Tensor, ...]]
@@ -0,0 +1,9 @@
1
+ from .sasrec import (
2
+ DiffTransformerBlock,
3
+ DiffTransformerLayer,
4
+ PositionAwareAggregator,
5
+ SasRec,
6
+ SasRecBody,
7
+ SasRecTransformerLayer,
8
+ )
9
+ from .twotower import ItemTower, QueryTower, TwoTower, TwoTowerBody
@@ -0,0 +1,7 @@
1
+ from .agg import PositionAwareAggregator
2
+ from .diff_transformer import (
3
+ DiffTransformerBlock,
4
+ DiffTransformerLayer,
5
+ )
6
+ from .model import SasRec, SasRecBody
7
+ from .transformer import SasRecTransformerLayer
@@ -0,0 +1,53 @@
1
+ import contextlib
2
+
3
+ import torch
4
+
5
+ from replay.data.nn.schema import TensorMap
6
+ from replay.nn.agg import AggregatorProto
7
+
8
+
9
+ class PositionAwareAggregator(torch.nn.Module):
10
+ """
11
+ The layer for aggregating embeddings and adding positional encoding.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ embedding_aggregator: AggregatorProto,
17
+ max_sequence_length: int,
18
+ dropout: float,
19
+ ) -> None:
20
+ """
21
+ :param embedding_aggregator: An object of a class that performs the logic of aggregating multiple embeddings.\n
22
+ For example, it can be a ``sum``, a ``mean``, or a ``concatenation``.
23
+ :param max_sequence_length: Max length of sequence.
24
+ :param dropout: probability of an element to be zeroed.
25
+ """
26
+ super().__init__()
27
+ self.embedding_aggregator = embedding_aggregator
28
+ self.pe = torch.nn.Embedding(max_sequence_length, self.embedding_aggregator.embedding_dim)
29
+ self.dropout = torch.nn.Dropout(p=dropout)
30
+
31
+ def reset_parameters(self) -> None:
32
+ self.embedding_aggregator.reset_parameters()
33
+ for _, param in self.pe.named_parameters():
34
+ with contextlib.suppress(ValueError):
35
+ torch.nn.init.xavier_normal_(param.data)
36
+
37
+ def forward(self, feature_tensors: TensorMap) -> torch.Tensor:
38
+ """
39
+ :param feature_tensors: a dictionary of tensors to pass into ``embedding_aggregator``.
40
+
41
+ :returns: Aggregated embeddings with positional encoding.
42
+ """
43
+ seqs: torch.Tensor = self.embedding_aggregator(feature_tensors)
44
+ assert seqs.dim() == 3
45
+ batch_size, seq_len, embedding_dim = seqs.size()
46
+ assert (
47
+ seq_len <= self.pe.num_embeddings
48
+ ), f"Sequence length = {seq_len} is greater then positional embedding num = {self.pe.num_embeddings}"
49
+
50
+ seqs *= embedding_dim**0.5
51
+ seqs += self.pe.weight[:seq_len].unsqueeze(0).repeat(batch_size, 1, 1)
52
+ seqs = self.dropout(seqs)
53
+ return seqs
@@ -0,0 +1,125 @@
1
+ import math
2
+
3
+ import torch
4
+
5
+ from replay.data.nn import TensorMap
6
+ from replay.nn.attention import MultiHeadDifferentialAttention
7
+ from replay.nn.ffn import SwiGLU
8
+
9
+
10
+ class DiffTransformerBlock(torch.nn.Module):
11
+ """
12
+ Single Block of the DiffTransformer Architecture.
13
+ Consists of Multi-Head Differential Attention followed by a SwiGLU Feed-Forward Network.
14
+
15
+ Source paper: https://arxiv.org/pdf/2410.05258
16
+ """
17
+
18
+ def __init__(self, embedding_dim: int, num_heads: int, lambda_init: float):
19
+ """
20
+ :param embedding_dim: Total dimension of the model. Must be divisible by ``num_heads``.
21
+ :param num_heads: Number of parallel attention heads.
22
+ :param lambda_init: Initial value for lambda.
23
+ """
24
+ super().__init__()
25
+ self.attn_norm = torch.nn.RMSNorm(embedding_dim)
26
+ self.attn = MultiHeadDifferentialAttention(embedding_dim, num_heads, lambda_init, vdim=2 * embedding_dim)
27
+ self.ff_norm = torch.nn.RMSNorm(embedding_dim)
28
+ self.ff = SwiGLU(embedding_dim, 2 * embedding_dim)
29
+
30
+ def reset_parameters(self) -> None:
31
+ self.attn_norm.reset_parameters()
32
+ self.attn.reset_parameters()
33
+ self.ff_norm.reset_parameters()
34
+ self.ff.reset_parameters()
35
+
36
+ def forward(
37
+ self,
38
+ input_embeddings: torch.Tensor,
39
+ attention_mask: torch.FloatTensor,
40
+ ) -> torch.Tensor:
41
+ """
42
+ Forward pass for a single differential transformer block.
43
+
44
+ :param input_embeddings: Input tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
45
+ :param attention_mask: Causal-like mask for attention pattern, where ``-inf`` for ``PAD``, ``0`` - otherwise.\n
46
+ Possible shapes:\n
47
+ 1. ``(batch_size * num_heads, sequence_length, sequence_length)``
48
+ 2. ``(batch_size, num_heads, sequence_length, sequence_length)``
49
+ :returns: Output tensor after processing through the block.
50
+ """
51
+ # Apply Multi-Head Differential Attention with residual connection
52
+ attent_emb = self.attn(
53
+ input_embeddings,
54
+ input_embeddings,
55
+ input_embeddings,
56
+ attention_mask,
57
+ )
58
+ attention_block_out = self.attn_norm(attent_emb + input_embeddings)
59
+
60
+ # Apply SwiGLU Feed-Forward Network with residual connection
61
+ ff_out = self.ff(input_embeddings=attention_block_out)
62
+ feedforward_block_out = self.ff_norm(ff_out + attention_block_out)
63
+ return feedforward_block_out
64
+
65
+
66
+ class DiffTransformerLayer(torch.nn.Module):
67
+ """
68
+ Stacked blocks of the DiffTransformer Architecture.
69
+ Single block consists of Multi-Head Differential Attention followed by a SwiGLU Feed-Forward Network.
70
+
71
+ Source paper: https://arxiv.org/pdf/2410.05258\n
72
+ Reference: https://github.com/nanowell/Differential-Transformer-PyTorch/blob/main/DiffTransformer.py
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ embedding_dim: int,
78
+ num_heads: int,
79
+ num_blocks: int,
80
+ ) -> None:
81
+ """
82
+ :param embedding_dim: Total dimension of the model. Must be divisible by num_heads.
83
+ :param num_heads: Number of parallel attention heads.
84
+ :param num_blocks: Number of Transformer blocks.
85
+ """
86
+ torch.nn.MultiheadAttention
87
+ super().__init__()
88
+ self.layers = torch.nn.ModuleList(
89
+ [
90
+ DiffTransformerBlock(
91
+ embedding_dim=embedding_dim,
92
+ num_heads=num_heads,
93
+ lambda_init=0.8 - 0.6 * math.exp(-0.3 * block_num),
94
+ )
95
+ for block_num in range(num_blocks)
96
+ ]
97
+ )
98
+
99
+ def reset_parameters(self) -> None:
100
+ for layer in self.layers:
101
+ layer.reset_parameters()
102
+
103
+ def forward(
104
+ self,
105
+ feature_tensors: TensorMap, # noqa: ARG002
106
+ input_embeddings: torch.Tensor,
107
+ padding_mask: torch.BoolTensor, # noqa: ARG002
108
+ attention_mask: torch.FloatTensor,
109
+ ) -> torch.Tensor:
110
+ """
111
+ forward(input_embeddings, attention_mask)
112
+ :param input_embeddings: Input tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
113
+ :param attention_mask: Causal-like mask for attention pattern, where ``-inf`` for ``PAD``, ``0`` - otherwise.\n
114
+ Possible shapes:\n
115
+ 1. ``(batch_size * num_heads, sequence_length, sequence_length)``
116
+ 2. ``(batch_size, num_heads, sequence_length, sequence_length)``
117
+ :returns: Output tensor after processing through the layer.
118
+ """
119
+ seqs = input_embeddings
120
+ for layer in self.layers:
121
+ seqs = layer(
122
+ input_embeddings=seqs,
123
+ attention_mask=attention_mask,
124
+ )
125
+ return seqs
@@ -0,0 +1,377 @@
1
+ from collections.abc import Sequence
2
+ from typing import Literal, Optional, Protocol, Union
3
+
4
+ import torch
5
+
6
+ from replay.data.nn import TensorMap, TensorSchema
7
+ from replay.nn.agg import AggregatorProto
8
+ from replay.nn.head import EmbeddingTyingHead
9
+ from replay.nn.loss import LossProto
10
+ from replay.nn.mask import AttentionMaskProto
11
+ from replay.nn.normalization import NormalizerProto
12
+ from replay.nn.output import InferenceOutput, TrainOutput
13
+ from replay.nn.utils import warning_is_not_none
14
+
15
+
16
+ class EmbedderProto(Protocol):
17
+ def get_item_weights(
18
+ self,
19
+ indices: Optional[torch.LongTensor],
20
+ ) -> torch.Tensor: ...
21
+
22
+ def forward(
23
+ self,
24
+ feature_tensors: TensorMap,
25
+ feature_names: Optional[Sequence[str]] = None,
26
+ ) -> TensorMap: ...
27
+
28
+ def reset_parameters(self) -> None: ...
29
+
30
+
31
+ class EncoderProto(Protocol):
32
+ def forward(
33
+ self,
34
+ feature_tensors: TensorMap,
35
+ input_embeddings: torch.Tensor,
36
+ padding_mask: torch.BoolTensor,
37
+ attention_mask: torch.Tensor,
38
+ ) -> torch.Tensor: ...
39
+
40
+ def reset_parameters(self) -> None: ...
41
+
42
+
43
+ class SasRecBody(torch.nn.Module):
44
+ """
45
+ Implementation of the architecture of the SasRec model.\n
46
+ It can include various self-written blocks for modifying the model,
47
+ but the sequence of applying layers is fixed in accordance with the original architecture.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ embedder: EmbedderProto,
53
+ embedding_aggregator: AggregatorProto,
54
+ attn_mask_builder: AttentionMaskProto,
55
+ encoder: EncoderProto,
56
+ output_normalization: NormalizerProto,
57
+ ):
58
+ """
59
+ :param embedder: An object of a class that performs the logic of
60
+ generating embeddings from an input set of tensors.
61
+ :param embedding_aggregator: An object of a class that performs the logic of aggregating multiple embeddings.\n
62
+ For example, it can be a ``sum``, a ``mean``, or a ``concatenation``.
63
+ :param attn_mask_builder: An object of a class that performs the logic of
64
+ generating an attention mask based on the features and padding mask given to the model.
65
+ :param encoder: An object of a class that performs the logic of generating
66
+ a hidden embedding representation based on
67
+ features, padding masks, attention mask, and aggregated embedding.
68
+ :param output_normalization: An object of a class that performs the logic of
69
+ normalization of the hidden state obtained from the encoder.\n
70
+ For example, it may be a ``torch.nn.LayerNorm`` or ``torch.nn.RMSNorm``.
71
+ """
72
+ super().__init__()
73
+ self.embedder = embedder
74
+ self.attn_mask_builder = attn_mask_builder
75
+ self.embedding_aggregator = embedding_aggregator
76
+ self.encoder = encoder
77
+ self.output_normalization = output_normalization
78
+
79
+ def reset_parameters(self) -> None:
80
+ self.embedder.reset_parameters()
81
+ self.embedding_aggregator.reset_parameters()
82
+ self.encoder.reset_parameters()
83
+ self.output_normalization.reset_parameters()
84
+
85
+ def forward(
86
+ self,
87
+ feature_tensors: TensorMap,
88
+ padding_mask: torch.BoolTensor,
89
+ ) -> torch.Tensor:
90
+ """
91
+ :param feature_tensors: a dictionary of tensors to generate embeddings.
92
+ :param padding_mask: A mask of shape ``(batch_size, sequence_length)``
93
+ indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding").
94
+ ``False`` value indicates that the corresponding ``key`` value will be ignored.
95
+ :returns: The final hidden state.\n
96
+ Expected shape: ``(batch_size, sequence_length, embedding_dim)``
97
+ """
98
+ embeddings = self.embedder(feature_tensors)
99
+ agg_emb: torch.Tensor = self.embedding_aggregator(embeddings)
100
+ assert agg_emb.dim() == 3
101
+
102
+ attn_mask = self.attn_mask_builder(feature_tensors, padding_mask)
103
+
104
+ hidden_state: torch.Tensor = self.encoder(
105
+ feature_tensors=feature_tensors,
106
+ input_embeddings=agg_emb,
107
+ padding_mask=padding_mask,
108
+ attention_mask=attn_mask,
109
+ )
110
+ assert agg_emb.size() == hidden_state.size()
111
+
112
+ hidden_state = self.output_normalization(hidden_state)
113
+ return hidden_state
114
+
115
+
116
+ class SasRec(torch.nn.Module):
117
+ """
118
+ A model using the SasRec architecture as a hidden state generator.
119
+ The hidden states are multiplied by the item embeddings,
120
+ resulting in logits for each of the items.
121
+
122
+ Source paper: https://arxiv.org/pdf/1808.09781.
123
+
124
+ Example:
125
+
126
+ .. code-block:: python
127
+
128
+ from replay.data import FeatureHint, FeatureSource, FeatureType
129
+ from replay.data.nn import TensorFeatureInfo, TensorFeatureSource, TensorSchema
130
+ from replay.nn.agg import SumAggregator
131
+ from replay.nn.embedding import SequenceEmbedding
132
+ from replay.nn.mask import DefaultAttentionMask
133
+ from replay.nn.loss import CESampled
134
+ from replay.nn.sequential import PositionAwareAggregator, SasRecTransformerLayer
135
+
136
+ tensor_schema = TensorSchema(
137
+ [
138
+ TensorFeatureInfo(
139
+ "item_id",
140
+ is_seq=True,
141
+ feature_type=FeatureType.CATEGORICAL,
142
+ embedding_dim=256,
143
+ padding_value=NUM_UNIQUE_ITEMS,
144
+ cardinality=NUM_UNIQUE_ITEMS+1,
145
+ feature_hint=FeatureHint.ITEM_ID,
146
+ feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "item_id")]
147
+ ),
148
+ ]
149
+ )
150
+
151
+ body = SasRecBody(
152
+ embedder=SequenceEmbedding(
153
+ schema=tensor_schema,
154
+ ),
155
+ embedding_aggregator=PositionAwareAggregator(
156
+ embedding_aggregator=SumAggregator(embedding_dim=256),
157
+ max_sequence_length=100,
158
+ dropout=0.2,
159
+ ),
160
+ attn_mask_builder=DefaultAttentionMask(
161
+ reference_feature_name=tensor_schema.item_id_feature_name,
162
+ num_heads=2,
163
+ ),
164
+ encoder=SasRecTransformerLayer(
165
+ embedding_dim=256,
166
+ num_heads=2,
167
+ num_blocks=2,
168
+ dropout=0.3,
169
+ activation="relu",
170
+ ),
171
+ output_normalization=torch.nn.LayerNorm(256),
172
+ )
173
+ sasrec = SasRec(
174
+ body=body,
175
+ loss=CESampled(padding_idx=tensor_schema.item_id_features.item().padding_value)
176
+ )
177
+
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ body: SasRecBody,
183
+ loss: LossProto,
184
+ ):
185
+ """
186
+ :param body: An instance of SasRecBody.
187
+ :param loss: An object of a class that performs loss calculation
188
+ based on hidden states from the model, positive and optionally negative labels.
189
+ """
190
+ super().__init__()
191
+ self.body = body
192
+ self.head = EmbeddingTyingHead()
193
+ self.loss = loss
194
+ self.loss.logits_callback = self.get_logits
195
+
196
+ self.reset_parameters()
197
+
198
+ @classmethod
199
+ def from_params(
200
+ cls,
201
+ schema: TensorSchema,
202
+ embedding_dim: int = 192,
203
+ num_heads: int = 4,
204
+ num_blocks: int = 2,
205
+ max_sequence_length: int = 50,
206
+ dropout: float = 0.3,
207
+ excluded_features: Optional[list[str]] = None,
208
+ categorical_list_feature_aggregation_method: Literal["sum", "mean", "max"] = "sum",
209
+ ) -> "SasRec":
210
+ from replay.nn.agg import SumAggregator
211
+ from replay.nn.embedding import SequenceEmbedding
212
+ from replay.nn.loss import CE
213
+ from replay.nn.mask import DefaultAttentionMask
214
+
215
+ from .agg import PositionAwareAggregator
216
+ from .transformer import SasRecTransformerLayer
217
+
218
+ excluded_features = [
219
+ schema.query_id_feature_name,
220
+ schema.timestamp_feature_name,
221
+ *(excluded_features or []),
222
+ ]
223
+ excluded_features = list(set(excluded_features))
224
+
225
+ body = SasRecBody(
226
+ embedder=SequenceEmbedding(
227
+ schema=schema,
228
+ categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,
229
+ excluded_features=excluded_features,
230
+ ),
231
+ embedding_aggregator=PositionAwareAggregator(
232
+ embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),
233
+ max_sequence_length=max_sequence_length,
234
+ dropout=dropout,
235
+ ),
236
+ attn_mask_builder=DefaultAttentionMask(
237
+ reference_feature_name=schema.item_id_feature_name,
238
+ num_heads=num_heads,
239
+ ),
240
+ encoder=SasRecTransformerLayer(
241
+ embedding_dim=embedding_dim,
242
+ num_heads=num_heads,
243
+ num_blocks=num_blocks,
244
+ dropout=dropout,
245
+ activation="relu",
246
+ ),
247
+ output_normalization=torch.nn.LayerNorm(embedding_dim),
248
+ )
249
+ return cls(
250
+ body=body,
251
+ loss=CE(ignore_index=schema.item_id_features.item().padding_value),
252
+ )
253
+
254
+ def reset_parameters(self) -> None:
255
+ self.body.reset_parameters()
256
+
257
+ def get_logits(
258
+ self,
259
+ model_embeddings: torch.Tensor,
260
+ candidates_to_score: Optional[torch.LongTensor] = None,
261
+ ) -> torch.Tensor:
262
+ item_embeddings: torch.Tensor = self.body.embedder.get_item_weights(candidates_to_score)
263
+ logits: torch.Tensor = self.head(model_embeddings, item_embeddings)
264
+ return logits
265
+
266
+ def forward_train(
267
+ self,
268
+ feature_tensors: TensorMap,
269
+ padding_mask: torch.BoolTensor,
270
+ positive_labels: torch.LongTensor,
271
+ negative_labels: torch.LongTensor,
272
+ target_padding_mask: torch.BoolTensor,
273
+ ) -> TrainOutput:
274
+ hidden_states: torch.Tensor = self.body(feature_tensors, padding_mask)
275
+ assert hidden_states.dim() == 3
276
+
277
+ loss: torch.Tensor = self.loss(
278
+ model_embeddings=hidden_states,
279
+ feature_tensors=feature_tensors,
280
+ positive_labels=positive_labels,
281
+ negative_labels=negative_labels,
282
+ padding_mask=padding_mask,
283
+ target_padding_mask=target_padding_mask,
284
+ )
285
+
286
+ return {
287
+ "loss": loss,
288
+ "hidden_states": (hidden_states,),
289
+ }
290
+
291
+ def forward_inference(
292
+ self,
293
+ feature_tensors: TensorMap,
294
+ padding_mask: torch.BoolTensor,
295
+ candidates_to_score: Optional[torch.LongTensor] = None,
296
+ ) -> InferenceOutput:
297
+ hidden_states: torch.Tensor = self.body(feature_tensors, padding_mask)
298
+ assert hidden_states.dim() == 3
299
+
300
+ last_hidden_state = hidden_states[:, -1, :].contiguous()
301
+ logits = self.get_logits(last_hidden_state, candidates_to_score)
302
+
303
+ return {
304
+ "logits": logits,
305
+ "hidden_states": (hidden_states,),
306
+ }
307
+
308
+ def forward(
309
+ self,
310
+ feature_tensors: TensorMap,
311
+ padding_mask: torch.BoolTensor,
312
+ candidates_to_score: Optional[torch.LongTensor] = None,
313
+ positive_labels: Optional[torch.LongTensor] = None,
314
+ negative_labels: Optional[torch.LongTensor] = None,
315
+ target_padding_mask: Optional[torch.BoolTensor] = None,
316
+ ) -> Union[TrainOutput, InferenceOutput]:
317
+ """
318
+ :param feature_tensors: a dictionary of tensors to generate embeddings.
319
+ :param padding_mask: A mask of shape ``(batch_size, sequence_length)``
320
+ indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding").
321
+ ``False`` value indicates that the corresponding ``key`` value will be ignored.
322
+ :param candidates_to_score: a tensor containing item IDs
323
+ for which you need to get logits at the inference stage.\n
324
+ **Note:** you must take into account the padding value when creating the tensor.\n
325
+ The tensor participates in calculations only on the inference stage.
326
+ You don't have to submit an argument at training stage,
327
+ but if it is submitted, then no effect will be provided.\n
328
+ Default: ``None``.
329
+ :param positive_labels: a tensor containing positive labels for calculating the loss.\n
330
+ You don't have to submit an argument at inference stage,
331
+ but if it is submitted, then no effect will be provided.\n
332
+ Default: ``None``.
333
+ :param negative_labels: a tensor containing negative labels for calculating the loss.\n
334
+ **Note:** Before run make sure that your loss supports calculations with negative labels.\n
335
+ You don't have to submit an argument at inference stage,
336
+ but if it is submitted, then no effect will be provided.\n
337
+ Default: ``None``.
338
+ :param target_padding_mask: A mask of shape ``(batch_size, sequence_length, num_positives)``
339
+ indicating elements from ``positive_labels`` to ignore during loss calculation.
340
+ ``False`` value indicates that the corresponding value will be ignored.\n
341
+ You don't have to submit an argument at inference stage,
342
+ but if it is submitted, then no effect will be provided.\n
343
+ Default: ``None``.
344
+ :returns: During training, the model will return an object
345
+ of the ``TrainOutput`` container class.
346
+ At the inference stage, the ``InferenceOutput`` class will be returned.
347
+ """
348
+ if self.training:
349
+ all(
350
+ map(
351
+ warning_is_not_none("Variable `{}` is not None. This will have no effect at the training stage."),
352
+ [(candidates_to_score, "candidates_to_score")],
353
+ )
354
+ )
355
+ return self.forward_train(
356
+ feature_tensors=feature_tensors,
357
+ padding_mask=padding_mask,
358
+ positive_labels=positive_labels,
359
+ negative_labels=negative_labels,
360
+ target_padding_mask=target_padding_mask,
361
+ )
362
+
363
+ all(
364
+ map(
365
+ warning_is_not_none("Variable `{}` is not None. This will have no effect at the inference stage."),
366
+ [
367
+ (positive_labels, "positive_labels"),
368
+ (negative_labels, "negative_labels"),
369
+ (target_padding_mask, "target_padding_mask"),
370
+ ],
371
+ )
372
+ )
373
+ return self.forward_inference(
374
+ feature_tensors=feature_tensors,
375
+ padding_mask=padding_mask,
376
+ candidates_to_score=candidates_to_score,
377
+ )