difflayers 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,339 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch import Tensor
5
+ from torch.nn import Linear, Module, Parameter
6
+ from typing import Optional
7
+
8
+ from .functional import hopfield_core_forward
9
+
10
+ try:
11
+ from torch.nn.modules.linear import _LinearWithBias
12
+ except ImportError:
13
+ _LinearWithBias = None
14
+
15
+
16
+ class HopfieldCore(Module):
17
+ r"""Allows the model to jointly attend to information
18
+ from different representation subspaces.
19
+ See references: "Hopfield Networks is All You Need" and
20
+ "Attention Is All You Need" (on which this implementation is partly based on).
21
+
22
+ .. math::
23
+ \text{HopfieldHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
24
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
25
+
26
+ Args:
27
+ embed_dim: total dimension of the model.
28
+ num_heads: parallel attention heads.
29
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
30
+ bias: add bias as module parameter. Default: True.
31
+ add_bias_kv: add bias to the key and value sequences at dim=0.
32
+ add_zero_attn: add a new batch of zeros to the key and
33
+ value sequences at dim=1.
34
+ kdim: total number of features in key. Default: None.
35
+ vdim: total number of features in value. Default: None.
36
+
37
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
38
+ query, key, and value have the same number of features.
39
+
40
+ Examples::
41
+
42
+ >>> hopfield_attn = HopfieldCore(embed_dim, num_heads)
43
+ >>> attn_output, attn_output_weights, attn_matrix = hopfield_attn(query, key, value)
44
+ """
45
+ __annotations__ = {
46
+ 'bias_k': torch._jit_internal.Optional[torch.Tensor],
47
+ 'bias_v': torch._jit_internal.Optional[torch.Tensor],
48
+ }
49
+
50
+ def __init__(self,
51
+ embed_dim=None, # type: Optional[int]
52
+ num_heads=1, # type: int
53
+ dropout=0.0, # type: float
54
+ bias=True, # type: bool
55
+ add_bias_kv=False, # type: bool
56
+ add_zero_attn=False, # type: bool
57
+ kdim=None, # type: Optional[int]
58
+ vdim=None, # type: Optional[int]
59
+
60
+ head_dim=None, # type: Optional[int]
61
+ pattern_dim=None, # type: Optional[int]
62
+ out_dim=None, # type: Optional[int]
63
+ disable_out_projection=False, # type: bool
64
+ key_as_static=False, # type: bool
65
+ query_as_static=False, # type: bool
66
+ value_as_static=False, # type: bool
67
+ value_as_connected=False, # type: bool
68
+ normalize_pattern=False, # type: bool
69
+ normalize_pattern_affine=False, # type: bool
70
+ normalize_pattern_eps=1e-5 # type: float
71
+ ):
72
+ super(HopfieldCore, self).__init__()
73
+
74
+ assert (type(key_as_static) == bool) and (type(query_as_static) == bool) and (type(value_as_static) == bool)
75
+ self.key_as_static, self.query_as_static, self.value_as_static = key_as_static, query_as_static, value_as_static
76
+ num_non_static = 3 - (self.key_as_static + self.query_as_static + self.value_as_static)
77
+ assert 0 <= num_non_static < 4
78
+
79
+ self.value_as_connected = value_as_connected
80
+ self.normalize_pattern, self.normalize_pattern_affine = normalize_pattern, normalize_pattern_affine
81
+ self.normalize_pattern_eps = normalize_pattern_eps
82
+ self.disable_out_projection = disable_out_projection
83
+
84
+ # In case of a static-only executions, check corresponding projections and normalizations.
85
+ self.static_execution = self._check_execution_mode()
86
+ if self.static_execution:
87
+ embed_dim, kdim, vdim = None, None, None
88
+ if embed_dim is None:
89
+ assert self.static_execution, r'static-only execution requires all projections to be deactivated.'
90
+
91
+ # Check and set all other properties, conditioned on <static_execution>.
92
+ self.embed_dim = embed_dim
93
+ self.kdim = kdim if kdim is not None else embed_dim
94
+ self.vdim = vdim if vdim is not None else embed_dim
95
+ self._qkv_same_embed_dim = all((
96
+ self.kdim == embed_dim, self.vdim == embed_dim, pattern_dim is None, not self.value_as_connected))
97
+ assert (not self.value_as_connected) or (self.kdim == self.vdim), r'key and value need to be of same dimension.'
98
+
99
+ self.num_heads = num_heads
100
+ self.dropout = dropout
101
+ self.head_dim = None
102
+ self.pattern_dim = pattern_dim
103
+ self.virtual_hopfield_dim = None
104
+ self.virtual_pattern_dim = None
105
+ if not self.static_execution:
106
+ if head_dim is None:
107
+ self.head_dim = embed_dim // num_heads
108
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads."
109
+ else:
110
+ assert head_dim > 0, "dimension of the association space has to be positive."
111
+ self.head_dim = head_dim
112
+ if self.pattern_dim is None:
113
+ self.pattern_dim = self.head_dim
114
+ self.virtual_hopfield_dim = self.num_heads * self.head_dim
115
+ self.virtual_pattern_dim = self.num_heads * self.pattern_dim
116
+
117
+ self.out_dim = embed_dim if out_dim is None else out_dim
118
+ assert disable_out_projection or (self.out_dim > 0), "output projection dimension has to be positive."
119
+
120
+ if normalize_pattern_affine:
121
+ assert normalize_pattern, "affine pattern normalization without pattern normalization has no effect."
122
+ self.p_norm_weight = Parameter(torch.Tensor(head_dim))
123
+ self.p_norm_bias = Parameter(torch.Tensor(head_dim))
124
+ else:
125
+ self.register_parameter('p_norm_weight', None)
126
+ self.register_parameter('p_norm_bias', None)
127
+
128
+ if self._qkv_same_embed_dim is False:
129
+ if query_as_static:
130
+ self.register_parameter('q_proj_weight', None)
131
+ else:
132
+ self.q_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, embed_dim))
133
+ if key_as_static:
134
+ self.register_parameter('k_proj_weight', None)
135
+ else:
136
+ self.k_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, self.kdim))
137
+ if value_as_static:
138
+ self.register_parameter('v_proj_weight', None)
139
+ else:
140
+ self.v_proj_weight = Parameter(torch.Tensor(
141
+ self.virtual_pattern_dim,
142
+ self.virtual_hopfield_dim if (value_as_connected and not key_as_static) else self.vdim))
143
+ self.register_parameter('in_proj_weight', None)
144
+ else:
145
+ if num_non_static > 0:
146
+ self.in_proj_weight = Parameter(torch.empty(
147
+ (not query_as_static) * self.virtual_hopfield_dim +
148
+ (not key_as_static) * self.virtual_hopfield_dim +
149
+ (not value_as_static) * self.virtual_pattern_dim, embed_dim))
150
+ else:
151
+ self.register_parameter('in_proj_weight', None)
152
+ self.register_parameter('q_proj_weight', None)
153
+ self.register_parameter('k_proj_weight', None)
154
+ self.register_parameter('v_proj_weight', None)
155
+
156
+ if bias and (num_non_static > 0):
157
+ self.in_proj_bias = Parameter(torch.empty(
158
+ (not query_as_static) * self.virtual_hopfield_dim +
159
+ (not key_as_static) * self.virtual_hopfield_dim + self.virtual_pattern_dim))
160
+ else:
161
+ self.register_parameter('in_proj_bias', None)
162
+ if disable_out_projection:
163
+ self.register_parameter('out_proj', None)
164
+ else:
165
+ if bias and _LinearWithBias is not None:
166
+ self.out_proj = _LinearWithBias(self.virtual_pattern_dim, self.out_dim)
167
+ else:
168
+ self.out_proj = Linear(self.virtual_pattern_dim, self.out_dim, bias=bias)
169
+
170
+ self.bias_k, self.bias_v = None, None
171
+ if add_bias_kv:
172
+ if not key_as_static:
173
+ self.bias_k = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim))
174
+ if not value_as_static:
175
+ self.bias_v = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim))
176
+ assert not (self.bias_k is None and self.bias_v is None), r'cannot set key/value bias if both are static.'
177
+
178
+ self.add_zero_attn = add_zero_attn
179
+ self.reset_parameters()
180
+
181
+ def _check_execution_mode(self) -> bool:
182
+ return all((
183
+ self.key_as_static, self.query_as_static, self.value_as_static, not self.value_as_connected,
184
+ not self.normalize_pattern, not self.normalize_pattern_affine, self.disable_out_projection
185
+ ))
186
+
187
+ def reset_parameters(self):
188
+ if self.p_norm_weight is not None:
189
+ nn.init.ones_(self.p_norm_weight)
190
+ nn.init.zeros_(self.p_norm_bias)
191
+
192
+ if self._qkv_same_embed_dim and (self.in_proj_weight is not None):
193
+ nn.init.normal_(self.in_proj_weight, mean=0.0, std=0.02)
194
+ else:
195
+ if self.q_proj_weight is not None:
196
+ nn.init.normal_(self.q_proj_weight, mean=0.0, std=0.02)
197
+ if self.k_proj_weight is not None:
198
+ nn.init.normal_(self.k_proj_weight, mean=0.0, std=0.02)
199
+ if self.v_proj_weight is not None:
200
+ nn.init.normal_(self.v_proj_weight, mean=0.0, std=0.02)
201
+
202
+ if self.in_proj_bias is not None:
203
+ nn.init.constant_(self.in_proj_bias, 0.0)
204
+ if not self.disable_out_projection:
205
+ nn.init.normal_(self.out_proj.weight, mean=0.0, std=0.02)
206
+ if self.out_proj.bias is not None:
207
+ nn.init.constant_(self.out_proj.bias, 0.0)
208
+ if self.bias_k is not None:
209
+ nn.init.normal_(self.bias_k, mean=0.0, std=0.02)
210
+ if self.bias_v is not None:
211
+ nn.init.normal_(self.bias_v, mean=0.0, std=0.02)
212
+
213
+ def __setstate__(self, state):
214
+ super(HopfieldCore, self).__setstate__(state)
215
+
216
+ def forward(self,
217
+ query, # type: Tensor
218
+ key, # type: Tensor
219
+ value, # type: Tensor
220
+ key_padding_mask=None, # type: Optional[Tensor]
221
+ need_weights=True, # type: bool
222
+ attn_mask=None, # type: Optional[Tensor]
223
+
224
+ scaling=None, # type: Optional[Tensor]
225
+ update_steps_max=0, # type: Optional[int]
226
+ update_steps_eps=1e-4, # type: float
227
+ return_raw_associations=False, # type: bool
228
+ return_pattern_projections=False # type: bool
229
+ ):
230
+ # type: (...) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
231
+ r"""
232
+ Args:
233
+ query, key, value: map a query and a set of key-value pairs to an output.
234
+ See "Attention Is All You Need" for more details.
235
+ See "Hopfield Networks is All You Need" for more details in the setting of Hopfield networks.
236
+ key_padding_mask: if provided, specified padding elements in the key will
237
+ be ignored by the attention. When given a binary mask and a value is True,
238
+ the corresponding value on the attention layer will be ignored. When given
239
+ a byte mask and a value is non-zero, the corresponding value on the attention
240
+ layer will be ignored.
241
+ need_weights: output attn_output_weights.
242
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
243
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
244
+
245
+ scaling: scaling of association heads, often represented as beta (one entry per head).
246
+ update_steps_max: maximum count of association update steps (None equals to infinity).
247
+ update_steps_eps: minimum difference threshold between two consecutive association update steps.
248
+ return_raw_associations: return raw association (softmax) values, unmodified.
249
+ return_pattern_projections: return pattern projection values, unmodified.
250
+
251
+ Shape:
252
+ - Inputs:
253
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
254
+ the embedding dimension.
255
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
256
+ the embedding dimension.
257
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
258
+ the embedding dimension.
259
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
260
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
261
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
262
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
263
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
264
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
265
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
266
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
267
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
268
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
269
+ is provided, it will be added to the attention weight.
270
+
271
+ - scaling: :math:`(num_heads,)`, where num_heads is the amount of heads.
272
+
273
+ - Outputs:
274
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
275
+ E is the embedding dimension.
276
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
277
+ L is the target sequence length, S is the source sequence length.
278
+ - attn_raw: :math:``(N, num_heads, L, S)`, where N is the batch size,
279
+ L is the target sequence length, S is the source sequence length.
280
+ """
281
+ if self.query_as_static and self.key_as_static:
282
+ assert query.shape[2] == key.shape[2], \
283
+ f'query shape[2] of {query.shape[2]} and key shape[2] of {key.shape[2]} need to be equal'
284
+ head_dim, embed_dim_to_check = query.shape[2], query.shape[2]
285
+ else:
286
+ assert self.query_as_static or (query.shape[2] == self.embed_dim), \
287
+ f'query shape[2] of {query.shape[2]} invalid, needs to be {self.embed_dim}.'
288
+ assert (not self.query_as_static) or (self.query_as_static and query.shape[2] == self.head_dim), \
289
+ f'query shape[2] of {query.shape[2]} invalid, needs to be {self.head_dim}'
290
+
291
+ assert self.key_as_static or (key.shape[2] == self.kdim), \
292
+ f'key shape[2] of {key.shape[2]} invalid, needs to be {self.kdim}.'
293
+ assert (not self.key_as_static) or (self.key_as_static and key.shape[2] == self.head_dim), \
294
+ f'key shape[2] of {key.shape[2]} invalid, needs to be {self.head_dim}'
295
+ head_dim, embed_dim_to_check = self.head_dim, self.head_dim if self.query_as_static else self.embed_dim
296
+
297
+ assert self.value_as_static or (value.shape[2] == self.vdim), \
298
+ f'value shape[2] of {value.shape[2]} invalid, needs to be {self.vdim}.'
299
+ assert any((
300
+ not self.value_as_static, self.value_as_static and value.shape[2] == self.pattern_dim,
301
+ self.disable_out_projection)
302
+ ), f'value shape[2] of {value.shape[2]} invalid, needs to be {self.pattern_dim}'
303
+
304
+ out_weights, out_bias = None, None
305
+ if not self.disable_out_projection:
306
+ out_weights, out_bias = self.out_proj.weight, self.out_proj.bias
307
+
308
+ if not self._qkv_same_embed_dim:
309
+ return hopfield_core_forward(
310
+ query=query, key=key, value=value, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads,
311
+ in_proj_weight=self.in_proj_weight, in_proj_bias=self.in_proj_bias, bias_k=self.bias_k,
312
+ bias_v=self.bias_v, add_zero_attn=self.add_zero_attn, dropout_p=self.dropout,
313
+ out_proj_weight=out_weights, out_proj_bias=out_bias, training=self.training,
314
+ key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask,
315
+ use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
316
+ v_proj_weight=self.v_proj_weight,
317
+
318
+ key_as_static=self.key_as_static, query_as_static=self.query_as_static,
319
+ value_as_static=self.value_as_static, value_as_connected=self.value_as_connected,
320
+ normalize_pattern=self.normalize_pattern, normalize_pattern_eps=self.normalize_pattern_eps,
321
+ p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias,
322
+ head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling,
323
+ update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
324
+ return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections)
325
+ else:
326
+ return hopfield_core_forward(
327
+ query=query, key=key, value=value, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads,
328
+ in_proj_weight=self.in_proj_weight, in_proj_bias=self.in_proj_bias, bias_k=self.bias_k,
329
+ bias_v=self.bias_v, add_zero_attn=self.add_zero_attn, dropout_p=self.dropout,
330
+ out_proj_weight=out_weights, out_proj_bias=out_bias, training=self.training,
331
+ key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask,
332
+
333
+ key_as_static=self.key_as_static, query_as_static=self.query_as_static,
334
+ value_as_static=self.value_as_static, value_as_connected=self.value_as_connected,
335
+ normalize_pattern=self.normalize_pattern, normalize_pattern_eps=self.normalize_pattern_eps,
336
+ p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias,
337
+ head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling,
338
+ update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
339
+ return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections)
@@ -0,0 +1,157 @@
1
+ """
2
+ Attention operator for the diffusion-attention dynamical memory system.
3
+
4
+ Responsibility: Apply scaled-dot-product attention in exactly two modes.
5
+ This is the *only* place that computes attention — no other module does it.
6
+
7
+ Modes
8
+ -----
9
+ dense (default)
10
+ logits = beta * Q @ K.T — (N, N)
11
+ weights = softmax(logits) — (N, N)
12
+ output = weights @ V — (N, d)
13
+ O(N²d) time, O(N²) space.
14
+ Exact match to the Hopfield baseline.
15
+
16
+ graph
17
+ For each query i, attend *only* to its kNN neighbors.
18
+ Requires adj_indices (N, k) LongTensor from ``GraphBuilder``.
19
+ O(kNd) time, O(kN) space.
20
+ Strictly faster than dense when k ≪ N.
21
+
22
+ API
23
+ ---
24
+ op = AttentionOperator(beta=10.0, mode="dense")
25
+ out = op(Q, K, V) # dense
26
+ out = op(Q, K, V, adj_indices=adj_idx) # graph
27
+
28
+ Constraints
29
+ -----------
30
+ * Modes are implemented in separate methods (_dense / _graph).
31
+ No conditional logic inside a shared inner loop.
32
+ * Dense mode never uses adj_indices; graph mode never builds N×N logit matrix.
33
+ """
34
+
35
+ from __future__ import annotations
36
+
37
+ from typing import Optional
38
+
39
+ import torch
40
+ import torch.nn.functional as F
41
+ from torch import Tensor
42
+
43
+
44
+ class AttentionOperator:
45
+ """
46
+ Scaled-dot-product attention with dense or graph-constrained mode.
47
+
48
+ Args:
49
+ beta: Scaling / inverse-temperature factor. Default: 1.0.
50
+ mode: ``'dense'`` (O(N²)) or ``'graph'`` (O(kN)).
51
+
52
+ Time complexity:
53
+ dense : O(N²d)
54
+ graph : O(kNd)
55
+
56
+ Space complexity:
57
+ dense : O(N²)
58
+ graph : O(kN)
59
+ """
60
+
61
+ def __init__(self, beta: float = 1.0, mode: str = "dense") -> None:
62
+ if mode not in ("dense", "graph"):
63
+ raise ValueError(
64
+ f"AttentionOperator: mode must be 'dense' or 'graph', got '{mode}'."
65
+ )
66
+ self.beta = beta
67
+ self.mode = mode
68
+
69
+ def __call__(
70
+ self,
71
+ Q: Tensor,
72
+ K: Tensor,
73
+ V: Tensor,
74
+ adj_indices: Optional[Tensor] = None,
75
+ ) -> Tensor:
76
+ """
77
+ Apply attention.
78
+
79
+ Args:
80
+ Q: (N, d) query patterns.
81
+ K: (N, d) key patterns.
82
+ V: (N, d) value patterns.
83
+ adj_indices: (N, k) LongTensor of neighbor indices.
84
+ Required for graph mode; ignored in dense mode.
85
+
86
+ Returns:
87
+ output: (N, d) attended result.
88
+ """
89
+ if self.mode == "dense":
90
+ return self._dense(Q, K, V)
91
+ return self._graph(Q, K, V, adj_indices)
92
+
93
+ # ------------------------------------------------------------------
94
+ # Dense O(N²d)
95
+ # ------------------------------------------------------------------
96
+
97
+ def _dense(self, Q: Tensor, K: Tensor, V: Tensor) -> Tensor:
98
+ """
99
+ Standard full-rank attention.
100
+
101
+ 2D: logits = beta * Q @ K.T (N, N) → softmax → weights @ V.
102
+ 3D: (S, B, d) → batched bmm over B.
103
+
104
+ Complexity: O(N²d) time, O(N²) space per batch element.
105
+ """
106
+ if Q.dim() == 2:
107
+ logits = self.beta * (Q @ K.t()) # (N, N)
108
+ weights = F.softmax(logits, dim=-1) # (N, N)
109
+ return weights @ V # (N, d)
110
+ # 3D: (S, B, d) — transpose to (B, S, d) for batched matmul
111
+ Q_b, K_b, V_b = Q.permute(1, 0, 2), K.permute(1, 0, 2), V.permute(1, 0, 2)
112
+ logits = self.beta * torch.bmm(Q_b, K_b.transpose(1, 2)) # (B, S, S)
113
+ weights = F.softmax(logits, dim=-1) # (B, S, S)
114
+ return torch.bmm(weights, V_b).permute(1, 0, 2) # (S, B, d)
115
+
116
+ # ------------------------------------------------------------------
117
+ # Graph-constrained O(kNd)
118
+ # ------------------------------------------------------------------
119
+
120
+ def _graph(
121
+ self,
122
+ Q: Tensor,
123
+ K: Tensor,
124
+ V: Tensor,
125
+ adj_indices: Optional[Tensor],
126
+ ) -> Tensor:
127
+ """
128
+ Attend only to kNN neighbors of each query node.
129
+
130
+ 2D: adj_indices (N, k) → gather → local softmax → weighted sum.
131
+ 3D: (S, B, d) → batched gather over B.
132
+ No N×N logit matrix is formed.
133
+
134
+ Complexity: O(kNd) time, O(kN) space per batch element.
135
+ """
136
+ if adj_indices is None:
137
+ raise ValueError(
138
+ "AttentionOperator(mode='graph') requires adj_indices (N, k)."
139
+ )
140
+
141
+ if Q.dim() == 2:
142
+ K_nbrs = K[adj_indices] # (N, k, d)
143
+ V_nbrs = V[adj_indices] # (N, k, d)
144
+ logits = self.beta * (Q.unsqueeze(1) * K_nbrs).sum(dim=-1) # (N, k)
145
+ weights = F.softmax(logits, dim=-1) # (N, k)
146
+ return (weights.unsqueeze(-1) * V_nbrs).sum(dim=1) # (N, d)
147
+
148
+ # 3D: (S, B, d) — transpose to (B, S, d) for batched gather
149
+ Q_b = Q.permute(1, 0, 2) # (B, S, d)
150
+ K_b = K.permute(1, 0, 2) # (B, S, d)
151
+ V_b = V.permute(1, 0, 2) # (B, S, d)
152
+ K_nbrs = K_b[:, adj_indices] # (B, S, k, d)
153
+ V_nbrs = V_b[:, adj_indices] # (B, S, k, d)
154
+ logits = self.beta * (Q_b.unsqueeze(2) * K_nbrs).sum(dim=-1) # (B, S, k)
155
+ weights = F.softmax(logits, dim=-1) # (B, S, k)
156
+ output = (weights.unsqueeze(-1) * V_nbrs).sum(dim=2) # (B, S, d)
157
+ return output.permute(1, 0, 2) # (S, B, d)
File without changes