replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -3,10 +3,15 @@ import contextlib
3
3
  from typing import Any, Optional, Union, cast
4
4
 
5
5
  import torch
6
+ from typing_extensions import deprecated
6
7
 
7
8
  from replay.data.nn import TensorMap, TensorSchema
8
9
 
9
10
 
11
+ @deprecated(
12
+ "`SasRecModel` class is deprecated. Use `replay.nn.sequential.SasRec` instead.",
13
+ stacklevel=2,
14
+ )
10
15
  class SasRecModel(torch.nn.Module):
11
16
  """
12
17
  SasRec model
@@ -110,7 +115,7 @@ class SasRecModel(torch.nn.Module):
110
115
  ) -> torch.Tensor:
111
116
  """
112
117
  :param feature_tensor: Batch of features.
113
- :param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
118
+ :param padding_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
114
119
 
115
120
  :returns: Calculated scores.
116
121
  """
@@ -127,11 +132,11 @@ class SasRecModel(torch.nn.Module):
127
132
  ) -> torch.Tensor:
128
133
  """
129
134
  :param feature_tensor: Batch of features.
130
- :param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
131
- :param candidates_to_score: Item ids to calculate scores.
132
- if `None` predicts for all items
135
+ :param padding_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
136
+ :param candidates_to_score: Item ids to calculate scores.\n
137
+ If ``None`` then predicts for all items. Default: ``None``.
133
138
 
134
- :returns: Prediction among canditates_to_score items.
139
+ :returns: Prediction among ``canditates_to_score`` items.
135
140
  """
136
141
  # final_emb: [B x E]
137
142
  final_emb = self.get_query_embeddings(feature_tensor, padding_mask)
@@ -145,7 +150,7 @@ class SasRecModel(torch.nn.Module):
145
150
  ):
146
151
  """
147
152
  :param feature_tensor: Batch of features.
148
- :param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
153
+ :param padding_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
149
154
 
150
155
  :returns: Query embeddings.
151
156
  """
@@ -158,7 +163,7 @@ class SasRecModel(torch.nn.Module):
158
163
  ) -> torch.Tensor:
159
164
  """
160
165
  :param feature_tensor: Batch of features.
161
- :param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
166
+ :param padding_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
162
167
 
163
168
  :returns: Output embeddings.
164
169
  """
@@ -176,9 +181,9 @@ class SasRecModel(torch.nn.Module):
176
181
 
177
182
  def get_logits(self, out_embeddings: torch.Tensor, item_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
178
183
  """
179
- Apply head to output embeddings of `forward_step`.
184
+ Apply head to output embeddings of ``forward_step``.
180
185
 
181
- :param out_embeddings: Embeddings after `forward step`.
186
+ :param out_embeddings: Embeddings after ``forward step``.
182
187
  :param item_ids: Item ids to calculate scores.
183
188
  Default: ``None``.
184
189
 
