difflayers 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- difflayers/__init__.py +965 -0
- difflayers/activation.py +339 -0
- difflayers/attention_operator.py +157 -0
- difflayers/auxiliary/__init__.py +0 -0
- difflayers/auxiliary/data.py +252 -0
- difflayers/diffused_attention.py +427 -0
- difflayers/diffusion.py +395 -0
- difflayers/dynamics_engine.py +540 -0
- difflayers/functional.py +459 -0
- difflayers/graph/__init__.py +18 -0
- difflayers/graph/build_graph.py +77 -0
- difflayers/graph/builder.py +120 -0
- difflayers/graph/laplacian.py +76 -0
- difflayers/graph/laplacian_builder.py +64 -0
- difflayers/transformer.py +212 -0
- difflayers-0.1.0.dist-info/METADATA +210 -0
- difflayers-0.1.0.dist-info/RECORD +20 -0
- difflayers-0.1.0.dist-info/WHEEL +5 -0
- difflayers-0.1.0.dist-info/licenses/LICENSE +79 -0
- difflayers-0.1.0.dist-info/top_level.txt +1 -0
difflayers/__init__.py
ADDED
|
@@ -0,0 +1,965 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from math import sqrt
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
from torch.nn import Module, Parameter
|
|
7
|
+
from typing import Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
from .activation import HopfieldCore
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Hopfield(Module):
|
|
13
|
+
"""
|
|
14
|
+
Module with underlying Hopfield association.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self,
|
|
18
|
+
input_size: Optional[int] = None,
|
|
19
|
+
hidden_size: Optional[int] = None,
|
|
20
|
+
output_size: Optional[int] = None,
|
|
21
|
+
pattern_size: Optional[int] = None,
|
|
22
|
+
num_heads: int = 1,
|
|
23
|
+
scaling: Optional[Union[float, Tensor]] = None,
|
|
24
|
+
update_steps_max: Optional[Union[int, Tensor]] = 0,
|
|
25
|
+
update_steps_eps: Union[float, Tensor] = 1e-4,
|
|
26
|
+
|
|
27
|
+
normalize_stored_pattern: bool = True,
|
|
28
|
+
normalize_stored_pattern_affine: bool = True,
|
|
29
|
+
normalize_stored_pattern_eps: float = 1e-5,
|
|
30
|
+
normalize_state_pattern: bool = True,
|
|
31
|
+
normalize_state_pattern_affine: bool = True,
|
|
32
|
+
normalize_state_pattern_eps: float = 1e-5,
|
|
33
|
+
normalize_pattern_projection: bool = True,
|
|
34
|
+
normalize_pattern_projection_affine: bool = True,
|
|
35
|
+
normalize_pattern_projection_eps: float = 1e-5,
|
|
36
|
+
normalize_hopfield_space: bool = False,
|
|
37
|
+
normalize_hopfield_space_affine: bool = False,
|
|
38
|
+
normalize_hopfield_space_eps: float = 1e-5,
|
|
39
|
+
stored_pattern_as_static: bool = False,
|
|
40
|
+
state_pattern_as_static: bool = False,
|
|
41
|
+
pattern_projection_as_static: bool = False,
|
|
42
|
+
pattern_projection_as_connected: bool = False,
|
|
43
|
+
stored_pattern_size: Optional[int] = None,
|
|
44
|
+
pattern_projection_size: Optional[int] = None,
|
|
45
|
+
|
|
46
|
+
batch_first: bool = True,
|
|
47
|
+
association_activation: Optional[str] = None,
|
|
48
|
+
dropout: float = 0.0,
|
|
49
|
+
input_bias: bool = True,
|
|
50
|
+
concat_bias_pattern: bool = False,
|
|
51
|
+
add_zero_association: bool = False,
|
|
52
|
+
disable_out_projection: bool = False
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Initialise new instance of a Hopfield module.
|
|
56
|
+
|
|
57
|
+
:param input_size: depth of the input (state pattern)
|
|
58
|
+
:param hidden_size: depth of the association space
|
|
59
|
+
:param output_size: depth of the output projection
|
|
60
|
+
:param pattern_size: depth of patterns to be selected
|
|
61
|
+
:param num_heads: amount of parallel association heads
|
|
62
|
+
:param scaling: scaling of association heads, often represented as beta (one entry per head)
|
|
63
|
+
:param update_steps_max: maximum count of association update steps (None equals to infinity)
|
|
64
|
+
:param update_steps_eps: minimum difference threshold between two consecutive association update steps
|
|
65
|
+
:param normalize_stored_pattern: apply normalization on stored patterns
|
|
66
|
+
:param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns
|
|
67
|
+
:param normalize_stored_pattern_eps: offset of the denominator for numerical stability
|
|
68
|
+
:param normalize_state_pattern: apply normalization on state patterns
|
|
69
|
+
:param normalize_state_pattern_affine: additionally enable affine normalization of state patterns
|
|
70
|
+
:param normalize_state_pattern_eps: offset of the denominator for numerical stability
|
|
71
|
+
:param normalize_pattern_projection: apply normalization on the pattern projection
|
|
72
|
+
:param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection
|
|
73
|
+
:param normalize_pattern_projection_eps: offset of the denominator for numerical stability
|
|
74
|
+
:param normalize_hopfield_space: enable normalization of patterns in the Hopfield space
|
|
75
|
+
:param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space
|
|
76
|
+
:param normalize_hopfield_space_eps: offset of the denominator for numerical stability
|
|
77
|
+
:param stored_pattern_as_static: interpret specified stored patterns as being static
|
|
78
|
+
:param state_pattern_as_static: interpret specified state patterns as being static
|
|
79
|
+
:param pattern_projection_as_static: interpret specified pattern projections as being static
|
|
80
|
+
:param pattern_projection_as_connected: connect pattern projection with stored pattern
|
|
81
|
+
:param stored_pattern_size: depth of input (stored pattern)
|
|
82
|
+
:param pattern_projection_size: depth of input (pattern projection)
|
|
83
|
+
:param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size
|
|
84
|
+
:param association_activation: additional activation to be applied on the result of the Hopfield association
|
|
85
|
+
:param dropout: dropout probability applied on the association matrix
|
|
86
|
+
:param input_bias: bias to be added to input (state and stored pattern as well as pattern projection)
|
|
87
|
+
:param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection
|
|
88
|
+
:param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection
|
|
89
|
+
:param disable_out_projection: disable output projection
|
|
90
|
+
"""
|
|
91
|
+
super(Hopfield, self).__init__()
|
|
92
|
+
assert type(batch_first) == bool, f'"batch_first" needs to be a boolean, not {type(batch_first)}.'
|
|
93
|
+
assert (association_activation is None) or (type(association_activation) == str)
|
|
94
|
+
|
|
95
|
+
# Initialise Hopfield association module.
|
|
96
|
+
self.association_core = HopfieldCore(
|
|
97
|
+
embed_dim=input_size, num_heads=num_heads, dropout=dropout, bias=input_bias,
|
|
98
|
+
add_bias_kv=concat_bias_pattern, add_zero_attn=add_zero_association, kdim=stored_pattern_size,
|
|
99
|
+
vdim=pattern_projection_size, head_dim=hidden_size, pattern_dim=pattern_size, out_dim=output_size,
|
|
100
|
+
disable_out_projection=disable_out_projection, key_as_static=stored_pattern_as_static,
|
|
101
|
+
query_as_static=state_pattern_as_static, value_as_static=pattern_projection_as_static,
|
|
102
|
+
value_as_connected=pattern_projection_as_connected, normalize_pattern=normalize_hopfield_space,
|
|
103
|
+
normalize_pattern_affine=normalize_hopfield_space_affine,
|
|
104
|
+
normalize_pattern_eps=normalize_hopfield_space_eps)
|
|
105
|
+
self.association_activation = None
|
|
106
|
+
if association_activation is not None:
|
|
107
|
+
self.association_activation = getattr(torch, association_activation, None)
|
|
108
|
+
|
|
109
|
+
# Initialise stored pattern normalization.
|
|
110
|
+
self.norm_stored_pattern = None
|
|
111
|
+
if normalize_stored_pattern_affine:
|
|
112
|
+
assert normalize_stored_pattern, "affine normalization without normalization has no effect."
|
|
113
|
+
if normalize_stored_pattern:
|
|
114
|
+
normalized_shape = input_size if stored_pattern_size is None else stored_pattern_size
|
|
115
|
+
assert normalized_shape is not None, "stored pattern size required for setting up normalisation"
|
|
116
|
+
self.norm_stored_pattern = nn.LayerNorm(
|
|
117
|
+
normalized_shape=normalized_shape, elementwise_affine=normalize_stored_pattern_affine,
|
|
118
|
+
eps=normalize_stored_pattern_eps)
|
|
119
|
+
|
|
120
|
+
# Initialise state pattern normalization.
|
|
121
|
+
self.norm_state_pattern = None
|
|
122
|
+
if normalize_state_pattern_affine:
|
|
123
|
+
assert normalize_state_pattern, "affine normalization without normalization has no effect."
|
|
124
|
+
if normalize_state_pattern:
|
|
125
|
+
assert input_size is not None, "input size required for setting up normalisation"
|
|
126
|
+
self.norm_state_pattern = nn.LayerNorm(
|
|
127
|
+
normalized_shape=input_size, elementwise_affine=normalize_state_pattern_affine,
|
|
128
|
+
eps=normalize_state_pattern_eps)
|
|
129
|
+
|
|
130
|
+
# Initialise pattern projection normalization.
|
|
131
|
+
self.norm_pattern_projection = None
|
|
132
|
+
if normalize_pattern_projection_affine:
|
|
133
|
+
assert normalize_pattern_projection, "affine normalization without normalization has no effect."
|
|
134
|
+
if normalize_pattern_projection:
|
|
135
|
+
normalized_shape = input_size if pattern_projection_size is None else pattern_projection_size
|
|
136
|
+
assert normalized_shape is not None, "pattern projection size required for setting up normalisation"
|
|
137
|
+
self.norm_pattern_projection = nn.LayerNorm(
|
|
138
|
+
normalized_shape=normalized_shape, elementwise_affine=normalize_pattern_projection_affine,
|
|
139
|
+
eps=normalize_pattern_projection_eps)
|
|
140
|
+
|
|
141
|
+
# Initialise remaining auxiliary properties.
|
|
142
|
+
if self.association_core.static_execution:
|
|
143
|
+
self.__scaling = 1.0 if scaling is None else scaling
|
|
144
|
+
else:
|
|
145
|
+
assert self.association_core.head_dim > 0, f'invalid hidden dimension encountered.'
|
|
146
|
+
self.__scaling = (1.0 / sqrt(self.association_core.head_dim)) if scaling is None else scaling
|
|
147
|
+
self.__batch_first = batch_first
|
|
148
|
+
self.__update_steps_max = update_steps_max
|
|
149
|
+
self.__update_steps_eps = update_steps_eps
|
|
150
|
+
self.reset_parameters()
|
|
151
|
+
|
|
152
|
+
def reset_parameters(self) -> None:
|
|
153
|
+
"""
|
|
154
|
+
Reset Hopfield association.
|
|
155
|
+
|
|
156
|
+
:return: None
|
|
157
|
+
"""
|
|
158
|
+
for module in (self.association_core, self.norm_stored_pattern,
|
|
159
|
+
self.norm_state_pattern, self.norm_pattern_projection):
|
|
160
|
+
if hasattr(module, r'reset_parameters'):
|
|
161
|
+
module.reset_parameters()
|
|
162
|
+
|
|
163
|
+
def _maybe_transpose(self, *args: Tuple[Tensor, ...]) -> Union[Tensor, Tuple[Tensor, ...]]:
|
|
164
|
+
"""
|
|
165
|
+
Eventually transpose specified data.
|
|
166
|
+
|
|
167
|
+
:param args: tensors to eventually transpose (dependent on the state of "batch_first")
|
|
168
|
+
:return: eventually transposed tensors
|
|
169
|
+
"""
|
|
170
|
+
transposed_result = tuple(_.transpose(0, 1) for _ in args) if self.__batch_first else args
|
|
171
|
+
return transposed_result[0] if len(transposed_result) == 1 else transposed_result
|
|
172
|
+
|
|
173
|
+
def _associate(self, data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
|
|
174
|
+
return_raw_associations: bool = False, return_projected_patterns: bool = False,
|
|
175
|
+
stored_pattern_padding_mask: Optional[Tensor] = None,
|
|
176
|
+
association_mask: Optional[Tensor] = None) -> Tuple[Optional[Tensor], ...]:
|
|
177
|
+
"""
|
|
178
|
+
Apply Hopfield association module on specified data.
|
|
179
|
+
|
|
180
|
+
:param data: data to be processed by Hopfield core module
|
|
181
|
+
:param return_raw_associations: return raw association (softmax) values, unmodified
|
|
182
|
+
:param return_projected_patterns: return pattern projection values, unmodified
|
|
183
|
+
:param stored_pattern_padding_mask: mask to be applied on stored patterns
|
|
184
|
+
:param association_mask: mask to be applied on inner association matrix
|
|
185
|
+
:return: Hopfield-processed input data
|
|
186
|
+
"""
|
|
187
|
+
assert (type(data) == Tensor) or ((type(data) == tuple) and (len(data) == 3)), \
|
|
188
|
+
r'either one tensor to be used as "stored pattern", "state pattern" and' \
|
|
189
|
+
r' "pattern_projection" must be provided, or three separate ones.'
|
|
190
|
+
if type(data) == Tensor:
|
|
191
|
+
stored_pattern, state_pattern, pattern_projection = data, data, data
|
|
192
|
+
else:
|
|
193
|
+
stored_pattern, state_pattern, pattern_projection = data
|
|
194
|
+
|
|
195
|
+
# Optionally transpose data.
|
|
196
|
+
stored_pattern, state_pattern, pattern_projection = self._maybe_transpose(
|
|
197
|
+
stored_pattern, state_pattern, pattern_projection)
|
|
198
|
+
|
|
199
|
+
# Optionally apply stored pattern normalization.
|
|
200
|
+
if self.norm_stored_pattern is not None:
|
|
201
|
+
stored_pattern = self.norm_stored_pattern(input=stored_pattern.reshape(
|
|
202
|
+
shape=(-1, stored_pattern.shape[2]))).reshape(shape=stored_pattern.shape)
|
|
203
|
+
|
|
204
|
+
# Optionally apply state pattern normalization.
|
|
205
|
+
if self.norm_state_pattern is not None:
|
|
206
|
+
state_pattern = self.norm_state_pattern(input=state_pattern.reshape(
|
|
207
|
+
shape=(-1, state_pattern.shape[2]))).reshape(shape=state_pattern.shape)
|
|
208
|
+
|
|
209
|
+
# Optionally apply pattern projection normalization.
|
|
210
|
+
if self.norm_pattern_projection is not None:
|
|
211
|
+
pattern_projection = self.norm_pattern_projection(input=pattern_projection.reshape(
|
|
212
|
+
shape=(-1, pattern_projection.shape[2]))).reshape(shape=pattern_projection.shape)
|
|
213
|
+
|
|
214
|
+
# Apply Hopfield association and optional activation function.
|
|
215
|
+
return self.association_core(
|
|
216
|
+
query=state_pattern, key=stored_pattern, value=pattern_projection,
|
|
217
|
+
key_padding_mask=stored_pattern_padding_mask, need_weights=False, attn_mask=association_mask,
|
|
218
|
+
scaling=self.__scaling, update_steps_max=self.__update_steps_max, update_steps_eps=self.__update_steps_eps,
|
|
219
|
+
return_raw_associations=return_raw_associations, return_pattern_projections=return_projected_patterns)
|
|
220
|
+
|
|
221
|
+
def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
|
|
222
|
+
stored_pattern_padding_mask: Optional[Tensor] = None,
|
|
223
|
+
association_mask: Optional[Tensor] = None) -> Tensor:
|
|
224
|
+
"""
|
|
225
|
+
Apply Hopfield association on specified data.
|
|
226
|
+
|
|
227
|
+
:param input: data to be processed by Hopfield association module
|
|
228
|
+
:param stored_pattern_padding_mask: mask to be applied on stored patterns
|
|
229
|
+
:param association_mask: mask to be applied on inner association matrix
|
|
230
|
+
:return: Hopfield-processed input data
|
|
231
|
+
"""
|
|
232
|
+
association_output = self._maybe_transpose(self._associate(
|
|
233
|
+
data=input, return_raw_associations=False,
|
|
234
|
+
stored_pattern_padding_mask=stored_pattern_padding_mask,
|
|
235
|
+
association_mask=association_mask)[0])
|
|
236
|
+
if self.association_activation is not None:
|
|
237
|
+
association_output = self.association_activation(association_output)
|
|
238
|
+
return association_output
|
|
239
|
+
|
|
240
|
+
def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
|
|
241
|
+
stored_pattern_padding_mask: Optional[Tensor] = None,
|
|
242
|
+
association_mask: Optional[Tensor] = None) -> Tensor:
|
|
243
|
+
"""
|
|
244
|
+
Fetch Hopfield association matrix gathered by passing through the specified data.
|
|
245
|
+
|
|
246
|
+
:param input: data to be passed through the Hopfield association
|
|
247
|
+
:param stored_pattern_padding_mask: mask to be applied on stored patterns
|
|
248
|
+
:param association_mask: mask to be applied on inner association matrix
|
|
249
|
+
:return: association matrix as computed by the Hopfield core module
|
|
250
|
+
"""
|
|
251
|
+
with torch.no_grad():
|
|
252
|
+
return self._associate(
|
|
253
|
+
data=input, return_raw_associations=True,
|
|
254
|
+
stored_pattern_padding_mask=stored_pattern_padding_mask,
|
|
255
|
+
association_mask=association_mask)[2]
|
|
256
|
+
|
|
257
|
+
def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
|
|
258
|
+
stored_pattern_padding_mask: Optional[Tensor] = None,
|
|
259
|
+
association_mask: Optional[Tensor] = None) -> Tensor:
|
|
260
|
+
"""
|
|
261
|
+
Fetch Hopfield projected pattern matrix gathered by passing through the specified data.
|
|
262
|
+
|
|
263
|
+
:param input: data to be passed through the Hopfield association
|
|
264
|
+
:param stored_pattern_padding_mask: mask to be applied on stored patterns
|
|
265
|
+
:param association_mask: mask to be applied on inner association matrix
|
|
266
|
+
:return: pattern projection matrix as computed by the Hopfield core module
|
|
267
|
+
"""
|
|
268
|
+
with torch.no_grad():
|
|
269
|
+
return self._associate(
|
|
270
|
+
data=input, return_projected_patterns=True,
|
|
271
|
+
stored_pattern_padding_mask=stored_pattern_padding_mask,
|
|
272
|
+
association_mask=association_mask)[3]
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def batch_first(self) -> bool:
|
|
276
|
+
return self.__batch_first
|
|
277
|
+
|
|
278
|
+
@property
|
|
279
|
+
def scaling(self) -> Union[float, Tensor]:
|
|
280
|
+
return self.__scaling.clone() if type(self.__scaling) == Tensor else self.__scaling
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def stored_pattern_dim(self) -> Optional[int]:
|
|
284
|
+
return self.association_core.kdim
|
|
285
|
+
|
|
286
|
+
@property
|
|
287
|
+
def state_pattern_dim(self) -> Optional[int]:
|
|
288
|
+
return self.association_core.embed_dim
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def pattern_projection_dim(self) -> Optional[int]:
|
|
292
|
+
return self.association_core.vdim
|
|
293
|
+
|
|
294
|
+
@property
|
|
295
|
+
def input_size(self) -> Optional[int]:
|
|
296
|
+
return self.state_pattern_dim
|
|
297
|
+
|
|
298
|
+
@property
|
|
299
|
+
def hidden_size(self) -> Optional[int]:
|
|
300
|
+
return self.association_core.head_dim
|
|
301
|
+
|
|
302
|
+
@property
|
|
303
|
+
def output_size(self) -> Optional[int]:
|
|
304
|
+
return self.association_core.out_dim
|
|
305
|
+
|
|
306
|
+
@property
|
|
307
|
+
def pattern_size(self) -> Optional[int]:
|
|
308
|
+
return self.association_core.pattern_dim
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def update_steps_max(self) -> Optional[Union[int, Tensor]]:
|
|
312
|
+
return self.__update_steps_max.clone() if type(self.__update_steps_max) == Tensor else self.__update_steps_max
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def update_steps_eps(self) -> Optional[Union[float, Tensor]]:
|
|
316
|
+
return self.__update_steps_eps.clone() if type(self.__update_steps_eps) == Tensor else self.__update_steps_eps
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def stored_pattern_as_static(self) -> bool:
|
|
320
|
+
return self.association_core.key_as_static
|
|
321
|
+
|
|
322
|
+
@property
|
|
323
|
+
def state_pattern_as_static(self) -> bool:
|
|
324
|
+
return self.association_core.query_as_static
|
|
325
|
+
|
|
326
|
+
@property
|
|
327
|
+
def pattern_projection_as_static(self) -> bool:
|
|
328
|
+
return self.association_core.value_as_static
|
|
329
|
+
|
|
330
|
+
@property
|
|
331
|
+
def normalize_stored_pattern(self) -> bool:
|
|
332
|
+
return self.norm_stored_pattern is not None
|
|
333
|
+
|
|
334
|
+
@property
|
|
335
|
+
def normalize_stored_pattern_affine(self) -> bool:
|
|
336
|
+
return self.normalize_stored_pattern and self.norm_stored_pattern.elementwise_affine
|
|
337
|
+
|
|
338
|
+
@property
|
|
339
|
+
def normalize_state_pattern(self) -> bool:
|
|
340
|
+
return self.norm_state_pattern is not None
|
|
341
|
+
|
|
342
|
+
@property
|
|
343
|
+
def normalize_state_pattern_affine(self) -> bool:
|
|
344
|
+
return self.normalize_state_pattern and self.norm_state_pattern.elementwise_affine
|
|
345
|
+
|
|
346
|
+
@property
|
|
347
|
+
def normalize_pattern_projection(self) -> bool:
|
|
348
|
+
return self.norm_pattern_projection is not None
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def normalize_pattern_projection_affine(self) -> bool:
|
|
352
|
+
return self.normalize_pattern_projection and self.norm_pattern_projection.elementwise_affine
|
|
353
|
+
|
|
354
|
+
@property
|
|
355
|
+
def normalize_hopfield_space(self) -> bool:
|
|
356
|
+
return self.association_core.normalize_pattern
|
|
357
|
+
|
|
358
|
+
@property
|
|
359
|
+
def normalize_hopfield_space_affine(self) -> bool:
|
|
360
|
+
return self.association_core.normalize_pattern_affine
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class HopfieldPooling(Module):
|
|
364
|
+
"""
|
|
365
|
+
Wrapper class encapsulating a trainable but fixed state pattern and "Hopfield" in
|
|
366
|
+
one combined module to be used as a Hopfield-based pooling layer.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
def __init__(self,
|
|
370
|
+
input_size: int,
|
|
371
|
+
hidden_size: Optional[int] = None,
|
|
372
|
+
output_size: Optional[int] = None,
|
|
373
|
+
pattern_size: Optional[int] = None,
|
|
374
|
+
num_heads: int = 1,
|
|
375
|
+
scaling: Optional[Union[float, Tensor]] = None,
|
|
376
|
+
update_steps_max: Optional[Union[int, Tensor]] = 0,
|
|
377
|
+
update_steps_eps: Union[float, Tensor] = 1e-4,
|
|
378
|
+
|
|
379
|
+
normalize_stored_pattern: bool = True,
|
|
380
|
+
normalize_stored_pattern_affine: bool = True,
|
|
381
|
+
normalize_state_pattern: bool = True,
|
|
382
|
+
normalize_state_pattern_affine: bool = True,
|
|
383
|
+
normalize_pattern_projection: bool = True,
|
|
384
|
+
normalize_pattern_projection_affine: bool = True,
|
|
385
|
+
normalize_hopfield_space: bool = False,
|
|
386
|
+
normalize_hopfield_space_affine: bool = False,
|
|
387
|
+
stored_pattern_as_static: bool = False,
|
|
388
|
+
state_pattern_as_static: bool = False,
|
|
389
|
+
pattern_projection_as_static: bool = False,
|
|
390
|
+
pattern_projection_as_connected: bool = False,
|
|
391
|
+
stored_pattern_size: Optional[int] = None,
|
|
392
|
+
pattern_projection_size: Optional[int] = None,
|
|
393
|
+
|
|
394
|
+
batch_first: bool = True,
|
|
395
|
+
association_activation: Optional[str] = None,
|
|
396
|
+
dropout: float = 0.0,
|
|
397
|
+
input_bias: bool = True,
|
|
398
|
+
concat_bias_pattern: bool = False,
|
|
399
|
+
add_zero_association: bool = False,
|
|
400
|
+
disable_out_projection: bool = False,
|
|
401
|
+
quantity: int = 1,
|
|
402
|
+
trainable: bool = True
|
|
403
|
+
):
|
|
404
|
+
"""
|
|
405
|
+
Initialise a new instance of a Hopfield-based pooling layer.
|
|
406
|
+
|
|
407
|
+
:param input_size: depth of the input (state pattern)
|
|
408
|
+
:param hidden_size: depth of the association space
|
|
409
|
+
:param output_size: depth of the output projection
|
|
410
|
+
:param pattern_size: depth of patterns to be selected
|
|
411
|
+
:param num_heads: amount of parallel association heads
|
|
412
|
+
:param scaling: scaling of association heads, often represented as beta (one entry per head)
|
|
413
|
+
:param update_steps_max: maximum count of association update steps (None equals to infinity)
|
|
414
|
+
:param update_steps_eps: minimum difference threshold between two consecutive association update steps
|
|
415
|
+
:param normalize_stored_pattern: apply normalization on stored patterns
|
|
416
|
+
:param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns
|
|
417
|
+
:param normalize_state_pattern: apply normalization on state patterns
|
|
418
|
+
:param normalize_state_pattern_affine: additionally enable affine normalization of state patterns
|
|
419
|
+
:param normalize_pattern_projection: apply normalization on the pattern projection
|
|
420
|
+
:param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection
|
|
421
|
+
:param normalize_hopfield_space: enable normalization of patterns in the Hopfield space
|
|
422
|
+
:param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space
|
|
423
|
+
:param stored_pattern_as_static: interpret specified stored patterns as being static
|
|
424
|
+
:param state_pattern_as_static: interpret specified state patterns as being static
|
|
425
|
+
:param pattern_projection_as_static: interpret specified pattern projections as being static
|
|
426
|
+
:param pattern_projection_as_connected: connect pattern projection with stored pattern
|
|
427
|
+
:param stored_pattern_size: depth of input (stored pattern)
|
|
428
|
+
:param pattern_projection_size: depth of input (pattern projection)
|
|
429
|
+
:param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size
|
|
430
|
+
:param association_activation: additional activation to be applied on the result of the Hopfield association
|
|
431
|
+
:param dropout: dropout probability applied on the association matrix
|
|
432
|
+
:param input_bias: bias to be added to input (state and stored pattern as well as pattern projection)
|
|
433
|
+
:param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection
|
|
434
|
+
:param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection
|
|
435
|
+
:param disable_out_projection: disable output projection
|
|
436
|
+
:param quantity: amount of state patterns
|
|
437
|
+
:param trainable: state pattern used for pooling is trainable
|
|
438
|
+
"""
|
|
439
|
+
super(HopfieldPooling, self).__init__()
|
|
440
|
+
self.hopfield = Hopfield(
|
|
441
|
+
input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size,
|
|
442
|
+
num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
|
|
443
|
+
normalize_stored_pattern=normalize_stored_pattern,
|
|
444
|
+
normalize_stored_pattern_affine=normalize_stored_pattern_affine,
|
|
445
|
+
normalize_state_pattern=normalize_state_pattern,
|
|
446
|
+
normalize_state_pattern_affine=normalize_state_pattern_affine,
|
|
447
|
+
normalize_pattern_projection=normalize_pattern_projection,
|
|
448
|
+
normalize_pattern_projection_affine=normalize_pattern_projection_affine,
|
|
449
|
+
normalize_hopfield_space=normalize_hopfield_space,
|
|
450
|
+
normalize_hopfield_space_affine=normalize_hopfield_space_affine,
|
|
451
|
+
stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static,
|
|
452
|
+
pattern_projection_as_static=pattern_projection_as_static,
|
|
453
|
+
pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size,
|
|
454
|
+
pattern_projection_size=pattern_projection_size, batch_first=batch_first,
|
|
455
|
+
association_activation=association_activation, dropout=dropout, input_bias=input_bias,
|
|
456
|
+
concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association,
|
|
457
|
+
disable_out_projection=disable_out_projection)
|
|
458
|
+
self._quantity = quantity
|
|
459
|
+
pooling_weight_size = self.hopfield.hidden_size if state_pattern_as_static else self.hopfield.input_size
|
|
460
|
+
self.pooling_weights = nn.Parameter(torch.empty(size=(*(
|
|
461
|
+
(1, quantity) if batch_first else (quantity, 1)
|
|
462
|
+
), input_size if pooling_weight_size is None else pooling_weight_size)), requires_grad=trainable)
|
|
463
|
+
self.reset_parameters()
|
|
464
|
+
|
|
465
|
+
def reset_parameters(self) -> None:
|
|
466
|
+
"""
|
|
467
|
+
Reset pooling weights and underlying Hopfield association.
|
|
468
|
+
|
|
469
|
+
:return: None
|
|
470
|
+
"""
|
|
471
|
+
if hasattr(self.hopfield, r'reset_parameters'):
|
|
472
|
+
self.hopfield.reset_parameters()
|
|
473
|
+
|
|
474
|
+
# Explicitly initialise pooling weights.
|
|
475
|
+
nn.init.normal_(self.pooling_weights, mean=0.0, std=0.02)
|
|
476
|
+
|
|
477
|
+
def _prepare_input(self, input: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]:
|
|
478
|
+
"""
|
|
479
|
+
Prepare input for Hopfield association.
|
|
480
|
+
|
|
481
|
+
:param input: data to be prepared
|
|
482
|
+
:return: stored pattern, expanded state pattern as well as pattern projection
|
|
483
|
+
"""
|
|
484
|
+
assert (type(input) == Tensor) or ((type(input) == tuple) and (len(input) == 2)), \
|
|
485
|
+
r'either one tensor to be used as "stored pattern" and' \
|
|
486
|
+
r' "pattern_projection" must be provided, or two separate ones.'
|
|
487
|
+
if type(input) == Tensor:
|
|
488
|
+
stored_pattern, pattern_projection = input, input
|
|
489
|
+
else:
|
|
490
|
+
stored_pattern, pattern_projection = input
|
|
491
|
+
|
|
492
|
+
batch_size = stored_pattern.shape[0 if self.batch_first else 1]
|
|
493
|
+
return stored_pattern, self.pooling_weights.expand(size=(*(
|
|
494
|
+
(batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size)
|
|
495
|
+
), self.pooling_weights.shape[2])), pattern_projection
|
|
496
|
+
|
|
497
|
+
def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor]], stored_pattern_padding_mask: Optional[Tensor] = None,
|
|
498
|
+
association_mask: Optional[Tensor] = None) -> Tensor:
|
|
499
|
+
"""
|
|
500
|
+
Compute Hopfield-based pooling on specified data.
|
|
501
|
+
|
|
502
|
+
:param input: data to be pooled
|
|
503
|
+
:param stored_pattern_padding_mask: mask to be applied on stored patterns
|
|
504
|
+
:param association_mask: mask to be applied on inner association matrix
|
|
505
|
+
:return: Hopfield-pooled input data
|
|
506
|
+
"""
|
|
507
|
+
return self.hopfield(
|
|
508
|
+
input=self._prepare_input(input=input),
|
|
509
|
+
stored_pattern_padding_mask=stored_pattern_padding_mask,
|
|
510
|
+
association_mask=association_mask).flatten(start_dim=1)
|
|
511
|
+
|
|
512
|
+
def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
513
|
+
stored_pattern_padding_mask: Optional[Tensor] = None,
|
|
514
|
+
association_mask: Optional[Tensor] = None) -> Tensor:
|
|
515
|
+
"""
|
|
516
|
+
Fetch Hopfield association matrix used for pooling gathered by passing through the specified data.
|
|
517
|
+
|
|
518
|
+
:param input: data to be passed through the Hopfield association
|
|
519
|
+
:param stored_pattern_padding_mask: mask to be applied on stored patterns
|
|
520
|
+
:param association_mask: mask to be applied on inner association matrix
|
|
521
|
+
:return: association matrix as computed by the Hopfield core module
|
|
522
|
+
"""
|
|
523
|
+
with torch.no_grad():
|
|
524
|
+
return self.hopfield.get_association_matrix(
|
|
525
|
+
input=self._prepare_input(input=input),
|
|
526
|
+
stored_pattern_padding_mask=stored_pattern_padding_mask,
|
|
527
|
+
association_mask=association_mask)
|
|
528
|
+
|
|
529
|
+
def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
530
|
+
stored_pattern_padding_mask: Optional[Tensor] = None,
|
|
531
|
+
association_mask: Optional[Tensor] = None) -> Tensor:
|
|
532
|
+
"""
|
|
533
|
+
Fetch Hopfield projected pattern matrix gathered by passing through the specified data.
|
|
534
|
+
|
|
535
|
+
:param input: data to be passed through the Hopfield association
|
|
536
|
+
:param stored_pattern_padding_mask: mask to be applied on stored patterns
|
|
537
|
+
:param association_mask: mask to be applied on inner association matrix
|
|
538
|
+
:return: pattern projection matrix as computed by the Hopfield core module
|
|
539
|
+
"""
|
|
540
|
+
with torch.no_grad():
|
|
541
|
+
return self.hopfield.get_projected_pattern_matrix(
|
|
542
|
+
input=self._prepare_input(input=input),
|
|
543
|
+
stored_pattern_padding_mask=stored_pattern_padding_mask,
|
|
544
|
+
association_mask=association_mask)
|
|
545
|
+
|
|
546
|
+
@property
|
|
547
|
+
def batch_first(self) -> bool:
|
|
548
|
+
return self.hopfield.batch_first
|
|
549
|
+
|
|
550
|
+
@property
|
|
551
|
+
def scaling(self) -> Union[float, Tensor]:
|
|
552
|
+
return self.hopfield.scaling
|
|
553
|
+
|
|
554
|
+
@property
|
|
555
|
+
def stored_pattern_dim(self) -> Optional[int]:
|
|
556
|
+
return self.hopfield.stored_pattern_dim
|
|
557
|
+
|
|
558
|
+
@property
|
|
559
|
+
def state_pattern_dim(self) -> Optional[int]:
|
|
560
|
+
return self.hopfield.state_pattern_dim
|
|
561
|
+
|
|
562
|
+
@property
|
|
563
|
+
def pattern_projection_dim(self) -> Optional[int]:
|
|
564
|
+
return self.hopfield.pattern_projection_dim
|
|
565
|
+
|
|
566
|
+
@property
|
|
567
|
+
def input_size(self) -> Optional[int]:
|
|
568
|
+
return self.hopfield.input_size
|
|
569
|
+
|
|
570
|
+
@property
|
|
571
|
+
def hidden_size(self) -> int:
|
|
572
|
+
return self.hopfield.hidden_size
|
|
573
|
+
|
|
574
|
+
@property
|
|
575
|
+
def output_size(self) -> Optional[int]:
|
|
576
|
+
return self.hopfield.output_size
|
|
577
|
+
|
|
578
|
+
@property
|
|
579
|
+
def pattern_size(self) -> Optional[int]:
|
|
580
|
+
return self.hopfield.pattern_size
|
|
581
|
+
|
|
582
|
+
@property
|
|
583
|
+
def quantity(self) -> int:
|
|
584
|
+
return self._quantity
|
|
585
|
+
|
|
586
|
+
@property
|
|
587
|
+
def update_steps_max(self) -> Optional[Union[int, Tensor]]:
|
|
588
|
+
return self.hopfield.update_steps_max
|
|
589
|
+
|
|
590
|
+
@property
|
|
591
|
+
def update_steps_eps(self) -> Optional[Union[float, Tensor]]:
|
|
592
|
+
return self.hopfield.update_steps_eps
|
|
593
|
+
|
|
594
|
+
@property
|
|
595
|
+
def stored_pattern_as_static(self) -> bool:
|
|
596
|
+
return self.hopfield.stored_pattern_as_static
|
|
597
|
+
|
|
598
|
+
@property
|
|
599
|
+
def state_pattern_as_static(self) -> bool:
|
|
600
|
+
return self.hopfield.state_pattern_as_static
|
|
601
|
+
|
|
602
|
+
@property
|
|
603
|
+
def pattern_projection_as_static(self) -> bool:
|
|
604
|
+
return self.hopfield.pattern_projection_as_static
|
|
605
|
+
|
|
606
|
+
@property
|
|
607
|
+
def normalize_stored_pattern(self) -> bool:
|
|
608
|
+
return self.hopfield.normalize_stored_pattern
|
|
609
|
+
|
|
610
|
+
@property
|
|
611
|
+
def normalize_stored_pattern_affine(self) -> bool:
|
|
612
|
+
return self.hopfield.normalize_stored_pattern_affine
|
|
613
|
+
|
|
614
|
+
@property
|
|
615
|
+
def normalize_state_pattern(self) -> bool:
|
|
616
|
+
return self.hopfield.normalize_state_pattern
|
|
617
|
+
|
|
618
|
+
@property
|
|
619
|
+
def normalize_state_pattern_affine(self) -> bool:
|
|
620
|
+
return self.hopfield.normalize_state_pattern_affine
|
|
621
|
+
|
|
622
|
+
@property
|
|
623
|
+
def normalize_pattern_projection(self) -> bool:
|
|
624
|
+
return self.hopfield.normalize_pattern_projection
|
|
625
|
+
|
|
626
|
+
@property
|
|
627
|
+
def normalize_pattern_projection_affine(self) -> bool:
|
|
628
|
+
return self.hopfield.normalize_pattern_projection_affine
|
|
629
|
+
|
|
630
|
+
@property
|
|
631
|
+
def normalize_hopfield_space(self) -> bool:
|
|
632
|
+
return self.hopfield.normalize_hopfield_space
|
|
633
|
+
|
|
634
|
+
@property
|
|
635
|
+
def normalize_hopfield_space_affine(self) -> bool:
|
|
636
|
+
return self.hopfield.normalize_hopfield_space_affine
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
class HopfieldLayer(Module):
|
|
640
|
+
"""
|
|
641
|
+
Wrapper class encapsulating a trainable but fixed stored pattern, pattern projection and "Hopfield" in
|
|
642
|
+
one combined module to be used as a Hopfield-based pooling layer.
|
|
643
|
+
"""
|
|
644
|
+
|
|
645
|
+
def __init__(self,
|
|
646
|
+
input_size: int,
|
|
647
|
+
hidden_size: Optional[int] = None,
|
|
648
|
+
output_size: Optional[int] = None,
|
|
649
|
+
pattern_size: Optional[int] = None,
|
|
650
|
+
num_heads: int = 1,
|
|
651
|
+
scaling: Optional[Union[float, Tensor]] = None,
|
|
652
|
+
update_steps_max: Optional[Union[int, Tensor]] = 0,
|
|
653
|
+
update_steps_eps: Union[float, Tensor] = 1e-4,
|
|
654
|
+
lookup_weights_as_separated: bool = False,
|
|
655
|
+
lookup_targets_as_trainable: bool = True,
|
|
656
|
+
|
|
657
|
+
normalize_stored_pattern: bool = True,
|
|
658
|
+
normalize_stored_pattern_affine: bool = True,
|
|
659
|
+
normalize_state_pattern: bool = True,
|
|
660
|
+
normalize_state_pattern_affine: bool = True,
|
|
661
|
+
normalize_pattern_projection: bool = True,
|
|
662
|
+
normalize_pattern_projection_affine: bool = True,
|
|
663
|
+
normalize_hopfield_space: bool = False,
|
|
664
|
+
normalize_hopfield_space_affine: bool = False,
|
|
665
|
+
stored_pattern_as_static: bool = False,
|
|
666
|
+
state_pattern_as_static: bool = False,
|
|
667
|
+
pattern_projection_as_static: bool = False,
|
|
668
|
+
pattern_projection_as_connected: bool = False,
|
|
669
|
+
stored_pattern_size: Optional[int] = None,
|
|
670
|
+
pattern_projection_size: Optional[int] = None,
|
|
671
|
+
|
|
672
|
+
batch_first: bool = True,
|
|
673
|
+
association_activation: Optional[str] = None,
|
|
674
|
+
dropout: float = 0.0,
|
|
675
|
+
input_bias: bool = True,
|
|
676
|
+
concat_bias_pattern: bool = False,
|
|
677
|
+
add_zero_association: bool = False,
|
|
678
|
+
disable_out_projection: bool = False,
|
|
679
|
+
quantity: int = 1,
|
|
680
|
+
trainable: bool = True
|
|
681
|
+
):
|
|
682
|
+
"""
|
|
683
|
+
Initialise a new instance of a Hopfield-based lookup layer.
|
|
684
|
+
|
|
685
|
+
:param input_size: depth of the input (state pattern)
|
|
686
|
+
:param hidden_size: depth of the association space
|
|
687
|
+
:param output_size: depth of the output projection
|
|
688
|
+
:param pattern_size: depth of patterns to be selected
|
|
689
|
+
:param num_heads: amount of parallel association heads
|
|
690
|
+
:param scaling: scaling of association heads, often represented as beta (one entry per head)
|
|
691
|
+
:param update_steps_max: maximum count of association update steps (None equals to infinity)
|
|
692
|
+
:param update_steps_eps: minimum difference threshold between two consecutive association update steps
|
|
693
|
+
:param lookup_weights_as_separated: separate lookup weights from lookup target weights
|
|
694
|
+
:param lookup_targets_as_trainable: employ trainable lookup target weights (used as pattern projection input)
|
|
695
|
+
:param normalize_stored_pattern: apply normalization on stored patterns
|
|
696
|
+
:param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns
|
|
697
|
+
:param normalize_state_pattern: apply normalization on state patterns
|
|
698
|
+
:param normalize_state_pattern_affine: additionally enable affine normalization of state patterns
|
|
699
|
+
:param normalize_pattern_projection: apply normalization on the pattern projection
|
|
700
|
+
:param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection
|
|
701
|
+
:param normalize_hopfield_space: enable normalization of patterns in the Hopfield space
|
|
702
|
+
:param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space
|
|
703
|
+
:param stored_pattern_as_static: interpret specified stored patterns as being static
|
|
704
|
+
:param state_pattern_as_static: interpret specified state patterns as being static
|
|
705
|
+
:param pattern_projection_as_static: interpret specified pattern projections as being static
|
|
706
|
+
:param pattern_projection_as_connected: connect pattern projection with stored pattern
|
|
707
|
+
:param stored_pattern_size: depth of input (stored pattern)
|
|
708
|
+
:param pattern_projection_size: depth of input (pattern projection)
|
|
709
|
+
:param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size
|
|
710
|
+
:param association_activation: additional activation to be applied on the result of the Hopfield association
|
|
711
|
+
:param dropout: dropout probability applied on the association matrix
|
|
712
|
+
:param input_bias: bias to be added to input (state and stored pattern as well as pattern projection)
|
|
713
|
+
:param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection
|
|
714
|
+
:param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection
|
|
715
|
+
:param disable_out_projection: disable output projection
|
|
716
|
+
:param quantity: amount of stored patterns
|
|
717
|
+
:param trainable: stored pattern used for lookup is trainable
|
|
718
|
+
"""
|
|
719
|
+
super(HopfieldLayer, self).__init__()
|
|
720
|
+
self.hopfield = Hopfield(
|
|
721
|
+
input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size,
|
|
722
|
+
num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
|
|
723
|
+
normalize_stored_pattern=normalize_stored_pattern,
|
|
724
|
+
normalize_stored_pattern_affine=normalize_stored_pattern_affine,
|
|
725
|
+
normalize_state_pattern=normalize_state_pattern,
|
|
726
|
+
normalize_state_pattern_affine=normalize_state_pattern_affine,
|
|
727
|
+
normalize_pattern_projection=normalize_pattern_projection,
|
|
728
|
+
normalize_pattern_projection_affine=normalize_pattern_projection_affine,
|
|
729
|
+
normalize_hopfield_space=normalize_hopfield_space,
|
|
730
|
+
normalize_hopfield_space_affine=normalize_hopfield_space_affine,
|
|
731
|
+
stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static,
|
|
732
|
+
pattern_projection_as_static=pattern_projection_as_static,
|
|
733
|
+
pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size,
|
|
734
|
+
pattern_projection_size=pattern_projection_size, batch_first=batch_first,
|
|
735
|
+
association_activation=association_activation, dropout=dropout, input_bias=input_bias,
|
|
736
|
+
concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association,
|
|
737
|
+
disable_out_projection=disable_out_projection)
|
|
738
|
+
self._quantity = quantity
|
|
739
|
+
lookup_weight_size = self.hopfield.hidden_size if stored_pattern_as_static else self.hopfield.stored_pattern_dim
|
|
740
|
+
self.lookup_weights = nn.Parameter(torch.empty(size=(*(
|
|
741
|
+
(1, quantity) if batch_first else (quantity, 1)
|
|
742
|
+
), input_size if lookup_weight_size is None else lookup_weight_size)), requires_grad=trainable)
|
|
743
|
+
|
|
744
|
+
if lookup_weights_as_separated:
|
|
745
|
+
target_weight_size = self.lookup_weights.shape[
|
|
746
|
+
2] if pattern_projection_size is None else pattern_projection_size
|
|
747
|
+
self.target_weights = nn.Parameter(torch.empty(size=(*(
|
|
748
|
+
(1, quantity) if batch_first else (quantity, 1)
|
|
749
|
+
), target_weight_size)), requires_grad=lookup_targets_as_trainable)
|
|
750
|
+
else:
|
|
751
|
+
self.register_parameter(name=r'target_weights', param=None)
|
|
752
|
+
self.reset_parameters()
|
|
753
|
+
|
|
754
|
+
def reset_parameters(self) -> None:
|
|
755
|
+
"""
|
|
756
|
+
Reset lookup and lookup target weights, including underlying Hopfield association.
|
|
757
|
+
|
|
758
|
+
:return: None
|
|
759
|
+
"""
|
|
760
|
+
if hasattr(self.hopfield, r'reset_parameters'):
|
|
761
|
+
self.hopfield.reset_parameters()
|
|
762
|
+
|
|
763
|
+
# Explicitly initialise lookup and target weights.
|
|
764
|
+
nn.init.normal_(self.lookup_weights, mean=0.0, std=0.02)
|
|
765
|
+
if self.target_weights is not None:
|
|
766
|
+
nn.init.normal_(self.target_weights, mean=0.0, std=0.02)
|
|
767
|
+
|
|
768
|
+
def _prepare_input(self, input: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
|
769
|
+
"""
|
|
770
|
+
Prepare input for Hopfield association.
|
|
771
|
+
|
|
772
|
+
:param input: data to be prepared
|
|
773
|
+
:return: stored pattern, expanded state pattern as well as pattern projection
|
|
774
|
+
"""
|
|
775
|
+
batch_size = input.shape[0 if self.batch_first else 1]
|
|
776
|
+
stored_pattern = self.lookup_weights.expand(size=(*(
|
|
777
|
+
(batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size)
|
|
778
|
+
), self.lookup_weights.shape[2]))
|
|
779
|
+
if self.target_weights is None:
|
|
780
|
+
pattern_projection = stored_pattern
|
|
781
|
+
else:
|
|
782
|
+
pattern_projection = self.target_weights.expand(size=(*(
|
|
783
|
+
(batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size)
|
|
784
|
+
), self.target_weights.shape[2]))
|
|
785
|
+
|
|
786
|
+
return stored_pattern, input, pattern_projection
|
|
787
|
+
|
|
788
|
+
def forward(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None,
|
|
789
|
+
association_mask: Optional[Tensor] = None) -> Tensor:
|
|
790
|
+
"""
|
|
791
|
+
Compute Hopfield-based lookup on specified data.
|
|
792
|
+
|
|
793
|
+
:param input: data to used in lookup
|
|
794
|
+
:param stored_pattern_padding_mask: mask to be applied on stored patterns
|
|
795
|
+
:param association_mask: mask to be applied on inner association matrix
|
|
796
|
+
:return: result of Hopfield-based lookup on input data
|
|
797
|
+
"""
|
|
798
|
+
return self.hopfield(
|
|
799
|
+
input=self._prepare_input(input=input),
|
|
800
|
+
stored_pattern_padding_mask=stored_pattern_padding_mask,
|
|
801
|
+
association_mask=association_mask)
|
|
802
|
+
|
|
803
|
+
def get_association_matrix(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None,
|
|
804
|
+
association_mask: Optional[Tensor] = None) -> Tensor:
|
|
805
|
+
"""
|
|
806
|
+
Fetch Hopfield association matrix used for lookup gathered by passing through the specified data.
|
|
807
|
+
|
|
808
|
+
:param input: data to be passed through the Hopfield association
|
|
809
|
+
:param stored_pattern_padding_mask: mask to be applied on stored patterns
|
|
810
|
+
:param association_mask: mask to be applied on inner association matrix
|
|
811
|
+
:return: association matrix as computed by the Hopfield core module
|
|
812
|
+
"""
|
|
813
|
+
with torch.no_grad():
|
|
814
|
+
return self.hopfield.get_association_matrix(
|
|
815
|
+
input=self._prepare_input(input=input),
|
|
816
|
+
stored_pattern_padding_mask=stored_pattern_padding_mask,
|
|
817
|
+
association_mask=association_mask)
|
|
818
|
+
|
|
819
|
+
def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
820
|
+
stored_pattern_padding_mask: Optional[Tensor] = None,
|
|
821
|
+
association_mask: Optional[Tensor] = None) -> Tensor:
|
|
822
|
+
"""
|
|
823
|
+
Fetch Hopfield projected pattern matrix gathered by passing through the specified data.
|
|
824
|
+
|
|
825
|
+
:param input: data to be passed through the Hopfield association
|
|
826
|
+
:param stored_pattern_padding_mask: mask to be applied on stored patterns
|
|
827
|
+
:param association_mask: mask to be applied on inner association matrix
|
|
828
|
+
:return: pattern projection matrix as computed by the Hopfield core module
|
|
829
|
+
"""
|
|
830
|
+
with torch.no_grad():
|
|
831
|
+
return self.hopfield.get_projected_pattern_matrix(
|
|
832
|
+
input=self._prepare_input(input=input),
|
|
833
|
+
stored_pattern_padding_mask=stored_pattern_padding_mask,
|
|
834
|
+
association_mask=association_mask)
|
|
835
|
+
|
|
836
|
+
@property
|
|
837
|
+
def batch_first(self) -> bool:
|
|
838
|
+
return self.hopfield.batch_first
|
|
839
|
+
|
|
840
|
+
@property
|
|
841
|
+
def scaling(self) -> Union[float, Tensor]:
|
|
842
|
+
return self.hopfield.scaling
|
|
843
|
+
|
|
844
|
+
@property
|
|
845
|
+
def stored_pattern_dim(self) -> Optional[int]:
|
|
846
|
+
return self.hopfield.stored_pattern_dim
|
|
847
|
+
|
|
848
|
+
@property
|
|
849
|
+
def state_pattern_dim(self) -> Optional[int]:
|
|
850
|
+
return self.hopfield.state_pattern_dim
|
|
851
|
+
|
|
852
|
+
@property
|
|
853
|
+
def pattern_projection_dim(self) -> Optional[int]:
|
|
854
|
+
return self.hopfield.pattern_projection_dim
|
|
855
|
+
|
|
856
|
+
@property
|
|
857
|
+
def input_size(self) -> Optional[int]:
|
|
858
|
+
return self.hopfield.input_size
|
|
859
|
+
|
|
860
|
+
@property
|
|
861
|
+
def hidden_size(self) -> int:
|
|
862
|
+
return self.hopfield.hidden_size
|
|
863
|
+
|
|
864
|
+
@property
|
|
865
|
+
def output_size(self) -> Optional[int]:
|
|
866
|
+
return self.hopfield.output_size
|
|
867
|
+
|
|
868
|
+
@property
|
|
869
|
+
def pattern_size(self) -> Optional[int]:
|
|
870
|
+
return self.hopfield.pattern_size
|
|
871
|
+
|
|
872
|
+
@property
|
|
873
|
+
def quantity(self) -> int:
|
|
874
|
+
return self._quantity
|
|
875
|
+
|
|
876
|
+
@property
|
|
877
|
+
def update_steps_max(self) -> Optional[Union[int, Tensor]]:
|
|
878
|
+
return self.hopfield.update_steps_max
|
|
879
|
+
|
|
880
|
+
@property
|
|
881
|
+
def update_steps_eps(self) -> Optional[Union[float, Tensor]]:
|
|
882
|
+
return self.hopfield.update_steps_eps
|
|
883
|
+
|
|
884
|
+
@property
|
|
885
|
+
def stored_pattern_as_static(self) -> bool:
|
|
886
|
+
return self.hopfield.stored_pattern_as_static
|
|
887
|
+
|
|
888
|
+
@property
|
|
889
|
+
def state_pattern_as_static(self) -> bool:
|
|
890
|
+
return self.hopfield.state_pattern_as_static
|
|
891
|
+
|
|
892
|
+
@property
|
|
893
|
+
def pattern_projection_as_static(self) -> bool:
|
|
894
|
+
return self.hopfield.pattern_projection_as_static
|
|
895
|
+
|
|
896
|
+
@property
|
|
897
|
+
def normalize_stored_pattern(self) -> bool:
|
|
898
|
+
return self.hopfield.normalize_stored_pattern
|
|
899
|
+
|
|
900
|
+
@property
|
|
901
|
+
def normalize_stored_pattern_affine(self) -> bool:
|
|
902
|
+
return self.hopfield.normalize_stored_pattern_affine
|
|
903
|
+
|
|
904
|
+
@property
|
|
905
|
+
def normalize_state_pattern(self) -> bool:
|
|
906
|
+
return self.hopfield.normalize_state_pattern
|
|
907
|
+
|
|
908
|
+
@property
|
|
909
|
+
def normalize_state_pattern_affine(self) -> bool:
|
|
910
|
+
return self.hopfield.normalize_state_pattern_affine
|
|
911
|
+
|
|
912
|
+
@property
|
|
913
|
+
def normalize_pattern_projection(self) -> bool:
|
|
914
|
+
return self.hopfield.normalize_pattern_projection
|
|
915
|
+
|
|
916
|
+
@property
|
|
917
|
+
def normalize_pattern_projection_affine(self) -> bool:
|
|
918
|
+
return self.hopfield.normalize_pattern_projection_affine
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
# Diffusion-augmented variant -- imported *after* Hopfield is fully defined to
|
|
922
|
+
# avoid a circular import (diffused_attention.py does 'from . import Hopfield').
|
|
923
|
+
from .diffused_attention import DiffusedHopfield # noqa: E402
|
|
924
|
+
from .attention_operator import AttentionOperator # noqa: E402
|
|
925
|
+
from .diffusion import ( # noqa: E402
|
|
926
|
+
DiffusionOperator,
|
|
927
|
+
SimpleDiffusion,
|
|
928
|
+
IterativeDiffusion,
|
|
929
|
+
SpectralDiffusion,
|
|
930
|
+
FactoredDiffusion,
|
|
931
|
+
apply_diffusion,
|
|
932
|
+
)
|
|
933
|
+
from .dynamics_engine import ( # noqa: E402
|
|
934
|
+
DiffusionConfig,
|
|
935
|
+
GraphCache,
|
|
936
|
+
EnergyTracker,
|
|
937
|
+
DynamicsEngine,
|
|
938
|
+
)
|
|
939
|
+
from .graph import GraphBuilder, LaplacianBuilder # noqa: E402
|
|
940
|
+
|
|
941
|
+
__all__ = [
|
|
942
|
+
# Original Hopfield modules
|
|
943
|
+
"Hopfield",
|
|
944
|
+
"HopfieldPooling",
|
|
945
|
+
"HopfieldLayer",
|
|
946
|
+
"HopfieldCore",
|
|
947
|
+
# DAHN — main model
|
|
948
|
+
"DiffusedHopfield",
|
|
949
|
+
# DAHN — diffusion operators
|
|
950
|
+
"DiffusionOperator",
|
|
951
|
+
"SimpleDiffusion",
|
|
952
|
+
"IterativeDiffusion",
|
|
953
|
+
"SpectralDiffusion",
|
|
954
|
+
"FactoredDiffusion",
|
|
955
|
+
"apply_diffusion",
|
|
956
|
+
# DAHN — graph utilities
|
|
957
|
+
"GraphBuilder",
|
|
958
|
+
"LaplacianBuilder",
|
|
959
|
+
# DAHN — runtime components
|
|
960
|
+
"DiffusionConfig",
|
|
961
|
+
"GraphCache",
|
|
962
|
+
"EnergyTracker",
|
|
963
|
+
"DynamicsEngine",
|
|
964
|
+
"AttentionOperator",
|
|
965
|
+
]
|