difflayers 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- difflayers/__init__.py +965 -0
- difflayers/activation.py +339 -0
- difflayers/attention_operator.py +157 -0
- difflayers/auxiliary/__init__.py +0 -0
- difflayers/auxiliary/data.py +252 -0
- difflayers/diffused_attention.py +427 -0
- difflayers/diffusion.py +395 -0
- difflayers/dynamics_engine.py +540 -0
- difflayers/functional.py +459 -0
- difflayers/graph/__init__.py +18 -0
- difflayers/graph/build_graph.py +77 -0
- difflayers/graph/builder.py +120 -0
- difflayers/graph/laplacian.py +76 -0
- difflayers/graph/laplacian_builder.py +64 -0
- difflayers/transformer.py +212 -0
- difflayers-0.1.0.dist-info/METADATA +210 -0
- difflayers-0.1.0.dist-info/RECORD +20 -0
- difflayers-0.1.0.dist-info/WHEEL +5 -0
- difflayers-0.1.0.dist-info/licenses/LICENSE +79 -0
- difflayers-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from math import ceil
|
|
4
|
+
from torch.utils.data import Dataset
|
|
5
|
+
from typing import Dict, Optional, Sequence, Tuple, Union
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BitPatternSet(Dataset):
|
|
9
|
+
"""
|
|
10
|
+
Binary multiple instance learning (MIL) data set comprising bit patterns as instances,
|
|
11
|
+
with implanted bit patterns unique to one of the classes.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, num_bags: int, num_instances: int, num_signals: int, num_signals_per_bag: int = 1,
|
|
15
|
+
fraction_targets: float = 0.5, num_bits: int = 8, dtype: torch.dtype = torch.float32,
|
|
16
|
+
seed_signals: int = 43, seed_data: int = 44):
|
|
17
|
+
"""
|
|
18
|
+
Create new binary bit pattern data set conforming to the specified properties.
|
|
19
|
+
|
|
20
|
+
:param num_bags: amount of bags
|
|
21
|
+
:param num_instances: amount of instances per bag
|
|
22
|
+
:param num_signals: amount of unique signals used to distinguish both classes
|
|
23
|
+
:param num_signals_per_bag: amount of unique signals to be implanted per bag
|
|
24
|
+
:param fraction_targets: fraction of targets
|
|
25
|
+
:param num_bits: amount of bits per instance
|
|
26
|
+
:param dtype: data type of instances
|
|
27
|
+
:param seed_signals: random seed used to generate the signals of the data set (excl. samples)
|
|
28
|
+
:param seed_data: random seed used to generate the samples of the data set (excl. signals)
|
|
29
|
+
"""
|
|
30
|
+
super(BitPatternSet, self).__init__()
|
|
31
|
+
assert (type(num_bags) == int) and (num_bags > 0), r'"num_bags" must be a positive integer!'
|
|
32
|
+
assert (type(num_instances) == int) and (num_instances > 0), r'"num_instances" must be a positive integer!'
|
|
33
|
+
assert (type(num_signals) == int) and (num_signals > 0), r'"num_signals" must be a positive integer!'
|
|
34
|
+
assert (type(num_signals_per_bag) == int) and (num_signals_per_bag >= 0) and (
|
|
35
|
+
num_signals_per_bag <= num_instances), r'"num_signals_per_bag" must be a non-negative integer!'
|
|
36
|
+
assert (type(fraction_targets) == float) and (fraction_targets > 0) and (
|
|
37
|
+
fraction_targets < 1), r'"fraction_targets" must be in interval (0; 1)!'
|
|
38
|
+
assert (type(num_bits) == int) and (num_bits > 0), r'"num_bits" must be a positive integer!'
|
|
39
|
+
assert ((2 ** num_bits) - 1) > num_signals, r'"num_signals" must be smaller than "2 ** num_bits - 1"!'
|
|
40
|
+
assert type(seed_signals) == int, r'"seed_signals" must be an integer!'
|
|
41
|
+
assert type(seed_data) == int, r'"seed_data" must be an integer!'
|
|
42
|
+
|
|
43
|
+
self.__num_bags = num_bags
|
|
44
|
+
self.__num_instances = num_instances
|
|
45
|
+
self.__num_signals = num_signals
|
|
46
|
+
self.__num_signals_per_bag = num_signals_per_bag
|
|
47
|
+
self.__fraction_targets = fraction_targets
|
|
48
|
+
self.__num_targets = min(self.__num_bags, max(1, ceil(self.__num_bags * self.__fraction_targets)))
|
|
49
|
+
self.__num_bits = num_bits
|
|
50
|
+
self.__dtype = dtype
|
|
51
|
+
self.__seed_signals = seed_signals
|
|
52
|
+
self.__seed_data = seed_data
|
|
53
|
+
self.__data, self.__targets, self.__signals = self._generate_bit_pattern_set()
|
|
54
|
+
|
|
55
|
+
def __len__(self) -> int:
|
|
56
|
+
"""
|
|
57
|
+
Fetch amount of bags.
|
|
58
|
+
|
|
59
|
+
:return: amount of bags
|
|
60
|
+
"""
|
|
61
|
+
return self.__num_bags
|
|
62
|
+
|
|
63
|
+
def __getitem__(self, item_index: int) -> Dict[str, torch.Tensor]:
|
|
64
|
+
"""
|
|
65
|
+
Fetch specific bag.
|
|
66
|
+
|
|
67
|
+
:param item_index: specific bag to fetch
|
|
68
|
+
:return: specific bag as dictionary of tensors
|
|
69
|
+
"""
|
|
70
|
+
return {r'data': self.__data[item_index].to(dtype=self.__dtype),
|
|
71
|
+
r'target': self.__targets[item_index].to(dtype=self.__dtype)}
|
|
72
|
+
|
|
73
|
+
def _generate_bit_pattern_set(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
74
|
+
"""
|
|
75
|
+
Generate underlying bit pattern data set.
|
|
76
|
+
|
|
77
|
+
:return: tuple containing generated bags, targets and signals
|
|
78
|
+
"""
|
|
79
|
+
torch.random.manual_seed(seed=self.__seed_signals)
|
|
80
|
+
|
|
81
|
+
# Generate signal patterns.
|
|
82
|
+
generated_signals = torch.randint(low=0, high=2, size=(self.__num_signals, self.__num_bits))
|
|
83
|
+
check_instances = True
|
|
84
|
+
while check_instances:
|
|
85
|
+
generated_signals = torch.unique(input=generated_signals, dim=0)
|
|
86
|
+
generated_signals = generated_signals[generated_signals.sum(axis=1) != 0]
|
|
87
|
+
missing_signals = self.__num_signals - generated_signals.shape[0]
|
|
88
|
+
if missing_signals > 0:
|
|
89
|
+
generated_signals = torch.cat(tensors=(
|
|
90
|
+
generated_signals, torch.randint(low=0, high=2, size=(missing_signals, self.__num_bits))), dim=0)
|
|
91
|
+
else:
|
|
92
|
+
check_instances = False
|
|
93
|
+
|
|
94
|
+
# Generate data and target tensors.
|
|
95
|
+
torch.random.manual_seed(seed=self.__seed_data)
|
|
96
|
+
generated_data = torch.randint(low=0, high=2, size=(self.__num_bags, self.__num_instances, self.__num_bits))
|
|
97
|
+
generated_targets = torch.zeros(size=(self.__num_bags,), dtype=generated_data.dtype)
|
|
98
|
+
generated_targets[:self.__num_targets] = 1
|
|
99
|
+
|
|
100
|
+
# Check invalid (all-zero and signal) instances and re-sample them.
|
|
101
|
+
check_instances = True
|
|
102
|
+
while check_instances:
|
|
103
|
+
invalid_instances = (generated_data.sum(axis=2) == 0).logical_or(
|
|
104
|
+
torch.sum(torch.stack([(generated_data == _).all(axis=2) for _ in generated_signals]), axis=0))
|
|
105
|
+
if invalid_instances.sum() > 0:
|
|
106
|
+
generated_data[invalid_instances] = torch.randint(
|
|
107
|
+
low=0, high=2, size=(invalid_instances.sum(), self.__num_bits))
|
|
108
|
+
else:
|
|
109
|
+
check_instances = False
|
|
110
|
+
|
|
111
|
+
# Re-implant signal into respective bags.
|
|
112
|
+
for data_index in range(self.__num_targets):
|
|
113
|
+
implantation_indices = []
|
|
114
|
+
for _ in range(self.__num_signals_per_bag):
|
|
115
|
+
while True:
|
|
116
|
+
current_implantation_index = torch.randint(low=0, high=generated_data.shape[1], size=(1,))
|
|
117
|
+
if current_implantation_index not in implantation_indices:
|
|
118
|
+
implantation_indices.append(current_implantation_index)
|
|
119
|
+
break
|
|
120
|
+
current_signal_index = torch.randint(low=0, high=generated_signals.shape[0], size=(1,))
|
|
121
|
+
generated_data[data_index, current_implantation_index] = generated_signals[current_signal_index]
|
|
122
|
+
|
|
123
|
+
return generated_data, generated_targets, generated_signals
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def num_bags(self) -> int:
|
|
127
|
+
return self.__num_bags
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def num_instances(self) -> int:
|
|
131
|
+
return self.__num_instances
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def num_bits(self) -> int:
|
|
135
|
+
return self.__num_bits
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def num_targets_high(self) -> int:
|
|
139
|
+
return self.__num_targets
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def num_targets_low(self) -> int:
|
|
143
|
+
return self.__num_bags - self.__num_targets
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def num_signals(self) -> int:
|
|
147
|
+
return self.__num_signals
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def num_signals_per_bag(self) -> int:
|
|
151
|
+
return self.__num_signals_per_bag
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def initial_seed(self) -> int:
|
|
155
|
+
return self.__seed_signals
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def bags(self) -> torch.Tensor:
|
|
159
|
+
return self.__data.clone()
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def targets(self) -> torch.Tensor:
|
|
163
|
+
return self.__targets.clone()
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def signals(self) -> torch.Tensor:
|
|
167
|
+
return self.__signals.clone()
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class LatchSequenceSet(Dataset):
|
|
171
|
+
"""
|
|
172
|
+
Latch data set comprising patterns as one-hot encoded instances.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def __init__(self, num_samples: int, num_instances: int = 20, num_characters: int = 6,
|
|
176
|
+
dtype: torch.dtype = torch.float32, seed: int = 43):
|
|
177
|
+
"""
|
|
178
|
+
Create new latch sequence data set conforming to the specified properties.
|
|
179
|
+
|
|
180
|
+
:param num_samples: amount of samples
|
|
181
|
+
:param num_instances: amount of instances per sample
|
|
182
|
+
:param num_characters: amount of different characters
|
|
183
|
+
:param dtype: data type of samples
|
|
184
|
+
:param seed: random seed used to generate the samples of the data set
|
|
185
|
+
"""
|
|
186
|
+
super(LatchSequenceSet, self).__init__()
|
|
187
|
+
assert (type(num_samples) == int) and (num_samples > 0), r'"num_samples" must be a positive integer!'
|
|
188
|
+
assert (type(num_instances) == int) and (num_instances > 0), r'"num_instances" must be a positive integer!'
|
|
189
|
+
assert (type(num_characters) == int) and (num_characters > 0), r'"num_characters" must be a positive integer!'
|
|
190
|
+
assert type(seed) == int, r'"seed" must be an integer!'
|
|
191
|
+
|
|
192
|
+
self.__num_samples = num_samples
|
|
193
|
+
self.__num_instances = num_instances
|
|
194
|
+
self.__num_characters = num_characters
|
|
195
|
+
self.__dtype = dtype
|
|
196
|
+
self.__seed = seed
|
|
197
|
+
self.__data, self.__targets = self._generate_latch_sequences()
|
|
198
|
+
|
|
199
|
+
def __len__(self) -> int:
|
|
200
|
+
"""
|
|
201
|
+
Fetch amount of samples.
|
|
202
|
+
|
|
203
|
+
:return: amount of samples
|
|
204
|
+
"""
|
|
205
|
+
return self.__num_samples
|
|
206
|
+
|
|
207
|
+
def __getitem__(self, item_index: int) -> Dict[str, torch.Tensor]:
|
|
208
|
+
"""
|
|
209
|
+
Fetch specific sample.
|
|
210
|
+
|
|
211
|
+
:param item_index: specific sample to fetch
|
|
212
|
+
:return: specific sample as dictionary of tensors
|
|
213
|
+
"""
|
|
214
|
+
return {r'data': self.__data[item_index].to(dtype=self.__dtype),
|
|
215
|
+
r'target': self.__targets[item_index].to(dtype=self.__dtype)}
|
|
216
|
+
|
|
217
|
+
def _generate_latch_sequences(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
218
|
+
"""
|
|
219
|
+
Generate underlying latch sequence data set.
|
|
220
|
+
|
|
221
|
+
:return: tuple containing generated data and targets
|
|
222
|
+
"""
|
|
223
|
+
torch.random.manual_seed(seed=self.__seed)
|
|
224
|
+
|
|
225
|
+
# Generate data and target tensors.
|
|
226
|
+
generated_data = torch.randint(
|
|
227
|
+
low=2, high=self.__num_characters, size=(self.__num_samples, self.__num_instances))
|
|
228
|
+
generated_signal = torch.randint(low=0, high=2, size=(self.__num_samples,))
|
|
229
|
+
generated_data[:, 0] = generated_signal
|
|
230
|
+
generated_data = torch.nn.functional.one_hot(input=generated_data, num_classes=self.__num_characters)
|
|
231
|
+
|
|
232
|
+
return generated_data, generated_signal
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def num_samples(self) -> int:
|
|
236
|
+
return self.__num_samples
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def num_instances(self) -> int:
|
|
240
|
+
return self.__num_instances
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def num_characters(self) -> int:
|
|
244
|
+
return self.__num_characters
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def initial_seed(self) -> int:
|
|
248
|
+
return self.__seed
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def targets(self) -> torch.Tensor:
|
|
252
|
+
return self.__targets.clone()
|
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Graph-Regularized (Diffusion-Augmented) Hopfield Attention.
|
|
3
|
+
|
|
4
|
+
Core idea
|
|
5
|
+
---------
|
|
6
|
+
Core model / dynamics loop
|
|
7
|
+
--------------------------
|
|
8
|
+
|
|
9
|
+
x_{t+1} = Attention(D · x_t) — spec Section 0
|
|
10
|
+
|
|
11
|
+
In practice the stored key patterns (and optionally queries) are diffused
|
|
12
|
+
before the Hopfield attention layer:
|
|
13
|
+
|
|
14
|
+
for t in range(T): # via DynamicsEngine.run_diffusion
|
|
15
|
+
K' = D @ K # D = I - η*L (or factored form)
|
|
16
|
+
Q' = D @ Q
|
|
17
|
+
Attention = softmax(β Q' K'ᵀ) V — dense O(N²) or graph O(kN)
|
|
18
|
+
|
|
19
|
+
Optionally, post-softmax attention weights are also smoothed over the graph
|
|
20
|
+
(logit-level diffusion):
|
|
21
|
+
weights' = diffuse(weights, L_K, η_logit)
|
|
22
|
+
output = weights' @ V
|
|
23
|
+
|
|
24
|
+
Four diffusion modes (implemented in ``DiffusionOperator`` subclasses):
|
|
25
|
+
* **factored** — x' = (1-η·deg)⊙x + η·W@x. O(kNd), no L formed.
|
|
26
|
+
* **simple** — D = I - η*L, applied once. O(N²d).
|
|
27
|
+
* **iterative** — D^T X. Deeper smoothing; early-stop guard. O(T·N²d).
|
|
28
|
+
* **spectral** — H = U exp(-η Λ) U^T. Exact heat kernel. O(N³) precompute.
|
|
29
|
+
|
|
30
|
+
Two attention modes (implemented in ``AttentionOperator``):
|
|
31
|
+
* **dense** (default) — full softmax(β Q Kᵀ) V. O(N²d). Exact baseline.
|
|
32
|
+
* **graph** — attend only to kNN neighbors. O(kNd). Faster.
|
|
33
|
+
|
|
34
|
+
Design
|
|
35
|
+
------
|
|
36
|
+
``DiffusedHopfield`` is a drop-in replacement for ``Hopfield``. It inherits
|
|
37
|
+
the full constructor and only overrides ``_associate`` to splice in diffusion.
|
|
38
|
+
|
|
39
|
+
Internally it delegates all graph/diffusion work to:
|
|
40
|
+
* ``GraphCache`` — builds and caches (W, deg, adj_idx, L, op) once.
|
|
41
|
+
* ``DynamicsEngine`` — runs T-step loop; no rebuild inside loop.
|
|
42
|
+
* ``AttentionOperator`` — dense or graph-constrained attention.
|
|
43
|
+
* ``EnergyTracker`` — (optional) per-step energy + early-stop.
|
|
44
|
+
|
|
45
|
+
This satisfies Open-Closed: new diffusion modes are added by subclassing
|
|
46
|
+
``DiffusionOperator`` in ``diffusion.py``; no changes needed here.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
from typing import Dict, Optional, Tuple, Union
|
|
50
|
+
|
|
51
|
+
import copy
|
|
52
|
+
|
|
53
|
+
import torch
|
|
54
|
+
import torch.nn.functional as F
|
|
55
|
+
from torch import Tensor
|
|
56
|
+
|
|
57
|
+
from . import Hopfield
|
|
58
|
+
from .attention_operator import AttentionOperator
|
|
59
|
+
from .diffusion import DiffusionOperator
|
|
60
|
+
from .dynamics_engine import DiffusionConfig, DynamicsEngine, EnergyTracker, GraphCache
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class DiffusedHopfield(Hopfield):
|
|
64
|
+
"""
|
|
65
|
+
Hopfield association module augmented with Laplacian graph diffusion.
|
|
66
|
+
|
|
67
|
+
Adds diffusion-specific parameters on top of the full ``Hopfield`` API.
|
|
68
|
+
All graph construction, caching, and diffusion are delegated to
|
|
69
|
+
``GraphCache`` and ``DynamicsEngine``; this class only orchestrates the
|
|
70
|
+
pre-processing hook inside ``_associate``.
|
|
71
|
+
|
|
72
|
+
New parameters (beyond the standard Hopfield signature)
|
|
73
|
+
-------------------------------------------------------
|
|
74
|
+
eta : float — Diffusion strength η. Default: 0.1.
|
|
75
|
+
k_neighbors : int — kNN graph degree. Default: 5.
|
|
76
|
+
use_normalized_laplacian : bool — Symmetric-normalised L. Default: True.
|
|
77
|
+
diffuse_query : bool — Diffuse query patterns. Default: False.
|
|
78
|
+
diffuse_key : bool — Diffuse key patterns. Default: True.
|
|
79
|
+
diffusion_mode : str — 'factored'|'simple'|'iterative'|'spectral'.
|
|
80
|
+
diffusion_steps : int — Iterations for iterative/spectral mode.
|
|
81
|
+
attention_mode : str — 'dense' (O(N²)) or 'graph' (O(kN)).
|
|
82
|
+
use_sparse : bool — Sparse adjacency for O(kN) products.
|
|
83
|
+
use_logit_diffusion : bool — Smooth post-softmax weights.
|
|
84
|
+
logit_eta : float — η for logit-level diffusion.
|
|
85
|
+
adaptive_eta : bool — Entropy-gated η scaling.
|
|
86
|
+
adaptive_temperature : float — Sigmoid temperature for adaptive η.
|
|
87
|
+
adaptive_threshold : float — Entropy midpoint for adaptive η.
|
|
88
|
+
cache_graph : bool — Cache graph/operator between passes.
|
|
89
|
+
energy_stop_tol : float — Early-stop tolerance; 0 = off.
|
|
90
|
+
|
|
91
|
+
All remaining kwargs are forwarded unchanged to ``Hopfield``.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
input_size: Optional[int] = None,
|
|
97
|
+
hidden_size: Optional[int] = None,
|
|
98
|
+
output_size: Optional[int] = None,
|
|
99
|
+
pattern_size: Optional[int] = None,
|
|
100
|
+
num_heads: int = 1,
|
|
101
|
+
scaling: Optional[Union[float, Tensor]] = None,
|
|
102
|
+
update_steps_max: Optional[Union[int, Tensor]] = 0,
|
|
103
|
+
update_steps_eps: Union[float, Tensor] = 1e-4,
|
|
104
|
+
|
|
105
|
+
normalize_stored_pattern: bool = True,
|
|
106
|
+
normalize_stored_pattern_affine: bool = True,
|
|
107
|
+
normalize_stored_pattern_eps: float = 1e-5,
|
|
108
|
+
normalize_state_pattern: bool = True,
|
|
109
|
+
normalize_state_pattern_affine: bool = True,
|
|
110
|
+
normalize_state_pattern_eps: float = 1e-5,
|
|
111
|
+
normalize_pattern_projection: bool = True,
|
|
112
|
+
normalize_pattern_projection_affine: bool = True,
|
|
113
|
+
normalize_pattern_projection_eps: float = 1e-5,
|
|
114
|
+
normalize_hopfield_space: bool = False,
|
|
115
|
+
normalize_hopfield_space_affine: bool = False,
|
|
116
|
+
normalize_hopfield_space_eps: float = 1e-5,
|
|
117
|
+
stored_pattern_as_static: bool = False,
|
|
118
|
+
state_pattern_as_static: bool = False,
|
|
119
|
+
pattern_projection_as_static: bool = False,
|
|
120
|
+
pattern_projection_as_connected: bool = False,
|
|
121
|
+
stored_pattern_size: Optional[int] = None,
|
|
122
|
+
pattern_projection_size: Optional[int] = None,
|
|
123
|
+
|
|
124
|
+
batch_first: bool = True,
|
|
125
|
+
association_activation: Optional[str] = None,
|
|
126
|
+
dropout: float = 0.0,
|
|
127
|
+
input_bias: bool = True,
|
|
128
|
+
concat_bias_pattern: bool = False,
|
|
129
|
+
add_zero_association: bool = False,
|
|
130
|
+
disable_out_projection: bool = False,
|
|
131
|
+
|
|
132
|
+
# --- Diffusion-specific parameters ---
|
|
133
|
+
eta: float = 0.1,
|
|
134
|
+
k_neighbors: int = 5,
|
|
135
|
+
use_normalized_laplacian: bool = True,
|
|
136
|
+
diffuse_query: bool = False,
|
|
137
|
+
diffuse_key: bool = True,
|
|
138
|
+
diffusion_mode: str = "factored",
|
|
139
|
+
diffusion_steps: int = 3,
|
|
140
|
+
attention_mode: str = "dense",
|
|
141
|
+
use_sparse: bool = False,
|
|
142
|
+
use_logit_diffusion: bool = False,
|
|
143
|
+
logit_eta: Optional[float] = None,
|
|
144
|
+
adaptive_eta: bool = False,
|
|
145
|
+
adaptive_temperature: float = 5.0,
|
|
146
|
+
adaptive_threshold: float = 1.0,
|
|
147
|
+
cache_graph: bool = True,
|
|
148
|
+
energy_stop_tol: float = 0.0,
|
|
149
|
+
):
|
|
150
|
+
super().__init__(
|
|
151
|
+
input_size=input_size,
|
|
152
|
+
hidden_size=hidden_size,
|
|
153
|
+
output_size=output_size,
|
|
154
|
+
pattern_size=pattern_size,
|
|
155
|
+
num_heads=num_heads,
|
|
156
|
+
scaling=scaling,
|
|
157
|
+
update_steps_max=update_steps_max,
|
|
158
|
+
update_steps_eps=update_steps_eps,
|
|
159
|
+
normalize_stored_pattern=normalize_stored_pattern,
|
|
160
|
+
normalize_stored_pattern_affine=normalize_stored_pattern_affine,
|
|
161
|
+
normalize_stored_pattern_eps=normalize_stored_pattern_eps,
|
|
162
|
+
normalize_state_pattern=normalize_state_pattern,
|
|
163
|
+
normalize_state_pattern_affine=normalize_state_pattern_affine,
|
|
164
|
+
normalize_state_pattern_eps=normalize_state_pattern_eps,
|
|
165
|
+
normalize_pattern_projection=normalize_pattern_projection,
|
|
166
|
+
normalize_pattern_projection_affine=normalize_pattern_projection_affine,
|
|
167
|
+
normalize_pattern_projection_eps=normalize_pattern_projection_eps,
|
|
168
|
+
normalize_hopfield_space=normalize_hopfield_space,
|
|
169
|
+
normalize_hopfield_space_affine=normalize_hopfield_space_affine,
|
|
170
|
+
normalize_hopfield_space_eps=normalize_hopfield_space_eps,
|
|
171
|
+
stored_pattern_as_static=stored_pattern_as_static,
|
|
172
|
+
state_pattern_as_static=state_pattern_as_static,
|
|
173
|
+
pattern_projection_as_static=pattern_projection_as_static,
|
|
174
|
+
pattern_projection_as_connected=pattern_projection_as_connected,
|
|
175
|
+
stored_pattern_size=stored_pattern_size,
|
|
176
|
+
pattern_projection_size=pattern_projection_size,
|
|
177
|
+
batch_first=batch_first,
|
|
178
|
+
association_activation=association_activation,
|
|
179
|
+
dropout=dropout,
|
|
180
|
+
input_bias=input_bias,
|
|
181
|
+
concat_bias_pattern=concat_bias_pattern,
|
|
182
|
+
add_zero_association=add_zero_association,
|
|
183
|
+
disable_out_projection=disable_out_projection,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Build unified config
|
|
187
|
+
_beta = float(scaling) if isinstance(scaling, (int, float)) else 1.0
|
|
188
|
+
self._diff_cfg = DiffusionConfig(
|
|
189
|
+
eta=eta,
|
|
190
|
+
beta=_beta,
|
|
191
|
+
steps=diffusion_steps,
|
|
192
|
+
diffusion_mode=diffusion_mode,
|
|
193
|
+
attention_mode=attention_mode,
|
|
194
|
+
k_neighbors=k_neighbors,
|
|
195
|
+
use_normalized_laplacian=use_normalized_laplacian,
|
|
196
|
+
use_sparse=use_sparse,
|
|
197
|
+
diffuse_key=diffuse_key,
|
|
198
|
+
diffuse_query=diffuse_query,
|
|
199
|
+
use_logit_diffusion=use_logit_diffusion,
|
|
200
|
+
logit_eta=logit_eta if logit_eta is not None else eta,
|
|
201
|
+
adaptive_eta=adaptive_eta,
|
|
202
|
+
adaptive_temperature=adaptive_temperature,
|
|
203
|
+
adaptive_threshold=adaptive_threshold,
|
|
204
|
+
cache_graph=cache_graph,
|
|
205
|
+
energy_stop_tol=energy_stop_tol,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Separate cache per role (key vs query patterns may differ in shape)
|
|
209
|
+
self._key_cache = GraphCache(self._diff_cfg)
|
|
210
|
+
self._query_cache = GraphCache(self._diff_cfg)
|
|
211
|
+
|
|
212
|
+
# Attention operator — dense O(N²) or graph-constrained O(kN)
|
|
213
|
+
self._attn_op = AttentionOperator(beta=_beta, mode=attention_mode)
|
|
214
|
+
|
|
215
|
+
# Energy tracker (shared across key/query for the last step)
|
|
216
|
+
self._energy_tracker: Optional[EnergyTracker] = (
|
|
217
|
+
EnergyTracker(
|
|
218
|
+
beta=_beta,
|
|
219
|
+
eta=eta,
|
|
220
|
+
tol=energy_stop_tol,
|
|
221
|
+
)
|
|
222
|
+
if energy_stop_tol > 0.0 else None
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Expose scalar hypers for backward compat property access
|
|
226
|
+
self.eta = eta
|
|
227
|
+
self.k_neighbors = k_neighbors
|
|
228
|
+
self.use_normalized_laplacian = use_normalized_laplacian
|
|
229
|
+
self.diffuse_query = diffuse_query
|
|
230
|
+
self.diffuse_key = diffuse_key
|
|
231
|
+
self.diffusion_mode = diffusion_mode
|
|
232
|
+
self.diffusion_steps = diffusion_steps
|
|
233
|
+
self.attention_mode = attention_mode
|
|
234
|
+
self.use_logit_diffusion = use_logit_diffusion
|
|
235
|
+
self.logit_eta = logit_eta if logit_eta is not None else eta
|
|
236
|
+
self.adaptive_eta = adaptive_eta
|
|
237
|
+
self.adaptive_temperature = adaptive_temperature
|
|
238
|
+
self.adaptive_threshold = adaptive_threshold
|
|
239
|
+
self.cache_graph = cache_graph
|
|
240
|
+
|
|
241
|
+
# Cache Hopfield name-mangled attrs used in association_core call
|
|
242
|
+
self._d_update_steps_max = update_steps_max
|
|
243
|
+
self._d_update_steps_eps = update_steps_eps
|
|
244
|
+
|
|
245
|
+
# ------------------------------------------------------------------
|
|
246
|
+
# Cache management
|
|
247
|
+
# ------------------------------------------------------------------
|
|
248
|
+
|
|
249
|
+
def invalidate_cache(self) -> None:
|
|
250
|
+
"""Clear cached Laplacians and operators. Call when patterns change."""
|
|
251
|
+
self._key_cache.invalidate()
|
|
252
|
+
self._query_cache.invalidate()
|
|
253
|
+
|
|
254
|
+
# ------------------------------------------------------------------
|
|
255
|
+
# Override _associate to inject diffusion
|
|
256
|
+
# ------------------------------------------------------------------
|
|
257
|
+
|
|
258
|
+
def _associate(
|
|
259
|
+
self,
|
|
260
|
+
data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
|
|
261
|
+
return_raw_associations: bool = False,
|
|
262
|
+
return_projected_patterns: bool = False,
|
|
263
|
+
stored_pattern_padding_mask: Optional[Tensor] = None,
|
|
264
|
+
association_mask: Optional[Tensor] = None,
|
|
265
|
+
) -> Tuple[Optional[Tensor], ...]:
|
|
266
|
+
"""
|
|
267
|
+
Hopfield association with graph-diffusion pre-processing.
|
|
268
|
+
|
|
269
|
+
Mirrors ``Hopfield._associate`` exactly; inserts a DynamicsEngine
|
|
270
|
+
diffusion pass on stored / state patterns after optional LayerNorm
|
|
271
|
+
but before HopfieldCore attention.
|
|
272
|
+
|
|
273
|
+
No graph is rebuilt if the same patterns are passed again
|
|
274
|
+
(GraphCache returns the cached DiffusionOperator in O(1)).
|
|
275
|
+
"""
|
|
276
|
+
assert (type(data) == Tensor) or (
|
|
277
|
+
(type(data) == tuple) and (len(data) == 3)
|
|
278
|
+
), (
|
|
279
|
+
"either one tensor or a 3-tuple "
|
|
280
|
+
"(stored_pattern, state_pattern, pattern_projection) must be provided."
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
if type(data) == Tensor:
|
|
284
|
+
stored_pattern = state_pattern = pattern_projection = data
|
|
285
|
+
else:
|
|
286
|
+
stored_pattern, state_pattern, pattern_projection = data
|
|
287
|
+
|
|
288
|
+
# --- batch_first transpose ---
|
|
289
|
+
stored_pattern, state_pattern, pattern_projection = self._maybe_transpose(
|
|
290
|
+
stored_pattern, state_pattern, pattern_projection
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# --- Optional LayerNorm (mirroring Hopfield._associate) ---
|
|
294
|
+
if self.norm_stored_pattern is not None:
|
|
295
|
+
stored_pattern = self.norm_stored_pattern(
|
|
296
|
+
input=stored_pattern.reshape(-1, stored_pattern.shape[2])
|
|
297
|
+
).reshape(stored_pattern.shape)
|
|
298
|
+
|
|
299
|
+
if self.norm_state_pattern is not None:
|
|
300
|
+
state_pattern = self.norm_state_pattern(
|
|
301
|
+
input=state_pattern.reshape(-1, state_pattern.shape[2])
|
|
302
|
+
).reshape(state_pattern.shape)
|
|
303
|
+
|
|
304
|
+
if self.norm_pattern_projection is not None:
|
|
305
|
+
pattern_projection = self.norm_pattern_projection(
|
|
306
|
+
input=pattern_projection.reshape(-1, pattern_projection.shape[2])
|
|
307
|
+
).reshape(pattern_projection.shape)
|
|
308
|
+
|
|
309
|
+
# --- Full dynamics loop: interleaved diffusion + attention (§0, §4) ---
|
|
310
|
+
cfg = self._diff_cfg
|
|
311
|
+
_W_k = _deg_k = _adj_k = L_k = op_k = None
|
|
312
|
+
_W_q = _deg_q = _adj_q = L_q = op_q = None
|
|
313
|
+
if cfg.eta > 0.0 and (cfg.diffuse_key or cfg.diffuse_query):
|
|
314
|
+
key_repr = stored_pattern.detach().mean(dim=1).float()
|
|
315
|
+
_W_k, _deg_k, _adj_k, L_k, op_k = self._key_cache.get(key_repr)
|
|
316
|
+
|
|
317
|
+
# Build a query-specific graph so Q is diffused over its own topology.
|
|
318
|
+
if cfg.diffuse_query:
|
|
319
|
+
q_repr = state_pattern.detach().mean(dim=1).float()
|
|
320
|
+
_W_q, _deg_q, _adj_q, L_q, op_q = self._query_cache.get(q_repr)
|
|
321
|
+
|
|
322
|
+
# Adaptive η: scale by attention entropy before dynamics loop
|
|
323
|
+
eta_eff = cfg.eta
|
|
324
|
+
if cfg.adaptive_eta:
|
|
325
|
+
raw_logits = torch.bmm(
|
|
326
|
+
state_pattern.permute(1, 0, 2),
|
|
327
|
+
stored_pattern.permute(1, 2, 0),
|
|
328
|
+
) # (B, L, S)
|
|
329
|
+
engine_tmp = DynamicsEngine(op_k)
|
|
330
|
+
eta_eff = engine_tmp.compute_adaptive_eta(
|
|
331
|
+
raw_logits, cfg.eta,
|
|
332
|
+
cfg.adaptive_temperature, cfg.adaptive_threshold,
|
|
333
|
+
)
|
|
334
|
+
if abs(eta_eff - cfg.eta) / (cfg.eta + 1e-9) > 0.05:
|
|
335
|
+
from .diffusion import FactoredDiffusion
|
|
336
|
+
if cfg.diffusion_mode == "factored":
|
|
337
|
+
op_k = FactoredDiffusion(
|
|
338
|
+
eta=eta_eff, steps=cfg.steps
|
|
339
|
+
).precompute_from_graph(_W_k, _deg_k)
|
|
340
|
+
else:
|
|
341
|
+
op_k = DiffusionOperator.create(
|
|
342
|
+
cfg.diffusion_mode, eta_eff, cfg.steps
|
|
343
|
+
).precompute(L_k)
|
|
344
|
+
# Rebuild query op with same eta_eff
|
|
345
|
+
if op_q is not None:
|
|
346
|
+
if cfg.diffusion_mode == "factored":
|
|
347
|
+
op_q = FactoredDiffusion(
|
|
348
|
+
eta=eta_eff, steps=cfg.steps
|
|
349
|
+
).precompute_from_graph(_W_q, _deg_q)
|
|
350
|
+
else:
|
|
351
|
+
op_q = DiffusionOperator.create(
|
|
352
|
+
cfg.diffusion_mode, eta_eff, cfg.steps
|
|
353
|
+
).precompute(L_q)
|
|
354
|
+
|
|
355
|
+
# Single-step copy for dynamics (outer loop controls iteration count)
|
|
356
|
+
op_dyn = copy.copy(op_k)
|
|
357
|
+
op_dyn.steps = 1
|
|
358
|
+
|
|
359
|
+
op_q_dyn = None
|
|
360
|
+
if op_q is not None:
|
|
361
|
+
op_q_dyn = copy.copy(op_q)
|
|
362
|
+
op_q_dyn.steps = 1
|
|
363
|
+
|
|
364
|
+
engine = DynamicsEngine(
|
|
365
|
+
diffusion_op=op_dyn,
|
|
366
|
+
attention_op=self._attn_op,
|
|
367
|
+
steps=self.diffusion_steps,
|
|
368
|
+
energy_tracker=self._energy_tracker,
|
|
369
|
+
query_diffusion_op=op_q_dyn,
|
|
370
|
+
)
|
|
371
|
+
state_pattern, stored_pattern = engine.run_dynamics(
|
|
372
|
+
Q=state_pattern, K=stored_pattern, V=pattern_projection,
|
|
373
|
+
adj_indices=_adj_k, L=L_k, W=_W_k, deg=_deg_k,
|
|
374
|
+
diffuse_query=cfg.diffuse_query, diffuse_key=cfg.diffuse_key,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
# --- Logit-level diffusion (post-softmax weight smoothing) ---
|
|
378
|
+
# Injected after core association when requested.
|
|
379
|
+
# We apply it on the raw-association output before returning.
|
|
380
|
+
result = self.association_core(
|
|
381
|
+
query=state_pattern,
|
|
382
|
+
key=stored_pattern,
|
|
383
|
+
value=pattern_projection,
|
|
384
|
+
key_padding_mask=stored_pattern_padding_mask,
|
|
385
|
+
need_weights=cfg.use_logit_diffusion, # need weights for smoothing
|
|
386
|
+
attn_mask=association_mask,
|
|
387
|
+
scaling=self.scaling,
|
|
388
|
+
update_steps_max=self._d_update_steps_max,
|
|
389
|
+
update_steps_eps=self._d_update_steps_eps,
|
|
390
|
+
return_raw_associations=return_raw_associations,
|
|
391
|
+
return_pattern_projections=return_projected_patterns,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
if cfg.use_logit_diffusion and cfg.logit_eta > 0.0:
|
|
395
|
+
# result[0] = output, result[1] = attn_weights (B, H, L, S)
|
|
396
|
+
# Smooth weights over the key graph and re-normalise.
|
|
397
|
+
attn_weights = result[1] # (B, H, L, S) or None
|
|
398
|
+
if attn_weights is not None:
|
|
399
|
+
# Reuse cached graph if available; otherwise build from keys
|
|
400
|
+
if op_k is None:
|
|
401
|
+
key_repr = stored_pattern.detach().mean(dim=1).float()
|
|
402
|
+
_W_k, _deg_k, _adj_k, L_k, op_k = self._key_cache.get(key_repr)
|
|
403
|
+
op_logit = op_k
|
|
404
|
+
# Treat (S,) distribution per query as the diffusion signal.
|
|
405
|
+
# Flatten to (S, B*H*L), diffuse, reshape and renormalise.
|
|
406
|
+
B, H, L_q_len, S = attn_weights.shape
|
|
407
|
+
w_flat = attn_weights.permute(3, 0, 1, 2).reshape(S, B * H * L_q_len)
|
|
408
|
+
w_diff = op_logit(w_flat) # diffuse along S
|
|
409
|
+
w_diff = w_diff.reshape(S, B, H, L_q_len).permute(1, 2, 3, 0)
|
|
410
|
+
w_diff = w_diff.clamp(min=0.0)
|
|
411
|
+
w_diff = w_diff / (w_diff.sum(dim=-1, keepdim=True) + 1e-9)
|
|
412
|
+
# Recompute output with smoothed weights.
|
|
413
|
+
V = pattern_projection.permute(1, 0, 2) # (B, S, d)
|
|
414
|
+
out_smooth = torch.einsum("bhls,bsd->bld", w_diff, V)
|
|
415
|
+
out_smooth = out_smooth.permute(1, 0, 2) # (L, B, d)
|
|
416
|
+
result = (out_smooth,) + result[1:]
|
|
417
|
+
|
|
418
|
+
return result
|
|
419
|
+
|
|
420
|
+
# ------------------------------------------------------------------
|
|
421
|
+
# Config dict API
|
|
422
|
+
# ------------------------------------------------------------------
|
|
423
|
+
|
|
424
|
+
def get_config(self) -> Dict[str, object]:
|
|
425
|
+
"""Return a JSON-serialisable dict of all diffusion hyperparameters."""
|
|
426
|
+
return self._diff_cfg.to_dict()
|
|
427
|
+
|