difflayers 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
difflayers/__init__.py ADDED
@@ -0,0 +1,965 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from math import sqrt
5
+ from torch import Tensor
6
+ from torch.nn import Module, Parameter
7
+ from typing import Optional, Tuple, Union
8
+
9
+ from .activation import HopfieldCore
10
+
11
+
12
+ class Hopfield(Module):
13
+ """
14
+ Module with underlying Hopfield association.
15
+ """
16
+
17
+ def __init__(self,
18
+ input_size: Optional[int] = None,
19
+ hidden_size: Optional[int] = None,
20
+ output_size: Optional[int] = None,
21
+ pattern_size: Optional[int] = None,
22
+ num_heads: int = 1,
23
+ scaling: Optional[Union[float, Tensor]] = None,
24
+ update_steps_max: Optional[Union[int, Tensor]] = 0,
25
+ update_steps_eps: Union[float, Tensor] = 1e-4,
26
+
27
+ normalize_stored_pattern: bool = True,
28
+ normalize_stored_pattern_affine: bool = True,
29
+ normalize_stored_pattern_eps: float = 1e-5,
30
+ normalize_state_pattern: bool = True,
31
+ normalize_state_pattern_affine: bool = True,
32
+ normalize_state_pattern_eps: float = 1e-5,
33
+ normalize_pattern_projection: bool = True,
34
+ normalize_pattern_projection_affine: bool = True,
35
+ normalize_pattern_projection_eps: float = 1e-5,
36
+ normalize_hopfield_space: bool = False,
37
+ normalize_hopfield_space_affine: bool = False,
38
+ normalize_hopfield_space_eps: float = 1e-5,
39
+ stored_pattern_as_static: bool = False,
40
+ state_pattern_as_static: bool = False,
41
+ pattern_projection_as_static: bool = False,
42
+ pattern_projection_as_connected: bool = False,
43
+ stored_pattern_size: Optional[int] = None,
44
+ pattern_projection_size: Optional[int] = None,
45
+
46
+ batch_first: bool = True,
47
+ association_activation: Optional[str] = None,
48
+ dropout: float = 0.0,
49
+ input_bias: bool = True,
50
+ concat_bias_pattern: bool = False,
51
+ add_zero_association: bool = False,
52
+ disable_out_projection: bool = False
53
+ ):
54
+ """
55
+ Initialise new instance of a Hopfield module.
56
+
57
+ :param input_size: depth of the input (state pattern)
58
+ :param hidden_size: depth of the association space
59
+ :param output_size: depth of the output projection
60
+ :param pattern_size: depth of patterns to be selected
61
+ :param num_heads: amount of parallel association heads
62
+ :param scaling: scaling of association heads, often represented as beta (one entry per head)
63
+ :param update_steps_max: maximum count of association update steps (None equals to infinity)
64
+ :param update_steps_eps: minimum difference threshold between two consecutive association update steps
65
+ :param normalize_stored_pattern: apply normalization on stored patterns
66
+ :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns
67
+ :param normalize_stored_pattern_eps: offset of the denominator for numerical stability
68
+ :param normalize_state_pattern: apply normalization on state patterns
69
+ :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns
70
+ :param normalize_state_pattern_eps: offset of the denominator for numerical stability
71
+ :param normalize_pattern_projection: apply normalization on the pattern projection
72
+ :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection
73
+ :param normalize_pattern_projection_eps: offset of the denominator for numerical stability
74
+ :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space
75
+ :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space
76
+ :param normalize_hopfield_space_eps: offset of the denominator for numerical stability
77
+ :param stored_pattern_as_static: interpret specified stored patterns as being static
78
+ :param state_pattern_as_static: interpret specified state patterns as being static
79
+ :param pattern_projection_as_static: interpret specified pattern projections as being static
80
+ :param pattern_projection_as_connected: connect pattern projection with stored pattern
81
+ :param stored_pattern_size: depth of input (stored pattern)
82
+ :param pattern_projection_size: depth of input (pattern projection)
83
+ :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size
84
+ :param association_activation: additional activation to be applied on the result of the Hopfield association
85
+ :param dropout: dropout probability applied on the association matrix
86
+ :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection)
87
+ :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection
88
+ :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection
89
+ :param disable_out_projection: disable output projection
90
+ """
91
+ super(Hopfield, self).__init__()
92
+ assert type(batch_first) == bool, f'"batch_first" needs to be a boolean, not {type(batch_first)}.'
93
+ assert (association_activation is None) or (type(association_activation) == str)
94
+
95
+ # Initialise Hopfield association module.
96
+ self.association_core = HopfieldCore(
97
+ embed_dim=input_size, num_heads=num_heads, dropout=dropout, bias=input_bias,
98
+ add_bias_kv=concat_bias_pattern, add_zero_attn=add_zero_association, kdim=stored_pattern_size,
99
+ vdim=pattern_projection_size, head_dim=hidden_size, pattern_dim=pattern_size, out_dim=output_size,
100
+ disable_out_projection=disable_out_projection, key_as_static=stored_pattern_as_static,
101
+ query_as_static=state_pattern_as_static, value_as_static=pattern_projection_as_static,
102
+ value_as_connected=pattern_projection_as_connected, normalize_pattern=normalize_hopfield_space,
103
+ normalize_pattern_affine=normalize_hopfield_space_affine,
104
+ normalize_pattern_eps=normalize_hopfield_space_eps)
105
+ self.association_activation = None
106
+ if association_activation is not None:
107
+ self.association_activation = getattr(torch, association_activation, None)
108
+
109
+ # Initialise stored pattern normalization.
110
+ self.norm_stored_pattern = None
111
+ if normalize_stored_pattern_affine:
112
+ assert normalize_stored_pattern, "affine normalization without normalization has no effect."
113
+ if normalize_stored_pattern:
114
+ normalized_shape = input_size if stored_pattern_size is None else stored_pattern_size
115
+ assert normalized_shape is not None, "stored pattern size required for setting up normalisation"
116
+ self.norm_stored_pattern = nn.LayerNorm(
117
+ normalized_shape=normalized_shape, elementwise_affine=normalize_stored_pattern_affine,
118
+ eps=normalize_stored_pattern_eps)
119
+
120
+ # Initialise state pattern normalization.
121
+ self.norm_state_pattern = None
122
+ if normalize_state_pattern_affine:
123
+ assert normalize_state_pattern, "affine normalization without normalization has no effect."
124
+ if normalize_state_pattern:
125
+ assert input_size is not None, "input size required for setting up normalisation"
126
+ self.norm_state_pattern = nn.LayerNorm(
127
+ normalized_shape=input_size, elementwise_affine=normalize_state_pattern_affine,
128
+ eps=normalize_state_pattern_eps)
129
+
130
+ # Initialise pattern projection normalization.
131
+ self.norm_pattern_projection = None
132
+ if normalize_pattern_projection_affine:
133
+ assert normalize_pattern_projection, "affine normalization without normalization has no effect."
134
+ if normalize_pattern_projection:
135
+ normalized_shape = input_size if pattern_projection_size is None else pattern_projection_size
136
+ assert normalized_shape is not None, "pattern projection size required for setting up normalisation"
137
+ self.norm_pattern_projection = nn.LayerNorm(
138
+ normalized_shape=normalized_shape, elementwise_affine=normalize_pattern_projection_affine,
139
+ eps=normalize_pattern_projection_eps)
140
+
141
+ # Initialise remaining auxiliary properties.
142
+ if self.association_core.static_execution:
143
+ self.__scaling = 1.0 if scaling is None else scaling
144
+ else:
145
+ assert self.association_core.head_dim > 0, f'invalid hidden dimension encountered.'
146
+ self.__scaling = (1.0 / sqrt(self.association_core.head_dim)) if scaling is None else scaling
147
+ self.__batch_first = batch_first
148
+ self.__update_steps_max = update_steps_max
149
+ self.__update_steps_eps = update_steps_eps
150
+ self.reset_parameters()
151
+
152
+ def reset_parameters(self) -> None:
153
+ """
154
+ Reset Hopfield association.
155
+
156
+ :return: None
157
+ """
158
+ for module in (self.association_core, self.norm_stored_pattern,
159
+ self.norm_state_pattern, self.norm_pattern_projection):
160
+ if hasattr(module, r'reset_parameters'):
161
+ module.reset_parameters()
162
+
163
+ def _maybe_transpose(self, *args: Tuple[Tensor, ...]) -> Union[Tensor, Tuple[Tensor, ...]]:
164
+ """
165
+ Eventually transpose specified data.
166
+
167
+ :param args: tensors to eventually transpose (dependent on the state of "batch_first")
168
+ :return: eventually transposed tensors
169
+ """
170
+ transposed_result = tuple(_.transpose(0, 1) for _ in args) if self.__batch_first else args
171
+ return transposed_result[0] if len(transposed_result) == 1 else transposed_result
172
+
173
+ def _associate(self, data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
174
+ return_raw_associations: bool = False, return_projected_patterns: bool = False,
175
+ stored_pattern_padding_mask: Optional[Tensor] = None,
176
+ association_mask: Optional[Tensor] = None) -> Tuple[Optional[Tensor], ...]:
177
+ """
178
+ Apply Hopfield association module on specified data.
179
+
180
+ :param data: data to be processed by Hopfield core module
181
+ :param return_raw_associations: return raw association (softmax) values, unmodified
182
+ :param return_projected_patterns: return pattern projection values, unmodified
183
+ :param stored_pattern_padding_mask: mask to be applied on stored patterns
184
+ :param association_mask: mask to be applied on inner association matrix
185
+ :return: Hopfield-processed input data
186
+ """
187
+ assert (type(data) == Tensor) or ((type(data) == tuple) and (len(data) == 3)), \
188
+ r'either one tensor to be used as "stored pattern", "state pattern" and' \
189
+ r' "pattern_projection" must be provided, or three separate ones.'
190
+ if type(data) == Tensor:
191
+ stored_pattern, state_pattern, pattern_projection = data, data, data
192
+ else:
193
+ stored_pattern, state_pattern, pattern_projection = data
194
+
195
+ # Optionally transpose data.
196
+ stored_pattern, state_pattern, pattern_projection = self._maybe_transpose(
197
+ stored_pattern, state_pattern, pattern_projection)
198
+
199
+ # Optionally apply stored pattern normalization.
200
+ if self.norm_stored_pattern is not None:
201
+ stored_pattern = self.norm_stored_pattern(input=stored_pattern.reshape(
202
+ shape=(-1, stored_pattern.shape[2]))).reshape(shape=stored_pattern.shape)
203
+
204
+ # Optionally apply state pattern normalization.
205
+ if self.norm_state_pattern is not None:
206
+ state_pattern = self.norm_state_pattern(input=state_pattern.reshape(
207
+ shape=(-1, state_pattern.shape[2]))).reshape(shape=state_pattern.shape)
208
+
209
+ # Optionally apply pattern projection normalization.
210
+ if self.norm_pattern_projection is not None:
211
+ pattern_projection = self.norm_pattern_projection(input=pattern_projection.reshape(
212
+ shape=(-1, pattern_projection.shape[2]))).reshape(shape=pattern_projection.shape)
213
+
214
+ # Apply Hopfield association and optional activation function.
215
+ return self.association_core(
216
+ query=state_pattern, key=stored_pattern, value=pattern_projection,
217
+ key_padding_mask=stored_pattern_padding_mask, need_weights=False, attn_mask=association_mask,
218
+ scaling=self.__scaling, update_steps_max=self.__update_steps_max, update_steps_eps=self.__update_steps_eps,
219
+ return_raw_associations=return_raw_associations, return_pattern_projections=return_projected_patterns)
220
+
221
+ def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
222
+ stored_pattern_padding_mask: Optional[Tensor] = None,
223
+ association_mask: Optional[Tensor] = None) -> Tensor:
224
+ """
225
+ Apply Hopfield association on specified data.
226
+
227
+ :param input: data to be processed by Hopfield association module
228
+ :param stored_pattern_padding_mask: mask to be applied on stored patterns
229
+ :param association_mask: mask to be applied on inner association matrix
230
+ :return: Hopfield-processed input data
231
+ """
232
+ association_output = self._maybe_transpose(self._associate(
233
+ data=input, return_raw_associations=False,
234
+ stored_pattern_padding_mask=stored_pattern_padding_mask,
235
+ association_mask=association_mask)[0])
236
+ if self.association_activation is not None:
237
+ association_output = self.association_activation(association_output)
238
+ return association_output
239
+
240
+ def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
241
+ stored_pattern_padding_mask: Optional[Tensor] = None,
242
+ association_mask: Optional[Tensor] = None) -> Tensor:
243
+ """
244
+ Fetch Hopfield association matrix gathered by passing through the specified data.
245
+
246
+ :param input: data to be passed through the Hopfield association
247
+ :param stored_pattern_padding_mask: mask to be applied on stored patterns
248
+ :param association_mask: mask to be applied on inner association matrix
249
+ :return: association matrix as computed by the Hopfield core module
250
+ """
251
+ with torch.no_grad():
252
+ return self._associate(
253
+ data=input, return_raw_associations=True,
254
+ stored_pattern_padding_mask=stored_pattern_padding_mask,
255
+ association_mask=association_mask)[2]
256
+
257
+ def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
258
+ stored_pattern_padding_mask: Optional[Tensor] = None,
259
+ association_mask: Optional[Tensor] = None) -> Tensor:
260
+ """
261
+ Fetch Hopfield projected pattern matrix gathered by passing through the specified data.
262
+
263
+ :param input: data to be passed through the Hopfield association
264
+ :param stored_pattern_padding_mask: mask to be applied on stored patterns
265
+ :param association_mask: mask to be applied on inner association matrix
266
+ :return: pattern projection matrix as computed by the Hopfield core module
267
+ """
268
+ with torch.no_grad():
269
+ return self._associate(
270
+ data=input, return_projected_patterns=True,
271
+ stored_pattern_padding_mask=stored_pattern_padding_mask,
272
+ association_mask=association_mask)[3]
273
+
274
+ @property
275
+ def batch_first(self) -> bool:
276
+ return self.__batch_first
277
+
278
+ @property
279
+ def scaling(self) -> Union[float, Tensor]:
280
+ return self.__scaling.clone() if type(self.__scaling) == Tensor else self.__scaling
281
+
282
+ @property
283
+ def stored_pattern_dim(self) -> Optional[int]:
284
+ return self.association_core.kdim
285
+
286
+ @property
287
+ def state_pattern_dim(self) -> Optional[int]:
288
+ return self.association_core.embed_dim
289
+
290
+ @property
291
+ def pattern_projection_dim(self) -> Optional[int]:
292
+ return self.association_core.vdim
293
+
294
+ @property
295
+ def input_size(self) -> Optional[int]:
296
+ return self.state_pattern_dim
297
+
298
+ @property
299
+ def hidden_size(self) -> Optional[int]:
300
+ return self.association_core.head_dim
301
+
302
+ @property
303
+ def output_size(self) -> Optional[int]:
304
+ return self.association_core.out_dim
305
+
306
+ @property
307
+ def pattern_size(self) -> Optional[int]:
308
+ return self.association_core.pattern_dim
309
+
310
+ @property
311
+ def update_steps_max(self) -> Optional[Union[int, Tensor]]:
312
+ return self.__update_steps_max.clone() if type(self.__update_steps_max) == Tensor else self.__update_steps_max
313
+
314
+ @property
315
+ def update_steps_eps(self) -> Optional[Union[float, Tensor]]:
316
+ return self.__update_steps_eps.clone() if type(self.__update_steps_eps) == Tensor else self.__update_steps_eps
317
+
318
+ @property
319
+ def stored_pattern_as_static(self) -> bool:
320
+ return self.association_core.key_as_static
321
+
322
+ @property
323
+ def state_pattern_as_static(self) -> bool:
324
+ return self.association_core.query_as_static
325
+
326
+ @property
327
+ def pattern_projection_as_static(self) -> bool:
328
+ return self.association_core.value_as_static
329
+
330
+ @property
331
+ def normalize_stored_pattern(self) -> bool:
332
+ return self.norm_stored_pattern is not None
333
+
334
+ @property
335
+ def normalize_stored_pattern_affine(self) -> bool:
336
+ return self.normalize_stored_pattern and self.norm_stored_pattern.elementwise_affine
337
+
338
+ @property
339
+ def normalize_state_pattern(self) -> bool:
340
+ return self.norm_state_pattern is not None
341
+
342
+ @property
343
+ def normalize_state_pattern_affine(self) -> bool:
344
+ return self.normalize_state_pattern and self.norm_state_pattern.elementwise_affine
345
+
346
+ @property
347
+ def normalize_pattern_projection(self) -> bool:
348
+ return self.norm_pattern_projection is not None
349
+
350
+ @property
351
+ def normalize_pattern_projection_affine(self) -> bool:
352
+ return self.normalize_pattern_projection and self.norm_pattern_projection.elementwise_affine
353
+
354
+ @property
355
+ def normalize_hopfield_space(self) -> bool:
356
+ return self.association_core.normalize_pattern
357
+
358
+ @property
359
+ def normalize_hopfield_space_affine(self) -> bool:
360
+ return self.association_core.normalize_pattern_affine
361
+
362
+
363
+ class HopfieldPooling(Module):
364
+ """
365
+ Wrapper class encapsulating a trainable but fixed state pattern and "Hopfield" in
366
+ one combined module to be used as a Hopfield-based pooling layer.
367
+ """
368
+
369
+ def __init__(self,
370
+ input_size: int,
371
+ hidden_size: Optional[int] = None,
372
+ output_size: Optional[int] = None,
373
+ pattern_size: Optional[int] = None,
374
+ num_heads: int = 1,
375
+ scaling: Optional[Union[float, Tensor]] = None,
376
+ update_steps_max: Optional[Union[int, Tensor]] = 0,
377
+ update_steps_eps: Union[float, Tensor] = 1e-4,
378
+
379
+ normalize_stored_pattern: bool = True,
380
+ normalize_stored_pattern_affine: bool = True,
381
+ normalize_state_pattern: bool = True,
382
+ normalize_state_pattern_affine: bool = True,
383
+ normalize_pattern_projection: bool = True,
384
+ normalize_pattern_projection_affine: bool = True,
385
+ normalize_hopfield_space: bool = False,
386
+ normalize_hopfield_space_affine: bool = False,
387
+ stored_pattern_as_static: bool = False,
388
+ state_pattern_as_static: bool = False,
389
+ pattern_projection_as_static: bool = False,
390
+ pattern_projection_as_connected: bool = False,
391
+ stored_pattern_size: Optional[int] = None,
392
+ pattern_projection_size: Optional[int] = None,
393
+
394
+ batch_first: bool = True,
395
+ association_activation: Optional[str] = None,
396
+ dropout: float = 0.0,
397
+ input_bias: bool = True,
398
+ concat_bias_pattern: bool = False,
399
+ add_zero_association: bool = False,
400
+ disable_out_projection: bool = False,
401
+ quantity: int = 1,
402
+ trainable: bool = True
403
+ ):
404
+ """
405
+ Initialise a new instance of a Hopfield-based pooling layer.
406
+
407
+ :param input_size: depth of the input (state pattern)
408
+ :param hidden_size: depth of the association space
409
+ :param output_size: depth of the output projection
410
+ :param pattern_size: depth of patterns to be selected
411
+ :param num_heads: amount of parallel association heads
412
+ :param scaling: scaling of association heads, often represented as beta (one entry per head)
413
+ :param update_steps_max: maximum count of association update steps (None equals to infinity)
414
+ :param update_steps_eps: minimum difference threshold between two consecutive association update steps
415
+ :param normalize_stored_pattern: apply normalization on stored patterns
416
+ :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns
417
+ :param normalize_state_pattern: apply normalization on state patterns
418
+ :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns
419
+ :param normalize_pattern_projection: apply normalization on the pattern projection
420
+ :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection
421
+ :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space
422
+ :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space
423
+ :param stored_pattern_as_static: interpret specified stored patterns as being static
424
+ :param state_pattern_as_static: interpret specified state patterns as being static
425
+ :param pattern_projection_as_static: interpret specified pattern projections as being static
426
+ :param pattern_projection_as_connected: connect pattern projection with stored pattern
427
+ :param stored_pattern_size: depth of input (stored pattern)
428
+ :param pattern_projection_size: depth of input (pattern projection)
429
+ :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size
430
+ :param association_activation: additional activation to be applied on the result of the Hopfield association
431
+ :param dropout: dropout probability applied on the association matrix
432
+ :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection)
433
+ :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection
434
+ :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection
435
+ :param disable_out_projection: disable output projection
436
+ :param quantity: amount of state patterns
437
+ :param trainable: state pattern used for pooling is trainable
438
+ """
439
+ super(HopfieldPooling, self).__init__()
440
+ self.hopfield = Hopfield(
441
+ input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size,
442
+ num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
443
+ normalize_stored_pattern=normalize_stored_pattern,
444
+ normalize_stored_pattern_affine=normalize_stored_pattern_affine,
445
+ normalize_state_pattern=normalize_state_pattern,
446
+ normalize_state_pattern_affine=normalize_state_pattern_affine,
447
+ normalize_pattern_projection=normalize_pattern_projection,
448
+ normalize_pattern_projection_affine=normalize_pattern_projection_affine,
449
+ normalize_hopfield_space=normalize_hopfield_space,
450
+ normalize_hopfield_space_affine=normalize_hopfield_space_affine,
451
+ stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static,
452
+ pattern_projection_as_static=pattern_projection_as_static,
453
+ pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size,
454
+ pattern_projection_size=pattern_projection_size, batch_first=batch_first,
455
+ association_activation=association_activation, dropout=dropout, input_bias=input_bias,
456
+ concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association,
457
+ disable_out_projection=disable_out_projection)
458
+ self._quantity = quantity
459
+ pooling_weight_size = self.hopfield.hidden_size if state_pattern_as_static else self.hopfield.input_size
460
+ self.pooling_weights = nn.Parameter(torch.empty(size=(*(
461
+ (1, quantity) if batch_first else (quantity, 1)
462
+ ), input_size if pooling_weight_size is None else pooling_weight_size)), requires_grad=trainable)
463
+ self.reset_parameters()
464
+
465
+ def reset_parameters(self) -> None:
466
+ """
467
+ Reset pooling weights and underlying Hopfield association.
468
+
469
+ :return: None
470
+ """
471
+ if hasattr(self.hopfield, r'reset_parameters'):
472
+ self.hopfield.reset_parameters()
473
+
474
+ # Explicitly initialise pooling weights.
475
+ nn.init.normal_(self.pooling_weights, mean=0.0, std=0.02)
476
+
477
+ def _prepare_input(self, input: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]:
478
+ """
479
+ Prepare input for Hopfield association.
480
+
481
+ :param input: data to be prepared
482
+ :return: stored pattern, expanded state pattern as well as pattern projection
483
+ """
484
+ assert (type(input) == Tensor) or ((type(input) == tuple) and (len(input) == 2)), \
485
+ r'either one tensor to be used as "stored pattern" and' \
486
+ r' "pattern_projection" must be provided, or two separate ones.'
487
+ if type(input) == Tensor:
488
+ stored_pattern, pattern_projection = input, input
489
+ else:
490
+ stored_pattern, pattern_projection = input
491
+
492
+ batch_size = stored_pattern.shape[0 if self.batch_first else 1]
493
+ return stored_pattern, self.pooling_weights.expand(size=(*(
494
+ (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size)
495
+ ), self.pooling_weights.shape[2])), pattern_projection
496
+
497
+ def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor]], stored_pattern_padding_mask: Optional[Tensor] = None,
498
+ association_mask: Optional[Tensor] = None) -> Tensor:
499
+ """
500
+ Compute Hopfield-based pooling on specified data.
501
+
502
+ :param input: data to be pooled
503
+ :param stored_pattern_padding_mask: mask to be applied on stored patterns
504
+ :param association_mask: mask to be applied on inner association matrix
505
+ :return: Hopfield-pooled input data
506
+ """
507
+ return self.hopfield(
508
+ input=self._prepare_input(input=input),
509
+ stored_pattern_padding_mask=stored_pattern_padding_mask,
510
+ association_mask=association_mask).flatten(start_dim=1)
511
+
512
+ def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
513
+ stored_pattern_padding_mask: Optional[Tensor] = None,
514
+ association_mask: Optional[Tensor] = None) -> Tensor:
515
+ """
516
+ Fetch Hopfield association matrix used for pooling gathered by passing through the specified data.
517
+
518
+ :param input: data to be passed through the Hopfield association
519
+ :param stored_pattern_padding_mask: mask to be applied on stored patterns
520
+ :param association_mask: mask to be applied on inner association matrix
521
+ :return: association matrix as computed by the Hopfield core module
522
+ """
523
+ with torch.no_grad():
524
+ return self.hopfield.get_association_matrix(
525
+ input=self._prepare_input(input=input),
526
+ stored_pattern_padding_mask=stored_pattern_padding_mask,
527
+ association_mask=association_mask)
528
+
529
+ def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
530
+ stored_pattern_padding_mask: Optional[Tensor] = None,
531
+ association_mask: Optional[Tensor] = None) -> Tensor:
532
+ """
533
+ Fetch Hopfield projected pattern matrix gathered by passing through the specified data.
534
+
535
+ :param input: data to be passed through the Hopfield association
536
+ :param stored_pattern_padding_mask: mask to be applied on stored patterns
537
+ :param association_mask: mask to be applied on inner association matrix
538
+ :return: pattern projection matrix as computed by the Hopfield core module
539
+ """
540
+ with torch.no_grad():
541
+ return self.hopfield.get_projected_pattern_matrix(
542
+ input=self._prepare_input(input=input),
543
+ stored_pattern_padding_mask=stored_pattern_padding_mask,
544
+ association_mask=association_mask)
545
+
546
+ @property
547
+ def batch_first(self) -> bool:
548
+ return self.hopfield.batch_first
549
+
550
+ @property
551
+ def scaling(self) -> Union[float, Tensor]:
552
+ return self.hopfield.scaling
553
+
554
+ @property
555
+ def stored_pattern_dim(self) -> Optional[int]:
556
+ return self.hopfield.stored_pattern_dim
557
+
558
+ @property
559
+ def state_pattern_dim(self) -> Optional[int]:
560
+ return self.hopfield.state_pattern_dim
561
+
562
+ @property
563
+ def pattern_projection_dim(self) -> Optional[int]:
564
+ return self.hopfield.pattern_projection_dim
565
+
566
+ @property
567
+ def input_size(self) -> Optional[int]:
568
+ return self.hopfield.input_size
569
+
570
+ @property
571
+ def hidden_size(self) -> int:
572
+ return self.hopfield.hidden_size
573
+
574
+ @property
575
+ def output_size(self) -> Optional[int]:
576
+ return self.hopfield.output_size
577
+
578
+ @property
579
+ def pattern_size(self) -> Optional[int]:
580
+ return self.hopfield.pattern_size
581
+
582
+ @property
583
+ def quantity(self) -> int:
584
+ return self._quantity
585
+
586
+ @property
587
+ def update_steps_max(self) -> Optional[Union[int, Tensor]]:
588
+ return self.hopfield.update_steps_max
589
+
590
+ @property
591
+ def update_steps_eps(self) -> Optional[Union[float, Tensor]]:
592
+ return self.hopfield.update_steps_eps
593
+
594
+ @property
595
+ def stored_pattern_as_static(self) -> bool:
596
+ return self.hopfield.stored_pattern_as_static
597
+
598
+ @property
599
+ def state_pattern_as_static(self) -> bool:
600
+ return self.hopfield.state_pattern_as_static
601
+
602
+ @property
603
+ def pattern_projection_as_static(self) -> bool:
604
+ return self.hopfield.pattern_projection_as_static
605
+
606
+ @property
607
+ def normalize_stored_pattern(self) -> bool:
608
+ return self.hopfield.normalize_stored_pattern
609
+
610
+ @property
611
+ def normalize_stored_pattern_affine(self) -> bool:
612
+ return self.hopfield.normalize_stored_pattern_affine
613
+
614
+ @property
615
+ def normalize_state_pattern(self) -> bool:
616
+ return self.hopfield.normalize_state_pattern
617
+
618
+ @property
619
+ def normalize_state_pattern_affine(self) -> bool:
620
+ return self.hopfield.normalize_state_pattern_affine
621
+
622
+ @property
623
+ def normalize_pattern_projection(self) -> bool:
624
+ return self.hopfield.normalize_pattern_projection
625
+
626
+ @property
627
+ def normalize_pattern_projection_affine(self) -> bool:
628
+ return self.hopfield.normalize_pattern_projection_affine
629
+
630
+ @property
631
+ def normalize_hopfield_space(self) -> bool:
632
+ return self.hopfield.normalize_hopfield_space
633
+
634
+ @property
635
+ def normalize_hopfield_space_affine(self) -> bool:
636
+ return self.hopfield.normalize_hopfield_space_affine
637
+
638
+
639
+ class HopfieldLayer(Module):
640
+ """
641
+ Wrapper class encapsulating a trainable but fixed stored pattern, pattern projection and "Hopfield" in
642
+ one combined module to be used as a Hopfield-based pooling layer.
643
+ """
644
+
645
+ def __init__(self,
646
+ input_size: int,
647
+ hidden_size: Optional[int] = None,
648
+ output_size: Optional[int] = None,
649
+ pattern_size: Optional[int] = None,
650
+ num_heads: int = 1,
651
+ scaling: Optional[Union[float, Tensor]] = None,
652
+ update_steps_max: Optional[Union[int, Tensor]] = 0,
653
+ update_steps_eps: Union[float, Tensor] = 1e-4,
654
+ lookup_weights_as_separated: bool = False,
655
+ lookup_targets_as_trainable: bool = True,
656
+
657
+ normalize_stored_pattern: bool = True,
658
+ normalize_stored_pattern_affine: bool = True,
659
+ normalize_state_pattern: bool = True,
660
+ normalize_state_pattern_affine: bool = True,
661
+ normalize_pattern_projection: bool = True,
662
+ normalize_pattern_projection_affine: bool = True,
663
+ normalize_hopfield_space: bool = False,
664
+ normalize_hopfield_space_affine: bool = False,
665
+ stored_pattern_as_static: bool = False,
666
+ state_pattern_as_static: bool = False,
667
+ pattern_projection_as_static: bool = False,
668
+ pattern_projection_as_connected: bool = False,
669
+ stored_pattern_size: Optional[int] = None,
670
+ pattern_projection_size: Optional[int] = None,
671
+
672
+ batch_first: bool = True,
673
+ association_activation: Optional[str] = None,
674
+ dropout: float = 0.0,
675
+ input_bias: bool = True,
676
+ concat_bias_pattern: bool = False,
677
+ add_zero_association: bool = False,
678
+ disable_out_projection: bool = False,
679
+ quantity: int = 1,
680
+ trainable: bool = True
681
+ ):
682
+ """
683
+ Initialise a new instance of a Hopfield-based lookup layer.
684
+
685
+ :param input_size: depth of the input (state pattern)
686
+ :param hidden_size: depth of the association space
687
+ :param output_size: depth of the output projection
688
+ :param pattern_size: depth of patterns to be selected
689
+ :param num_heads: amount of parallel association heads
690
+ :param scaling: scaling of association heads, often represented as beta (one entry per head)
691
+ :param update_steps_max: maximum count of association update steps (None equals to infinity)
692
+ :param update_steps_eps: minimum difference threshold between two consecutive association update steps
693
+ :param lookup_weights_as_separated: separate lookup weights from lookup target weights
694
+ :param lookup_targets_as_trainable: employ trainable lookup target weights (used as pattern projection input)
695
+ :param normalize_stored_pattern: apply normalization on stored patterns
696
+ :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns
697
+ :param normalize_state_pattern: apply normalization on state patterns
698
+ :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns
699
+ :param normalize_pattern_projection: apply normalization on the pattern projection
700
+ :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection
701
+ :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space
702
+ :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space
703
+ :param stored_pattern_as_static: interpret specified stored patterns as being static
704
+ :param state_pattern_as_static: interpret specified state patterns as being static
705
+ :param pattern_projection_as_static: interpret specified pattern projections as being static
706
+ :param pattern_projection_as_connected: connect pattern projection with stored pattern
707
+ :param stored_pattern_size: depth of input (stored pattern)
708
+ :param pattern_projection_size: depth of input (pattern projection)
709
+ :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size
710
+ :param association_activation: additional activation to be applied on the result of the Hopfield association
711
+ :param dropout: dropout probability applied on the association matrix
712
+ :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection)
713
+ :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection
714
+ :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection
715
+ :param disable_out_projection: disable output projection
716
+ :param quantity: amount of stored patterns
717
+ :param trainable: stored pattern used for lookup is trainable
718
+ """
719
+ super(HopfieldLayer, self).__init__()
720
+ self.hopfield = Hopfield(
721
+ input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size,
722
+ num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
723
+ normalize_stored_pattern=normalize_stored_pattern,
724
+ normalize_stored_pattern_affine=normalize_stored_pattern_affine,
725
+ normalize_state_pattern=normalize_state_pattern,
726
+ normalize_state_pattern_affine=normalize_state_pattern_affine,
727
+ normalize_pattern_projection=normalize_pattern_projection,
728
+ normalize_pattern_projection_affine=normalize_pattern_projection_affine,
729
+ normalize_hopfield_space=normalize_hopfield_space,
730
+ normalize_hopfield_space_affine=normalize_hopfield_space_affine,
731
+ stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static,
732
+ pattern_projection_as_static=pattern_projection_as_static,
733
+ pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size,
734
+ pattern_projection_size=pattern_projection_size, batch_first=batch_first,
735
+ association_activation=association_activation, dropout=dropout, input_bias=input_bias,
736
+ concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association,
737
+ disable_out_projection=disable_out_projection)
738
+ self._quantity = quantity
739
+ lookup_weight_size = self.hopfield.hidden_size if stored_pattern_as_static else self.hopfield.stored_pattern_dim
740
+ self.lookup_weights = nn.Parameter(torch.empty(size=(*(
741
+ (1, quantity) if batch_first else (quantity, 1)
742
+ ), input_size if lookup_weight_size is None else lookup_weight_size)), requires_grad=trainable)
743
+
744
+ if lookup_weights_as_separated:
745
+ target_weight_size = self.lookup_weights.shape[
746
+ 2] if pattern_projection_size is None else pattern_projection_size
747
+ self.target_weights = nn.Parameter(torch.empty(size=(*(
748
+ (1, quantity) if batch_first else (quantity, 1)
749
+ ), target_weight_size)), requires_grad=lookup_targets_as_trainable)
750
+ else:
751
+ self.register_parameter(name=r'target_weights', param=None)
752
+ self.reset_parameters()
753
+
754
+ def reset_parameters(self) -> None:
755
+ """
756
+ Reset lookup and lookup target weights, including underlying Hopfield association.
757
+
758
+ :return: None
759
+ """
760
+ if hasattr(self.hopfield, r'reset_parameters'):
761
+ self.hopfield.reset_parameters()
762
+
763
+ # Explicitly initialise lookup and target weights.
764
+ nn.init.normal_(self.lookup_weights, mean=0.0, std=0.02)
765
+ if self.target_weights is not None:
766
+ nn.init.normal_(self.target_weights, mean=0.0, std=0.02)
767
+
768
+ def _prepare_input(self, input: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
769
+ """
770
+ Prepare input for Hopfield association.
771
+
772
+ :param input: data to be prepared
773
+ :return: stored pattern, expanded state pattern as well as pattern projection
774
+ """
775
+ batch_size = input.shape[0 if self.batch_first else 1]
776
+ stored_pattern = self.lookup_weights.expand(size=(*(
777
+ (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size)
778
+ ), self.lookup_weights.shape[2]))
779
+ if self.target_weights is None:
780
+ pattern_projection = stored_pattern
781
+ else:
782
+ pattern_projection = self.target_weights.expand(size=(*(
783
+ (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size)
784
+ ), self.target_weights.shape[2]))
785
+
786
+ return stored_pattern, input, pattern_projection
787
+
788
+ def forward(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None,
789
+ association_mask: Optional[Tensor] = None) -> Tensor:
790
+ """
791
+ Compute Hopfield-based lookup on specified data.
792
+
793
+ :param input: data to used in lookup
794
+ :param stored_pattern_padding_mask: mask to be applied on stored patterns
795
+ :param association_mask: mask to be applied on inner association matrix
796
+ :return: result of Hopfield-based lookup on input data
797
+ """
798
+ return self.hopfield(
799
+ input=self._prepare_input(input=input),
800
+ stored_pattern_padding_mask=stored_pattern_padding_mask,
801
+ association_mask=association_mask)
802
+
803
+ def get_association_matrix(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None,
804
+ association_mask: Optional[Tensor] = None) -> Tensor:
805
+ """
806
+ Fetch Hopfield association matrix used for lookup gathered by passing through the specified data.
807
+
808
+ :param input: data to be passed through the Hopfield association
809
+ :param stored_pattern_padding_mask: mask to be applied on stored patterns
810
+ :param association_mask: mask to be applied on inner association matrix
811
+ :return: association matrix as computed by the Hopfield core module
812
+ """
813
+ with torch.no_grad():
814
+ return self.hopfield.get_association_matrix(
815
+ input=self._prepare_input(input=input),
816
+ stored_pattern_padding_mask=stored_pattern_padding_mask,
817
+ association_mask=association_mask)
818
+
819
+ def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
820
+ stored_pattern_padding_mask: Optional[Tensor] = None,
821
+ association_mask: Optional[Tensor] = None) -> Tensor:
822
+ """
823
+ Fetch Hopfield projected pattern matrix gathered by passing through the specified data.
824
+
825
+ :param input: data to be passed through the Hopfield association
826
+ :param stored_pattern_padding_mask: mask to be applied on stored patterns
827
+ :param association_mask: mask to be applied on inner association matrix
828
+ :return: pattern projection matrix as computed by the Hopfield core module
829
+ """
830
+ with torch.no_grad():
831
+ return self.hopfield.get_projected_pattern_matrix(
832
+ input=self._prepare_input(input=input),
833
+ stored_pattern_padding_mask=stored_pattern_padding_mask,
834
+ association_mask=association_mask)
835
+
836
+ @property
837
+ def batch_first(self) -> bool:
838
+ return self.hopfield.batch_first
839
+
840
+ @property
841
+ def scaling(self) -> Union[float, Tensor]:
842
+ return self.hopfield.scaling
843
+
844
+ @property
845
+ def stored_pattern_dim(self) -> Optional[int]:
846
+ return self.hopfield.stored_pattern_dim
847
+
848
+ @property
849
+ def state_pattern_dim(self) -> Optional[int]:
850
+ return self.hopfield.state_pattern_dim
851
+
852
+ @property
853
+ def pattern_projection_dim(self) -> Optional[int]:
854
+ return self.hopfield.pattern_projection_dim
855
+
856
+ @property
857
+ def input_size(self) -> Optional[int]:
858
+ return self.hopfield.input_size
859
+
860
+ @property
861
+ def hidden_size(self) -> int:
862
+ return self.hopfield.hidden_size
863
+
864
+ @property
865
+ def output_size(self) -> Optional[int]:
866
+ return self.hopfield.output_size
867
+
868
+ @property
869
+ def pattern_size(self) -> Optional[int]:
870
+ return self.hopfield.pattern_size
871
+
872
+ @property
873
+ def quantity(self) -> int:
874
+ return self._quantity
875
+
876
+ @property
877
+ def update_steps_max(self) -> Optional[Union[int, Tensor]]:
878
+ return self.hopfield.update_steps_max
879
+
880
+ @property
881
+ def update_steps_eps(self) -> Optional[Union[float, Tensor]]:
882
+ return self.hopfield.update_steps_eps
883
+
884
+ @property
885
+ def stored_pattern_as_static(self) -> bool:
886
+ return self.hopfield.stored_pattern_as_static
887
+
888
+ @property
889
+ def state_pattern_as_static(self) -> bool:
890
+ return self.hopfield.state_pattern_as_static
891
+
892
+ @property
893
+ def pattern_projection_as_static(self) -> bool:
894
+ return self.hopfield.pattern_projection_as_static
895
+
896
+ @property
897
+ def normalize_stored_pattern(self) -> bool:
898
+ return self.hopfield.normalize_stored_pattern
899
+
900
+ @property
901
+ def normalize_stored_pattern_affine(self) -> bool:
902
+ return self.hopfield.normalize_stored_pattern_affine
903
+
904
+ @property
905
+ def normalize_state_pattern(self) -> bool:
906
+ return self.hopfield.normalize_state_pattern
907
+
908
+ @property
909
+ def normalize_state_pattern_affine(self) -> bool:
910
+ return self.hopfield.normalize_state_pattern_affine
911
+
912
+ @property
913
+ def normalize_pattern_projection(self) -> bool:
914
+ return self.hopfield.normalize_pattern_projection
915
+
916
+ @property
917
+ def normalize_pattern_projection_affine(self) -> bool:
918
+ return self.hopfield.normalize_pattern_projection_affine
919
+
920
+
921
+ # Diffusion-augmented variant -- imported *after* Hopfield is fully defined to
922
+ # avoid a circular import (diffused_attention.py does 'from . import Hopfield').
923
+ from .diffused_attention import DiffusedHopfield # noqa: E402
924
+ from .attention_operator import AttentionOperator # noqa: E402
925
+ from .diffusion import ( # noqa: E402
926
+ DiffusionOperator,
927
+ SimpleDiffusion,
928
+ IterativeDiffusion,
929
+ SpectralDiffusion,
930
+ FactoredDiffusion,
931
+ apply_diffusion,
932
+ )
933
+ from .dynamics_engine import ( # noqa: E402
934
+ DiffusionConfig,
935
+ GraphCache,
936
+ EnergyTracker,
937
+ DynamicsEngine,
938
+ )
939
+ from .graph import GraphBuilder, LaplacianBuilder # noqa: E402
940
+
941
+ __all__ = [
942
+ # Original Hopfield modules
943
+ "Hopfield",
944
+ "HopfieldPooling",
945
+ "HopfieldLayer",
946
+ "HopfieldCore",
947
+ # DAHN — main model
948
+ "DiffusedHopfield",
949
+ # DAHN — diffusion operators
950
+ "DiffusionOperator",
951
+ "SimpleDiffusion",
952
+ "IterativeDiffusion",
953
+ "SpectralDiffusion",
954
+ "FactoredDiffusion",
955
+ "apply_diffusion",
956
+ # DAHN — graph utilities
957
+ "GraphBuilder",
958
+ "LaplacianBuilder",
959
+ # DAHN — runtime components
960
+ "DiffusionConfig",
961
+ "GraphCache",
962
+ "EnergyTracker",
963
+ "DynamicsEngine",
964
+ "AttentionOperator",
965
+ ]