difflayers 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- difflayers/__init__.py +965 -0
- difflayers/activation.py +339 -0
- difflayers/attention_operator.py +157 -0
- difflayers/auxiliary/__init__.py +0 -0
- difflayers/auxiliary/data.py +252 -0
- difflayers/diffused_attention.py +427 -0
- difflayers/diffusion.py +395 -0
- difflayers/dynamics_engine.py +540 -0
- difflayers/functional.py +459 -0
- difflayers/graph/__init__.py +18 -0
- difflayers/graph/build_graph.py +77 -0
- difflayers/graph/builder.py +120 -0
- difflayers/graph/laplacian.py +76 -0
- difflayers/graph/laplacian_builder.py +64 -0
- difflayers/transformer.py +212 -0
- difflayers-0.1.0.dist-info/METADATA +210 -0
- difflayers-0.1.0.dist-info/RECORD +20 -0
- difflayers-0.1.0.dist-info/WHEEL +5 -0
- difflayers-0.1.0.dist-info/licenses/LICENSE +79 -0
- difflayers-0.1.0.dist-info/top_level.txt +1 -0
difflayers/activation.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
from torch.nn import Linear, Module, Parameter
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from .functional import hopfield_core_forward
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from torch.nn.modules.linear import _LinearWithBias
|
|
12
|
+
except ImportError:
|
|
13
|
+
_LinearWithBias = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HopfieldCore(Module):
|
|
17
|
+
r"""Allows the model to jointly attend to information
|
|
18
|
+
from different representation subspaces.
|
|
19
|
+
See references: "Hopfield Networks is All You Need" and
|
|
20
|
+
"Attention Is All You Need" (on which this implementation is partly based on).
|
|
21
|
+
|
|
22
|
+
.. math::
|
|
23
|
+
\text{HopfieldHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
|
24
|
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
embed_dim: total dimension of the model.
|
|
28
|
+
num_heads: parallel attention heads.
|
|
29
|
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
|
30
|
+
bias: add bias as module parameter. Default: True.
|
|
31
|
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
|
32
|
+
add_zero_attn: add a new batch of zeros to the key and
|
|
33
|
+
value sequences at dim=1.
|
|
34
|
+
kdim: total number of features in key. Default: None.
|
|
35
|
+
vdim: total number of features in value. Default: None.
|
|
36
|
+
|
|
37
|
+
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
|
38
|
+
query, key, and value have the same number of features.
|
|
39
|
+
|
|
40
|
+
Examples::
|
|
41
|
+
|
|
42
|
+
>>> hopfield_attn = HopfieldCore(embed_dim, num_heads)
|
|
43
|
+
>>> attn_output, attn_output_weights, attn_matrix = hopfield_attn(query, key, value)
|
|
44
|
+
"""
|
|
45
|
+
__annotations__ = {
|
|
46
|
+
'bias_k': torch._jit_internal.Optional[torch.Tensor],
|
|
47
|
+
'bias_v': torch._jit_internal.Optional[torch.Tensor],
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
def __init__(self,
|
|
51
|
+
embed_dim=None, # type: Optional[int]
|
|
52
|
+
num_heads=1, # type: int
|
|
53
|
+
dropout=0.0, # type: float
|
|
54
|
+
bias=True, # type: bool
|
|
55
|
+
add_bias_kv=False, # type: bool
|
|
56
|
+
add_zero_attn=False, # type: bool
|
|
57
|
+
kdim=None, # type: Optional[int]
|
|
58
|
+
vdim=None, # type: Optional[int]
|
|
59
|
+
|
|
60
|
+
head_dim=None, # type: Optional[int]
|
|
61
|
+
pattern_dim=None, # type: Optional[int]
|
|
62
|
+
out_dim=None, # type: Optional[int]
|
|
63
|
+
disable_out_projection=False, # type: bool
|
|
64
|
+
key_as_static=False, # type: bool
|
|
65
|
+
query_as_static=False, # type: bool
|
|
66
|
+
value_as_static=False, # type: bool
|
|
67
|
+
value_as_connected=False, # type: bool
|
|
68
|
+
normalize_pattern=False, # type: bool
|
|
69
|
+
normalize_pattern_affine=False, # type: bool
|
|
70
|
+
normalize_pattern_eps=1e-5 # type: float
|
|
71
|
+
):
|
|
72
|
+
super(HopfieldCore, self).__init__()
|
|
73
|
+
|
|
74
|
+
assert (type(key_as_static) == bool) and (type(query_as_static) == bool) and (type(value_as_static) == bool)
|
|
75
|
+
self.key_as_static, self.query_as_static, self.value_as_static = key_as_static, query_as_static, value_as_static
|
|
76
|
+
num_non_static = 3 - (self.key_as_static + self.query_as_static + self.value_as_static)
|
|
77
|
+
assert 0 <= num_non_static < 4
|
|
78
|
+
|
|
79
|
+
self.value_as_connected = value_as_connected
|
|
80
|
+
self.normalize_pattern, self.normalize_pattern_affine = normalize_pattern, normalize_pattern_affine
|
|
81
|
+
self.normalize_pattern_eps = normalize_pattern_eps
|
|
82
|
+
self.disable_out_projection = disable_out_projection
|
|
83
|
+
|
|
84
|
+
# In case of a static-only executions, check corresponding projections and normalizations.
|
|
85
|
+
self.static_execution = self._check_execution_mode()
|
|
86
|
+
if self.static_execution:
|
|
87
|
+
embed_dim, kdim, vdim = None, None, None
|
|
88
|
+
if embed_dim is None:
|
|
89
|
+
assert self.static_execution, r'static-only execution requires all projections to be deactivated.'
|
|
90
|
+
|
|
91
|
+
# Check and set all other properties, conditioned on <static_execution>.
|
|
92
|
+
self.embed_dim = embed_dim
|
|
93
|
+
self.kdim = kdim if kdim is not None else embed_dim
|
|
94
|
+
self.vdim = vdim if vdim is not None else embed_dim
|
|
95
|
+
self._qkv_same_embed_dim = all((
|
|
96
|
+
self.kdim == embed_dim, self.vdim == embed_dim, pattern_dim is None, not self.value_as_connected))
|
|
97
|
+
assert (not self.value_as_connected) or (self.kdim == self.vdim), r'key and value need to be of same dimension.'
|
|
98
|
+
|
|
99
|
+
self.num_heads = num_heads
|
|
100
|
+
self.dropout = dropout
|
|
101
|
+
self.head_dim = None
|
|
102
|
+
self.pattern_dim = pattern_dim
|
|
103
|
+
self.virtual_hopfield_dim = None
|
|
104
|
+
self.virtual_pattern_dim = None
|
|
105
|
+
if not self.static_execution:
|
|
106
|
+
if head_dim is None:
|
|
107
|
+
self.head_dim = embed_dim // num_heads
|
|
108
|
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads."
|
|
109
|
+
else:
|
|
110
|
+
assert head_dim > 0, "dimension of the association space has to be positive."
|
|
111
|
+
self.head_dim = head_dim
|
|
112
|
+
if self.pattern_dim is None:
|
|
113
|
+
self.pattern_dim = self.head_dim
|
|
114
|
+
self.virtual_hopfield_dim = self.num_heads * self.head_dim
|
|
115
|
+
self.virtual_pattern_dim = self.num_heads * self.pattern_dim
|
|
116
|
+
|
|
117
|
+
self.out_dim = embed_dim if out_dim is None else out_dim
|
|
118
|
+
assert disable_out_projection or (self.out_dim > 0), "output projection dimension has to be positive."
|
|
119
|
+
|
|
120
|
+
if normalize_pattern_affine:
|
|
121
|
+
assert normalize_pattern, "affine pattern normalization without pattern normalization has no effect."
|
|
122
|
+
self.p_norm_weight = Parameter(torch.Tensor(head_dim))
|
|
123
|
+
self.p_norm_bias = Parameter(torch.Tensor(head_dim))
|
|
124
|
+
else:
|
|
125
|
+
self.register_parameter('p_norm_weight', None)
|
|
126
|
+
self.register_parameter('p_norm_bias', None)
|
|
127
|
+
|
|
128
|
+
if self._qkv_same_embed_dim is False:
|
|
129
|
+
if query_as_static:
|
|
130
|
+
self.register_parameter('q_proj_weight', None)
|
|
131
|
+
else:
|
|
132
|
+
self.q_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, embed_dim))
|
|
133
|
+
if key_as_static:
|
|
134
|
+
self.register_parameter('k_proj_weight', None)
|
|
135
|
+
else:
|
|
136
|
+
self.k_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, self.kdim))
|
|
137
|
+
if value_as_static:
|
|
138
|
+
self.register_parameter('v_proj_weight', None)
|
|
139
|
+
else:
|
|
140
|
+
self.v_proj_weight = Parameter(torch.Tensor(
|
|
141
|
+
self.virtual_pattern_dim,
|
|
142
|
+
self.virtual_hopfield_dim if (value_as_connected and not key_as_static) else self.vdim))
|
|
143
|
+
self.register_parameter('in_proj_weight', None)
|
|
144
|
+
else:
|
|
145
|
+
if num_non_static > 0:
|
|
146
|
+
self.in_proj_weight = Parameter(torch.empty(
|
|
147
|
+
(not query_as_static) * self.virtual_hopfield_dim +
|
|
148
|
+
(not key_as_static) * self.virtual_hopfield_dim +
|
|
149
|
+
(not value_as_static) * self.virtual_pattern_dim, embed_dim))
|
|
150
|
+
else:
|
|
151
|
+
self.register_parameter('in_proj_weight', None)
|
|
152
|
+
self.register_parameter('q_proj_weight', None)
|
|
153
|
+
self.register_parameter('k_proj_weight', None)
|
|
154
|
+
self.register_parameter('v_proj_weight', None)
|
|
155
|
+
|
|
156
|
+
if bias and (num_non_static > 0):
|
|
157
|
+
self.in_proj_bias = Parameter(torch.empty(
|
|
158
|
+
(not query_as_static) * self.virtual_hopfield_dim +
|
|
159
|
+
(not key_as_static) * self.virtual_hopfield_dim + self.virtual_pattern_dim))
|
|
160
|
+
else:
|
|
161
|
+
self.register_parameter('in_proj_bias', None)
|
|
162
|
+
if disable_out_projection:
|
|
163
|
+
self.register_parameter('out_proj', None)
|
|
164
|
+
else:
|
|
165
|
+
if bias and _LinearWithBias is not None:
|
|
166
|
+
self.out_proj = _LinearWithBias(self.virtual_pattern_dim, self.out_dim)
|
|
167
|
+
else:
|
|
168
|
+
self.out_proj = Linear(self.virtual_pattern_dim, self.out_dim, bias=bias)
|
|
169
|
+
|
|
170
|
+
self.bias_k, self.bias_v = None, None
|
|
171
|
+
if add_bias_kv:
|
|
172
|
+
if not key_as_static:
|
|
173
|
+
self.bias_k = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim))
|
|
174
|
+
if not value_as_static:
|
|
175
|
+
self.bias_v = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim))
|
|
176
|
+
assert not (self.bias_k is None and self.bias_v is None), r'cannot set key/value bias if both are static.'
|
|
177
|
+
|
|
178
|
+
self.add_zero_attn = add_zero_attn
|
|
179
|
+
self.reset_parameters()
|
|
180
|
+
|
|
181
|
+
def _check_execution_mode(self) -> bool:
|
|
182
|
+
return all((
|
|
183
|
+
self.key_as_static, self.query_as_static, self.value_as_static, not self.value_as_connected,
|
|
184
|
+
not self.normalize_pattern, not self.normalize_pattern_affine, self.disable_out_projection
|
|
185
|
+
))
|
|
186
|
+
|
|
187
|
+
def reset_parameters(self):
|
|
188
|
+
if self.p_norm_weight is not None:
|
|
189
|
+
nn.init.ones_(self.p_norm_weight)
|
|
190
|
+
nn.init.zeros_(self.p_norm_bias)
|
|
191
|
+
|
|
192
|
+
if self._qkv_same_embed_dim and (self.in_proj_weight is not None):
|
|
193
|
+
nn.init.normal_(self.in_proj_weight, mean=0.0, std=0.02)
|
|
194
|
+
else:
|
|
195
|
+
if self.q_proj_weight is not None:
|
|
196
|
+
nn.init.normal_(self.q_proj_weight, mean=0.0, std=0.02)
|
|
197
|
+
if self.k_proj_weight is not None:
|
|
198
|
+
nn.init.normal_(self.k_proj_weight, mean=0.0, std=0.02)
|
|
199
|
+
if self.v_proj_weight is not None:
|
|
200
|
+
nn.init.normal_(self.v_proj_weight, mean=0.0, std=0.02)
|
|
201
|
+
|
|
202
|
+
if self.in_proj_bias is not None:
|
|
203
|
+
nn.init.constant_(self.in_proj_bias, 0.0)
|
|
204
|
+
if not self.disable_out_projection:
|
|
205
|
+
nn.init.normal_(self.out_proj.weight, mean=0.0, std=0.02)
|
|
206
|
+
if self.out_proj.bias is not None:
|
|
207
|
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
|
208
|
+
if self.bias_k is not None:
|
|
209
|
+
nn.init.normal_(self.bias_k, mean=0.0, std=0.02)
|
|
210
|
+
if self.bias_v is not None:
|
|
211
|
+
nn.init.normal_(self.bias_v, mean=0.0, std=0.02)
|
|
212
|
+
|
|
213
|
+
def __setstate__(self, state):
|
|
214
|
+
super(HopfieldCore, self).__setstate__(state)
|
|
215
|
+
|
|
216
|
+
def forward(self,
|
|
217
|
+
query, # type: Tensor
|
|
218
|
+
key, # type: Tensor
|
|
219
|
+
value, # type: Tensor
|
|
220
|
+
key_padding_mask=None, # type: Optional[Tensor]
|
|
221
|
+
need_weights=True, # type: bool
|
|
222
|
+
attn_mask=None, # type: Optional[Tensor]
|
|
223
|
+
|
|
224
|
+
scaling=None, # type: Optional[Tensor]
|
|
225
|
+
update_steps_max=0, # type: Optional[int]
|
|
226
|
+
update_steps_eps=1e-4, # type: float
|
|
227
|
+
return_raw_associations=False, # type: bool
|
|
228
|
+
return_pattern_projections=False # type: bool
|
|
229
|
+
):
|
|
230
|
+
# type: (...) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
|
|
231
|
+
r"""
|
|
232
|
+
Args:
|
|
233
|
+
query, key, value: map a query and a set of key-value pairs to an output.
|
|
234
|
+
See "Attention Is All You Need" for more details.
|
|
235
|
+
See "Hopfield Networks is All You Need" for more details in the setting of Hopfield networks.
|
|
236
|
+
key_padding_mask: if provided, specified padding elements in the key will
|
|
237
|
+
be ignored by the attention. When given a binary mask and a value is True,
|
|
238
|
+
the corresponding value on the attention layer will be ignored. When given
|
|
239
|
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
|
240
|
+
layer will be ignored.
|
|
241
|
+
need_weights: output attn_output_weights.
|
|
242
|
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
|
243
|
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
|
244
|
+
|
|
245
|
+
scaling: scaling of association heads, often represented as beta (one entry per head).
|
|
246
|
+
update_steps_max: maximum count of association update steps (None equals to infinity).
|
|
247
|
+
update_steps_eps: minimum difference threshold between two consecutive association update steps.
|
|
248
|
+
return_raw_associations: return raw association (softmax) values, unmodified.
|
|
249
|
+
return_pattern_projections: return pattern projection values, unmodified.
|
|
250
|
+
|
|
251
|
+
Shape:
|
|
252
|
+
- Inputs:
|
|
253
|
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
|
254
|
+
the embedding dimension.
|
|
255
|
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
|
256
|
+
the embedding dimension.
|
|
257
|
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
|
258
|
+
the embedding dimension.
|
|
259
|
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
|
260
|
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
|
261
|
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
|
262
|
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
|
263
|
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
|
264
|
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
|
265
|
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
|
266
|
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
|
267
|
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
|
268
|
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
|
269
|
+
is provided, it will be added to the attention weight.
|
|
270
|
+
|
|
271
|
+
- scaling: :math:`(num_heads,)`, where num_heads is the amount of heads.
|
|
272
|
+
|
|
273
|
+
- Outputs:
|
|
274
|
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
|
275
|
+
E is the embedding dimension.
|
|
276
|
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
|
277
|
+
L is the target sequence length, S is the source sequence length.
|
|
278
|
+
- attn_raw: :math:``(N, num_heads, L, S)`, where N is the batch size,
|
|
279
|
+
L is the target sequence length, S is the source sequence length.
|
|
280
|
+
"""
|
|
281
|
+
if self.query_as_static and self.key_as_static:
|
|
282
|
+
assert query.shape[2] == key.shape[2], \
|
|
283
|
+
f'query shape[2] of {query.shape[2]} and key shape[2] of {key.shape[2]} need to be equal'
|
|
284
|
+
head_dim, embed_dim_to_check = query.shape[2], query.shape[2]
|
|
285
|
+
else:
|
|
286
|
+
assert self.query_as_static or (query.shape[2] == self.embed_dim), \
|
|
287
|
+
f'query shape[2] of {query.shape[2]} invalid, needs to be {self.embed_dim}.'
|
|
288
|
+
assert (not self.query_as_static) or (self.query_as_static and query.shape[2] == self.head_dim), \
|
|
289
|
+
f'query shape[2] of {query.shape[2]} invalid, needs to be {self.head_dim}'
|
|
290
|
+
|
|
291
|
+
assert self.key_as_static or (key.shape[2] == self.kdim), \
|
|
292
|
+
f'key shape[2] of {key.shape[2]} invalid, needs to be {self.kdim}.'
|
|
293
|
+
assert (not self.key_as_static) or (self.key_as_static and key.shape[2] == self.head_dim), \
|
|
294
|
+
f'key shape[2] of {key.shape[2]} invalid, needs to be {self.head_dim}'
|
|
295
|
+
head_dim, embed_dim_to_check = self.head_dim, self.head_dim if self.query_as_static else self.embed_dim
|
|
296
|
+
|
|
297
|
+
assert self.value_as_static or (value.shape[2] == self.vdim), \
|
|
298
|
+
f'value shape[2] of {value.shape[2]} invalid, needs to be {self.vdim}.'
|
|
299
|
+
assert any((
|
|
300
|
+
not self.value_as_static, self.value_as_static and value.shape[2] == self.pattern_dim,
|
|
301
|
+
self.disable_out_projection)
|
|
302
|
+
), f'value shape[2] of {value.shape[2]} invalid, needs to be {self.pattern_dim}'
|
|
303
|
+
|
|
304
|
+
out_weights, out_bias = None, None
|
|
305
|
+
if not self.disable_out_projection:
|
|
306
|
+
out_weights, out_bias = self.out_proj.weight, self.out_proj.bias
|
|
307
|
+
|
|
308
|
+
if not self._qkv_same_embed_dim:
|
|
309
|
+
return hopfield_core_forward(
|
|
310
|
+
query=query, key=key, value=value, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads,
|
|
311
|
+
in_proj_weight=self.in_proj_weight, in_proj_bias=self.in_proj_bias, bias_k=self.bias_k,
|
|
312
|
+
bias_v=self.bias_v, add_zero_attn=self.add_zero_attn, dropout_p=self.dropout,
|
|
313
|
+
out_proj_weight=out_weights, out_proj_bias=out_bias, training=self.training,
|
|
314
|
+
key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask,
|
|
315
|
+
use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
|
316
|
+
v_proj_weight=self.v_proj_weight,
|
|
317
|
+
|
|
318
|
+
key_as_static=self.key_as_static, query_as_static=self.query_as_static,
|
|
319
|
+
value_as_static=self.value_as_static, value_as_connected=self.value_as_connected,
|
|
320
|
+
normalize_pattern=self.normalize_pattern, normalize_pattern_eps=self.normalize_pattern_eps,
|
|
321
|
+
p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias,
|
|
322
|
+
head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling,
|
|
323
|
+
update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
|
|
324
|
+
return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections)
|
|
325
|
+
else:
|
|
326
|
+
return hopfield_core_forward(
|
|
327
|
+
query=query, key=key, value=value, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads,
|
|
328
|
+
in_proj_weight=self.in_proj_weight, in_proj_bias=self.in_proj_bias, bias_k=self.bias_k,
|
|
329
|
+
bias_v=self.bias_v, add_zero_attn=self.add_zero_attn, dropout_p=self.dropout,
|
|
330
|
+
out_proj_weight=out_weights, out_proj_bias=out_bias, training=self.training,
|
|
331
|
+
key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask,
|
|
332
|
+
|
|
333
|
+
key_as_static=self.key_as_static, query_as_static=self.query_as_static,
|
|
334
|
+
value_as_static=self.value_as_static, value_as_connected=self.value_as_connected,
|
|
335
|
+
normalize_pattern=self.normalize_pattern, normalize_pattern_eps=self.normalize_pattern_eps,
|
|
336
|
+
p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias,
|
|
337
|
+
head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling,
|
|
338
|
+
update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
|
|
339
|
+
return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections)
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Attention operator for the diffusion-attention dynamical memory system.
|
|
3
|
+
|
|
4
|
+
Responsibility: Apply scaled-dot-product attention in exactly two modes.
|
|
5
|
+
This is the *only* place that computes attention — no other module does it.
|
|
6
|
+
|
|
7
|
+
Modes
|
|
8
|
+
-----
|
|
9
|
+
dense (default)
|
|
10
|
+
logits = beta * Q @ K.T — (N, N)
|
|
11
|
+
weights = softmax(logits) — (N, N)
|
|
12
|
+
output = weights @ V — (N, d)
|
|
13
|
+
O(N²d) time, O(N²) space.
|
|
14
|
+
Exact match to the Hopfield baseline.
|
|
15
|
+
|
|
16
|
+
graph
|
|
17
|
+
For each query i, attend *only* to its kNN neighbors.
|
|
18
|
+
Requires adj_indices (N, k) LongTensor from ``GraphBuilder``.
|
|
19
|
+
O(kNd) time, O(kN) space.
|
|
20
|
+
Strictly faster than dense when k ≪ N.
|
|
21
|
+
|
|
22
|
+
API
|
|
23
|
+
---
|
|
24
|
+
op = AttentionOperator(beta=10.0, mode="dense")
|
|
25
|
+
out = op(Q, K, V) # dense
|
|
26
|
+
out = op(Q, K, V, adj_indices=adj_idx) # graph
|
|
27
|
+
|
|
28
|
+
Constraints
|
|
29
|
+
-----------
|
|
30
|
+
* Modes are implemented in separate methods (_dense / _graph).
|
|
31
|
+
No conditional logic inside a shared inner loop.
|
|
32
|
+
* Dense mode never uses adj_indices; graph mode never builds N×N logit matrix.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from __future__ import annotations
|
|
36
|
+
|
|
37
|
+
from typing import Optional
|
|
38
|
+
|
|
39
|
+
import torch
|
|
40
|
+
import torch.nn.functional as F
|
|
41
|
+
from torch import Tensor
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class AttentionOperator:
|
|
45
|
+
"""
|
|
46
|
+
Scaled-dot-product attention with dense or graph-constrained mode.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
beta: Scaling / inverse-temperature factor. Default: 1.0.
|
|
50
|
+
mode: ``'dense'`` (O(N²)) or ``'graph'`` (O(kN)).
|
|
51
|
+
|
|
52
|
+
Time complexity:
|
|
53
|
+
dense : O(N²d)
|
|
54
|
+
graph : O(kNd)
|
|
55
|
+
|
|
56
|
+
Space complexity:
|
|
57
|
+
dense : O(N²)
|
|
58
|
+
graph : O(kN)
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, beta: float = 1.0, mode: str = "dense") -> None:
|
|
62
|
+
if mode not in ("dense", "graph"):
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"AttentionOperator: mode must be 'dense' or 'graph', got '{mode}'."
|
|
65
|
+
)
|
|
66
|
+
self.beta = beta
|
|
67
|
+
self.mode = mode
|
|
68
|
+
|
|
69
|
+
def __call__(
|
|
70
|
+
self,
|
|
71
|
+
Q: Tensor,
|
|
72
|
+
K: Tensor,
|
|
73
|
+
V: Tensor,
|
|
74
|
+
adj_indices: Optional[Tensor] = None,
|
|
75
|
+
) -> Tensor:
|
|
76
|
+
"""
|
|
77
|
+
Apply attention.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
Q: (N, d) query patterns.
|
|
81
|
+
K: (N, d) key patterns.
|
|
82
|
+
V: (N, d) value patterns.
|
|
83
|
+
adj_indices: (N, k) LongTensor of neighbor indices.
|
|
84
|
+
Required for graph mode; ignored in dense mode.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
output: (N, d) attended result.
|
|
88
|
+
"""
|
|
89
|
+
if self.mode == "dense":
|
|
90
|
+
return self._dense(Q, K, V)
|
|
91
|
+
return self._graph(Q, K, V, adj_indices)
|
|
92
|
+
|
|
93
|
+
# ------------------------------------------------------------------
|
|
94
|
+
# Dense O(N²d)
|
|
95
|
+
# ------------------------------------------------------------------
|
|
96
|
+
|
|
97
|
+
def _dense(self, Q: Tensor, K: Tensor, V: Tensor) -> Tensor:
|
|
98
|
+
"""
|
|
99
|
+
Standard full-rank attention.
|
|
100
|
+
|
|
101
|
+
2D: logits = beta * Q @ K.T (N, N) → softmax → weights @ V.
|
|
102
|
+
3D: (S, B, d) → batched bmm over B.
|
|
103
|
+
|
|
104
|
+
Complexity: O(N²d) time, O(N²) space per batch element.
|
|
105
|
+
"""
|
|
106
|
+
if Q.dim() == 2:
|
|
107
|
+
logits = self.beta * (Q @ K.t()) # (N, N)
|
|
108
|
+
weights = F.softmax(logits, dim=-1) # (N, N)
|
|
109
|
+
return weights @ V # (N, d)
|
|
110
|
+
# 3D: (S, B, d) — transpose to (B, S, d) for batched matmul
|
|
111
|
+
Q_b, K_b, V_b = Q.permute(1, 0, 2), K.permute(1, 0, 2), V.permute(1, 0, 2)
|
|
112
|
+
logits = self.beta * torch.bmm(Q_b, K_b.transpose(1, 2)) # (B, S, S)
|
|
113
|
+
weights = F.softmax(logits, dim=-1) # (B, S, S)
|
|
114
|
+
return torch.bmm(weights, V_b).permute(1, 0, 2) # (S, B, d)
|
|
115
|
+
|
|
116
|
+
# ------------------------------------------------------------------
|
|
117
|
+
# Graph-constrained O(kNd)
|
|
118
|
+
# ------------------------------------------------------------------
|
|
119
|
+
|
|
120
|
+
def _graph(
|
|
121
|
+
self,
|
|
122
|
+
Q: Tensor,
|
|
123
|
+
K: Tensor,
|
|
124
|
+
V: Tensor,
|
|
125
|
+
adj_indices: Optional[Tensor],
|
|
126
|
+
) -> Tensor:
|
|
127
|
+
"""
|
|
128
|
+
Attend only to kNN neighbors of each query node.
|
|
129
|
+
|
|
130
|
+
2D: adj_indices (N, k) → gather → local softmax → weighted sum.
|
|
131
|
+
3D: (S, B, d) → batched gather over B.
|
|
132
|
+
No N×N logit matrix is formed.
|
|
133
|
+
|
|
134
|
+
Complexity: O(kNd) time, O(kN) space per batch element.
|
|
135
|
+
"""
|
|
136
|
+
if adj_indices is None:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
"AttentionOperator(mode='graph') requires adj_indices (N, k)."
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
if Q.dim() == 2:
|
|
142
|
+
K_nbrs = K[adj_indices] # (N, k, d)
|
|
143
|
+
V_nbrs = V[adj_indices] # (N, k, d)
|
|
144
|
+
logits = self.beta * (Q.unsqueeze(1) * K_nbrs).sum(dim=-1) # (N, k)
|
|
145
|
+
weights = F.softmax(logits, dim=-1) # (N, k)
|
|
146
|
+
return (weights.unsqueeze(-1) * V_nbrs).sum(dim=1) # (N, d)
|
|
147
|
+
|
|
148
|
+
# 3D: (S, B, d) — transpose to (B, S, d) for batched gather
|
|
149
|
+
Q_b = Q.permute(1, 0, 2) # (B, S, d)
|
|
150
|
+
K_b = K.permute(1, 0, 2) # (B, S, d)
|
|
151
|
+
V_b = V.permute(1, 0, 2) # (B, S, d)
|
|
152
|
+
K_nbrs = K_b[:, adj_indices] # (B, S, k, d)
|
|
153
|
+
V_nbrs = V_b[:, adj_indices] # (B, S, k, d)
|
|
154
|
+
logits = self.beta * (Q_b.unsqueeze(2) * K_nbrs).sum(dim=-1) # (B, S, k)
|
|
155
|
+
weights = F.softmax(logits, dim=-1) # (B, S, k)
|
|
156
|
+
output = (weights.unsqueeze(-1) * V_nbrs).sum(dim=2) # (B, S, d)
|
|
157
|
+
return output.permute(1, 0, 2) # (S, B, d)
|
|
File without changes
|