replay/nn/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ from replay.utils import TORCH_AVAILABLE
2
+
3
+ if not TORCH_AVAILABLE:
4
+ msg = (
5
+ "The replay.nn module is unavailable. "
6
+ "To use the functionality from this module, please install ``torch`` and ``lightning``."
7
+ )
8
+ raise ImportError(msg)
replay/nn/agg.py ADDED
@@ -0,0 +1,109 @@
1
+ import contextlib
2
+ from typing import Protocol
3
+
4
+ import torch
5
+
6
+ from replay.data.nn.schema import TensorMap
7
+
8
+
9
+ class AggregatorProto(Protocol):
10
+ """Class-protocol for working with embedding aggregation functions"""
11
+
12
+ def forward(
13
+ self,
14
+ feature_tensors: TensorMap,
15
+ ) -> torch.Tensor: ...
16
+
17
+ @property
18
+ def embedding_dim(self) -> int: ...
19
+
20
+ def reset_parameters(self) -> None: ...
21
+
22
+
23
+ class SumAggregator(torch.nn.Module):
24
+ """
25
+ The class summarizes the incoming embeddings.
26
+ Note that for successful aggregation, the dimensions of all embeddings must match.
27
+ """
28
+
29
+ def __init__(self, embedding_dim: int) -> None:
30
+ """
31
+ :param embedding_dim: The last dimension of incoming and outcoming embeddings.
32
+ """
33
+ super().__init__()
34
+ self._embedding_dim = embedding_dim
35
+
36
+ @property
37
+ def embedding_dim(self) -> int:
38
+ """The dimension of the output embedding"""
39
+ return self._embedding_dim
40
+
41
+ def reset_parameters(self) -> None:
42
+ pass
43
+
44
+ def forward(self, feature_tensors: TensorMap) -> torch.Tensor:
45
+ """
46
+ :param feature_tensors: a dictionary of tensors to sum up.
47
+ The dimensions of all tensors in the dictionary must match.
48
+
49
+ :returns: torch.Tensor. The last dimension of the tensor is ``embedding_dim``.
50
+ """
51
+ out = sum(feature_tensors.values())
52
+ assert out.size(-1) == self.embedding_dim
53
+ return out
54
+
55
+
56
+ class ConcatAggregator(torch.nn.Module):
57
+ """
58
+ The class concatenates incoming embeddings by the last dimension.
59
+
60
+ If you need to concatenate several embeddings,
61
+ then a linear layer will be applied to get the last dimension equal to ``embedding_dim``.\n
62
+ If only one embedding comes to the input, then its last dimension is expected to be equal to ``embedding_dim``.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ input_embedding_dims: list[int],
68
+ output_embedding_dim: int,
69
+ ) -> None:
70
+ """
71
+ :param input_embedding_dims: Dimensions of incoming embeddings.
72
+ :param output_embedding_dim: The dimension of the output embedding after concatenation.
73
+ """
74
+ super().__init__()
75
+ self._embedding_dim = output_embedding_dim
76
+ embedding_concat_size = sum(input_embedding_dims)
77
+ self.feat_projection = None
78
+ if len(input_embedding_dims) > 1:
79
+ self.feat_projection = torch.nn.Linear(embedding_concat_size, self.embedding_dim)
80
+ elif embedding_concat_size != self.embedding_dim:
81
+ msg = f"Input embedding dim is not equal to embedding_dim ({embedding_concat_size} != {self.embedding_dim})"
82
+ raise ValueError(msg)
83
+
84
+ @property
85
+ def embedding_dim(self) -> int:
86
+ """The dimension of the output embedding"""
87
+ return self._embedding_dim
88
+
89
+ def reset_parameters(self) -> None:
90
+ for _, param in self.named_parameters():
91
+ with contextlib.suppress(ValueError):
92
+ torch.nn.init.xavier_normal_(param.data)
93
+
94
+ def forward(self, feature_tensors: TensorMap) -> torch.Tensor:
95
+ """
96
+ To ensure the deterministic nature of the result,
97
+ the embeddings are concatenated in the ascending order of the keys in the dictionary.
98
+
99
+ :param feature_tensors: a dictionary of tensors to concatenate.
100
+
101
+ :returns: The last dimension of the tensor is ``embedding_dim``.
102
+ """
103
+ # To maintain determinism, we concatenate the tensors in sorted order by names.
104
+ sorted_names = sorted(feature_tensors.keys())
105
+ out = torch.cat([feature_tensors[name] for name in sorted_names], dim=-1)
106
+ if self.feat_projection is not None:
107
+ out = self.feat_projection(out)
108
+ assert out.size(-1) == self.embedding_dim
109
+ return out
replay/nn/attention.py ADDED
@@ -0,0 +1,158 @@
1
+ import contextlib
2
+ import math
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ class MultiHeadDifferentialAttention(torch.nn.Module):
9
+ """
10
+ Multi-Head Differential Attention Mechanism.
11
+ Replaces the conventional softmax attention with a differential attention.
12
+ Incorporattes a causal mask (if other not specified) to ensure autoregressive behavior.
13
+
14
+ Source paper: https://arxiv.org/pdf/2410.05258
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ embedding_dim: int,
20
+ num_heads: int,
21
+ lambda_init: float,
22
+ bias: bool = False,
23
+ kdim: Optional[int] = None,
24
+ vdim: Optional[int] = None,
25
+ ):
26
+ """
27
+ :param embedding_dim: Total dimension of the model. Must be divisible by ``num_heads``.
28
+ :param num_heads: Number of parallel attention heads.
29
+ :param lambda_init: Initial value for lambda.
30
+ :param bias: If specified, adds bias to input / output projection layers. Default: ``False``.
31
+ :param kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embedding_dim``).
32
+ :param vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embedding_dim``).
33
+ """
34
+ super().__init__()
35
+ kdim = kdim or embedding_dim
36
+ vdim = vdim or embedding_dim
37
+ assert kdim % num_heads == 0, "Query/Key embedding dim is not divisible by num_heads"
38
+ assert vdim % num_heads == 0, "Value embedding dim is not divisible by num_heads"
39
+ self.qk_e_head = kdim // num_heads
40
+ self.v_e_head = vdim // num_heads
41
+ self.num_heads = num_heads
42
+
43
+ # Linear projections for queries, keys, and values
44
+ # Project to 2 * d_head per head for differential attention
45
+ self.W_q = torch.nn.Linear(embedding_dim, 2 * self.qk_e_head * num_heads, bias=bias)
46
+ self.W_k = torch.nn.Linear(embedding_dim, 2 * self.qk_e_head * num_heads, bias=bias)
47
+ self.W_v = torch.nn.Linear(embedding_dim, self.v_e_head * num_heads, bias=bias)
48
+ self.W_o = torch.nn.Linear(self.v_e_head * num_heads, embedding_dim, bias=bias)
49
+
50
+ # Learnable parameters for lambda reparameterization
51
+ self.lambda_q1 = torch.nn.Parameter(torch.randn(num_heads, self.qk_e_head))
52
+ self.lambda_k1 = torch.nn.Parameter(torch.randn(num_heads, self.qk_e_head))
53
+ self.lambda_q2 = torch.nn.Parameter(torch.randn(num_heads, self.qk_e_head))
54
+ self.lambda_k2 = torch.nn.Parameter(torch.randn(num_heads, self.qk_e_head))
55
+ self.register_buffer("scaling", torch.asarray(1 / math.sqrt(self.qk_e_head), dtype=torch.float32))
56
+
57
+ self.lambda_init = lambda_init
58
+
59
+ # Scale parameter for RMSNorm
60
+ self.rms_scale = torch.nn.Parameter(torch.ones(self.v_e_head))
61
+ self.eps = 1e-5 # Epsilon for numerical stability
62
+
63
+ def reset_parameters(self) -> None:
64
+ for _, param in self.named_parameters():
65
+ with contextlib.suppress(ValueError):
66
+ torch.nn.init.xavier_normal_(param.data)
67
+
68
+ def forward(
69
+ self,
70
+ query: torch.Tensor,
71
+ key: torch.Tensor,
72
+ value: torch.Tensor,
73
+ attn_mask: torch.FloatTensor,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Forward pass for Multi-Head Differential Attention.
77
+
78
+ :param query: Query sequence of shape ``(batch_size, sequence_length, embedding_dim)``.
79
+ :param key: Key sequence of shape ``(batch_size, sequence_length, embedding_dim)``.
80
+ :param value: Value sequence of shape ``(batch_size, sequence_length, embedding_dim)``.
81
+ :param attn_mask: attention mask, where ``-inf`` for ``PAD``, ``0`` - otherwise.\n
82
+ Possible shapes:\n
83
+ 1. ``(batch_size * num_heads, sequence_length, sequence_length)``
84
+ 2. ``(batch_size, num_heads, sequence_length, sequence_length)``
85
+ :returns: torch.Tensor: Output tensor after applying differential attention.
86
+ """
87
+ batch_size, seq_len, _ = value.shape
88
+
89
+ # Project inputs to queries, keys, and values
90
+ query = self.W_q(query) # Shape: (batch_size, seq_len, 2 * num_heads * qk_e_head)
91
+ key = self.W_k(key) # Shape: (batch_size, seq_len, 2 * num_heads * qk_e_head)
92
+ value = self.W_v(value) # Shape: (batch_size, seq_len, num_heads * v_e_head)
93
+
94
+ # Reshape and permute for multi-head attention
95
+ # New shape: (batch_size, num_heads, sequence_length, 2 * qk_e_head or v_e_head)
96
+ query = query.view(batch_size, seq_len, self.num_heads, 2 * self.qk_e_head).transpose(1, 2)
97
+ key = key.view(batch_size, seq_len, self.num_heads, 2 * self.qk_e_head).transpose(1, 2)
98
+ value = value.view(batch_size, seq_len, self.num_heads, self.v_e_head).transpose(1, 2)
99
+
100
+ # Split query and key into query1, query2 and key1, key2
101
+ query1, query2 = query.chunk(2, dim=-1) # Each of shape: (batch_size, num_heads, seq_len, d_head)
102
+ key1, key2 = key.chunk(2, dim=-1) # Each of shape: (batch_size, num_heads, seq_len, d_head)
103
+
104
+ # Compute lambda using reparameterization
105
+ # Compute dot products for each head
106
+ # Shape of lambda_val: (num_heads,)
107
+ lambda_q1_dot_k1 = torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() # (num_heads,)
108
+ lambda_q2_dot_k2 = torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() # (num_heads,)
109
+ lambda_val = torch.exp(lambda_q1_dot_k1) - torch.exp(lambda_q2_dot_k2) + self.lambda_init # (num_heads,)
110
+
111
+ # Expand lambda_val to match attention dimensions (batch_size, num_heads, 1, 1)
112
+ lambda_val = lambda_val.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
113
+
114
+ # Reshape attn_mask from 3D to 4D
115
+ if len(attn_mask.shape) == 3:
116
+ attn_mask = attn_mask.reshape(attn_mask.shape[0] // self.num_heads, self.num_heads, *attn_mask.shape[1:])
117
+
118
+ # check shapes
119
+ assert attn_mask.dim() == 4
120
+ assert attn_mask.size() == (batch_size, self.num_heads, seq_len, seq_len)
121
+
122
+ # Compute attention scores
123
+ attention_scores1 = torch.matmul(query1, key1.transpose(-2, -1)) * self.get_buffer(
124
+ "scaling"
125
+ ) # (batch_size, num_heads, seq_len, seq_len)
126
+ attention_scores2 = torch.matmul(query2, key2.transpose(-2, -1)) * self.get_buffer(
127
+ "scaling"
128
+ ) # (batch_size, num_heads, seq_len, seq_len)
129
+
130
+ # Apply the causal mask
131
+ attention_scores1 = attention_scores1 + attn_mask # Mask out future positions
132
+ attention_scores2 = attention_scores2 + attn_mask # Mask out future positions
133
+
134
+ # Apply softmax to get attention weights
135
+ attention1 = torch.nn.functional.softmax(attention_scores1, dim=-1) # (batch_size, num_heads, seq_len, seq_len)
136
+ attention2 = torch.nn.functional.softmax(attention_scores2, dim=-1)
137
+ attention = attention1 - lambda_val * attention2
138
+
139
+ # Apply attention weights to values
140
+ output = torch.matmul(attention, value) # (batch_size, num_heads, seq_len, v_e_head)
141
+
142
+ # Normalize each head independently using RMSNorm
143
+ # Compute RMSNorm
144
+ rms_norm = torch.sqrt(output.pow(2).mean(dim=-1, keepdim=True) + self.eps) # (batch_size*num_heads, seq_len, 1)
145
+ output_normalized = (output / rms_norm) * self.rms_scale # (batch*num_heads, seq_len, v_e_head)
146
+
147
+ # Scale the normalized output
148
+ output_normalized = output_normalized * (1 - self.lambda_init) # Scalar scaling
149
+
150
+ # Concatenate all heads
151
+ # New shape: (batch_size, seq_len, num_heads * v_e_head)
152
+ output_concat = (
153
+ output_normalized.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.v_e_head)
154
+ )
155
+
156
+ # Final linear projection
157
+ output_projection = self.W_o(output_concat) # (batch_size, seq_len, embedding_dim)
158
+ return output_projection
replay/nn/embedding.py ADDED
@@ -0,0 +1,283 @@
1
+ import contextlib
2
+ import warnings
3
+ from collections.abc import Sequence
4
+ from typing import Literal, Optional, Union
5
+
6
+ import torch
7
+
8
+ from replay.data.nn.schema import TensorFeatureInfo, TensorMap, TensorSchema
9
+
10
+
11
+ class SequenceEmbedding(torch.nn.Module):
12
+ """
13
+ The embedding generation class for all types of features given into the sequential models.
14
+
15
+ The embedding size for each feature will be taken from ``TensorSchema`` (from field named ``embedding_dim``).
16
+ For numerical features, it is expected that the last dimension of the tensor will be equal
17
+ to ``tensor_dim`` field in ``TensorSchema``.
18
+
19
+ Keep in mind that the first dimension of the every categorical embedding (the size of embedding table)
20
+ will equal to the ``cardinality`` + 1. This is necessary to take into account the padding value.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ schema: TensorSchema,
26
+ excluded_features: Optional[list[str]] = None,
27
+ categorical_list_feature_aggregation_method: Literal["sum", "mean", "max"] = "sum",
28
+ ):
29
+ """
30
+ :param schema: TensorSchema containing meta information about all the features
31
+ for which you need to generate an embedding.
32
+ :param excluded_features: A list containing the names of features
33
+ for which you do not need to generate an embedding.
34
+ Fragments from this list are expected to be contained in ``schema``.
35
+ Default: ``None``.
36
+ :param categorical_list_feature_aggregation_method: Mode to aggregate tokens
37
+ in token item representation (categorical list only).
38
+ Default: ``"sum"``.
39
+ """
40
+ super().__init__()
41
+ self.excluded_features = excluded_features or []
42
+ feature_embedders = {}
43
+
44
+ for feature_name, tensor_info in schema.items():
45
+ if feature_name in self.excluded_features:
46
+ continue
47
+ if not tensor_info.is_seq:
48
+ msg = f"Non-sequential features is not yet supported. Got {feature_name}"
49
+ raise NotImplementedError(msg)
50
+ if tensor_info.is_cat:
51
+ feature_embedders[feature_name] = CategoricalEmbedding(
52
+ tensor_info,
53
+ categorical_list_feature_aggregation_method,
54
+ )
55
+ else:
56
+ feature_embedders[feature_name] = NumericalEmbedding(tensor_info)
57
+
58
+ self.feature_names = list(feature_embedders.keys())
59
+ if not feature_embedders:
60
+ msg = "Expected to have at least one feature name to generate embedding."
61
+ raise ValueError(msg)
62
+ self.feature_embedders: dict[str, Union[CategoricalEmbedding, NumericalEmbedding]] = torch.nn.ModuleDict(
63
+ feature_embedders
64
+ )
65
+ self._item_feature_name = schema.item_id_feature_name
66
+
67
+ def reset_parameters(self) -> None:
68
+ for _, param in self.named_parameters():
69
+ with contextlib.suppress(ValueError):
70
+ torch.nn.init.xavier_normal_(param.data)
71
+
72
+ def forward(self, feature_tensor: TensorMap, feature_names: Optional[Sequence[str]] = None) -> TensorMap:
73
+ """
74
+ :param feature_tensor: a dictionary of tensors to generate embedding.
75
+ It is expected that the keys from this dictionary match the names of the features in the given ``schema``.
76
+ :param feature_names: A custom list of features for which embeddings need to be generated.
77
+ It is expected that the values from this list match the names of the features in the given ``schema``.\n
78
+ Default: ``None``. This means that the names of the features from the ``schema`` will be used.
79
+
80
+ :returns: a dictionary with tensors that contains embeddings.
81
+ """
82
+ return {
83
+ feature_name: self.feature_embedders[feature_name](feature_tensor[feature_name])
84
+ for feature_name in (feature_names or self.feature_names)
85
+ }
86
+
87
+ @property
88
+ def embeddings_dim(self) -> dict[str, int]:
89
+ """
90
+ Returns the embedding dimensions for each of the features in the `schema`.
91
+ """
92
+ return {name: emb.embedding_dim for name, emb in self.feature_embedders.items()}
93
+
94
+ def get_item_weights(self, indices: Optional[torch.LongTensor] = None) -> torch.Tensor:
95
+ """
96
+ Getting the embedding weights for a feature that matches the item id feature
97
+ with the name specified in the ``schema``.
98
+ It is expected that embeddings for this feature will definitely exist.
99
+ **Note**: the row corresponding to the padding will be excluded from the returned weights.
100
+ This logic will work if given ``indices`` is ``None``.
101
+
102
+ :param indices: Items indices.
103
+ :returns: Embeddings for specific items.
104
+ """
105
+ if indices is None:
106
+ return self.feature_embedders[self._item_feature_name].weight
107
+ return self.feature_embedders[self._item_feature_name](indices)
108
+
109
+
110
+ class CategoricalEmbedding(torch.nn.Module):
111
+ """
112
+ The embedding generation class for categorical features.
113
+ It supports working with single features for each event in sequence, as well as several (categorical list).
114
+
115
+ When using this class, keep in mind that
116
+ the first dimension of the embedding (the size of embedding table) will equal to the ``cardinality`` + 1.
117
+ This is necessary to take into account the padding value.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ feature_info: TensorFeatureInfo,
123
+ categorical_list_feature_aggregation_method: Literal["sum", "mean", "max"] = "sum",
124
+ ) -> None:
125
+ """
126
+ :param feature_info: Meta information about the feature.
127
+ :param categorical_list_feature_aggregation_method: Mode to aggregate tokens
128
+ in token item representation (categorical list only). One of {`sum`, `mean`, `max`}
129
+ Default: ``"sum"``.
130
+ """
131
+ super().__init__()
132
+ assert feature_info.cardinality
133
+ assert feature_info.embedding_dim
134
+
135
+ self._expect_padding_value_setted = True
136
+ if feature_info.cardinality != feature_info.padding_value:
137
+ self._expect_padding_value_setted = False
138
+ msg = (
139
+ f"The padding value={feature_info.padding_value} is set for the feature={feature_info.name}. "
140
+ f"The expected padding value for this feature should be {feature_info.cardinality}. "
141
+ "Keep this in mind when getting the weights via the `weight` property, "
142
+ "because the weights are returned there without padding row. "
143
+ "Therefore, during the IDs scores generating, "
144
+ "all the IDs that greater than the padding value should be increased by 1."
145
+ )
146
+ warnings.warn(msg, stacklevel=2)
147
+
148
+ if feature_info.is_list:
149
+ self.emb = torch.nn.EmbeddingBag(
150
+ feature_info.cardinality + 1,
151
+ feature_info.embedding_dim,
152
+ padding_idx=feature_info.padding_value,
153
+ mode=categorical_list_feature_aggregation_method,
154
+ )
155
+ self._get_embeddings = self._get_cat_list_embeddings
156
+ else:
157
+ self.emb = torch.nn.Embedding(
158
+ feature_info.cardinality + 1,
159
+ feature_info.embedding_dim,
160
+ padding_idx=feature_info.padding_value,
161
+ )
162
+ self._get_embeddings = self._get_cat_embeddings
163
+
164
+ @property
165
+ def weight(self) -> torch.Tensor:
166
+ """
167
+ Returns the weights of the embedding layer,
168
+ excluding the row that corresponds to the padding.
169
+ """
170
+ if not self._expect_padding_value_setted:
171
+ msg = (
172
+ "The weights are returned there do not contain padding row. "
173
+ "Therefore, during the IDs scores generating, "
174
+ "all the IDs that greater than the padding value should be increased by 1."
175
+ )
176
+ warnings.warn(msg, stacklevel=2)
177
+
178
+ mask_without_padding = torch.ones(
179
+ size=(self.emb.weight.size(0),),
180
+ dtype=torch.bool,
181
+ device=self.emb.weight.device,
182
+ )
183
+ mask_without_padding[self.emb.padding_idx].zero_()
184
+ return self.emb.weight[mask_without_padding]
185
+
186
+ def forward(self, indices: torch.LongTensor) -> torch.Tensor:
187
+ """
188
+ :param indices: Items indices.
189
+
190
+ :returns: Embeddings for specific items.
191
+ """
192
+ return self._get_embeddings(indices)
193
+
194
+ @property
195
+ def embedding_dim(self) -> int:
196
+ """Embedding dimension after applying the layer"""
197
+ return self.emb.embedding_dim
198
+
199
+ def _get_cat_embeddings(self, indices: torch.LongTensor) -> torch.Tensor:
200
+ """
201
+ :param indices: Items indices.
202
+
203
+ :returns: Embeddings for specific items.
204
+ """
205
+ return self.emb(indices)
206
+
207
+ def _get_cat_list_embeddings(self, indices: torch.LongTensor) -> torch.Tensor:
208
+ """
209
+ :param indices: Items indices.
210
+
211
+ :returns: Embeddings for specific items.
212
+ """
213
+ assert indices.dim() >= 2
214
+ if indices.dim() == 2:
215
+ embeddings: torch.Tensor = self.emb(indices)
216
+ else:
217
+ source_size = indices.size()
218
+ indices = indices.view(-1, source_size[-1])
219
+ embeddings = self.emb(indices)
220
+ embeddings = embeddings.view(*source_size[:-1], -1)
221
+ return embeddings
222
+
223
+
224
+ class NumericalEmbedding(torch.nn.Module):
225
+ """
226
+ The embedding generation class for numerical features.
227
+ It supports working with single features for each event in sequence, as well as several (numerical list).
228
+
229
+ **Note**: if the ``embedding_dim`` field in ``TensorFeatureInfo`` for an incoming feature matches its last dimension
230
+ (``tensor_dim`` field in ``TensorFeatureInfo``), then transformation will not be applied.
231
+ """
232
+
233
+ def __init__(self, feature_info: TensorFeatureInfo) -> None:
234
+ """
235
+ :param feature_info: Meta information about the feature.
236
+ """
237
+ super().__init__()
238
+ assert feature_info.tensor_dim
239
+ assert feature_info.embedding_dim
240
+ self._tensor_dim = feature_info.tensor_dim
241
+ self._embedding_dim = feature_info.embedding_dim
242
+ self.linear = torch.nn.Linear(feature_info.tensor_dim, self.embedding_dim)
243
+
244
+ if feature_info.is_list:
245
+ if self.embedding_dim == feature_info.tensor_dim:
246
+ torch.nn.init.eye_(self.linear.weight.data)
247
+ torch.nn.init.zeros_(self.linear.bias.data)
248
+
249
+ self.linear.weight.requires_grad = False
250
+ self.linear.bias.requires_grad = False
251
+ else:
252
+ assert feature_info.tensor_dim == 1
253
+ self.linear = torch.nn.Linear(feature_info.tensor_dim, self.embedding_dim)
254
+
255
+ @property
256
+ def weight(self) -> torch.Tensor:
257
+ """
258
+ Returns the weight of the applied layer.
259
+ If ``embedding_dim`` matches ``tensor_dim``, then the identity matrix will be returned.
260
+ """
261
+ return self.linear.weight
262
+
263
+ def forward(self, values: torch.FloatTensor) -> torch.Tensor:
264
+ """
265
+ Numerical embedding forward pass.\n
266
+ **Note**: if the ``embedding_dim`` for an incoming feature matches its last dimension (``tensor_dim``),
267
+ then transformation will not be applied.
268
+
269
+ :param values: feature values.
270
+ :returns: Embeddings for specific items.
271
+ """
272
+ if values.dim() <= 2 and self._tensor_dim == 1:
273
+ values = values.unsqueeze(-1).contiguous()
274
+
275
+ assert values.size(-1) == self._tensor_dim
276
+ if self._tensor_dim != self.embedding_dim:
277
+ return self.linear(values)
278
+ return values
279
+
280
+ @property
281
+ def embedding_dim(self) -> int:
282
+ """Embedding dimension after applying the layer"""
283
+ return self._embedding_dim