difflayers 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- difflayers/__init__.py +965 -0
- difflayers/activation.py +339 -0
- difflayers/attention_operator.py +157 -0
- difflayers/auxiliary/__init__.py +0 -0
- difflayers/auxiliary/data.py +252 -0
- difflayers/diffused_attention.py +427 -0
- difflayers/diffusion.py +395 -0
- difflayers/dynamics_engine.py +540 -0
- difflayers/functional.py +459 -0
- difflayers/graph/__init__.py +18 -0
- difflayers/graph/build_graph.py +77 -0
- difflayers/graph/builder.py +120 -0
- difflayers/graph/laplacian.py +76 -0
- difflayers/graph/laplacian_builder.py +64 -0
- difflayers/transformer.py +212 -0
- difflayers-0.1.0.dist-info/METADATA +210 -0
- difflayers-0.1.0.dist-info/RECORD +20 -0
- difflayers-0.1.0.dist-info/WHEEL +5 -0
- difflayers-0.1.0.dist-info/licenses/LICENSE +79 -0
- difflayers-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,540 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dynamical Memory Engine for iterative diffusion-attention systems.
|
|
3
|
+
|
|
4
|
+
Responsibility: Orchestrate the interleaved diffusion → attention loop.
|
|
5
|
+
|
|
6
|
+
This module provides four focused classes following Single Responsibility:
|
|
7
|
+
|
|
8
|
+
GraphCache — builds and caches (W, deg, adj_indices, L, op) once per input
|
|
9
|
+
DynamicsEngine — runs T steps of diffusion → attention on Q, K, V patterns
|
|
10
|
+
EnergyTracker — computes and tracks Hopfield energy across steps
|
|
11
|
+
DiffusionConfig — single serialisable config object shared by all classes
|
|
12
|
+
|
|
13
|
+
Design notes
|
|
14
|
+
------------
|
|
15
|
+
* ``GraphCache`` is the single source of truth for all graph objects.
|
|
16
|
+
No other module rebuilds W, deg, L, or the diffusion operator.
|
|
17
|
+
* ``DynamicsEngine`` holds references to a ``DiffusionOperator`` and an
|
|
18
|
+
``AttentionOperator``; it does NOT rebuild either per step or per call.
|
|
19
|
+
* ``FactoredDiffusion`` is the default for sparse mode — it never forms L.
|
|
20
|
+
``SpectralDiffusion`` and ``SimpleDiffusion`` use L when required.
|
|
21
|
+
* ``EnergyTracker`` is optional and zero-cost when disabled.
|
|
22
|
+
|
|
23
|
+
Full dynamics loop (Section 4.1 of spec)
|
|
24
|
+
-----------------------------------------
|
|
25
|
+
for t in range(T):
|
|
26
|
+
Q = diffusion(Q)
|
|
27
|
+
K = diffusion(K)
|
|
28
|
+
Q = attention(Q, K, V)
|
|
29
|
+
|
|
30
|
+
Each iteration costs:
|
|
31
|
+
diffusion : O(kNd) with FactoredDiffusion + sparse W
|
|
32
|
+
attention : O(N²d) dense OR O(kNd) graph mode
|
|
33
|
+
|
|
34
|
+
Complexity (N patterns, d features, T steps, k graph neighbours)
|
|
35
|
+
-----------------------------------------------------------------
|
|
36
|
+
GraphCache.get build : O(N²d) sim-matrix + O(N²) kNN + O(N³) if spectral
|
|
37
|
+
GraphCache.get hit : O(1)
|
|
38
|
+
DynamicsEngine.run : O(T * kNd) sparse factored diffusion + graph attn
|
|
39
|
+
O(T * N^2d) dense diffusion + dense attn
|
|
40
|
+
EnergyTracker.step : O(N²)
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
from __future__ import annotations
|
|
44
|
+
|
|
45
|
+
from dataclasses import dataclass, field
|
|
46
|
+
from typing import Dict, List, Optional, Tuple
|
|
47
|
+
|
|
48
|
+
import torch
|
|
49
|
+
from torch import Tensor
|
|
50
|
+
|
|
51
|
+
from .attention_operator import AttentionOperator
|
|
52
|
+
from .diffusion import DiffusionOperator, FactoredDiffusion
|
|
53
|
+
from .graph.builder import GraphBuilder
|
|
54
|
+
from .graph.laplacian_builder import LaplacianBuilder
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# ---------------------------------------------------------------------------
|
|
58
|
+
# Config dataclass
|
|
59
|
+
# ---------------------------------------------------------------------------
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class DiffusionConfig:
|
|
63
|
+
"""
|
|
64
|
+
Unified, serialisable configuration for the dynamical diffusion system.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
eta: Diffusion strength eta.
|
|
68
|
+
beta: Hopfield scaling / attention inverse-temperature.
|
|
69
|
+
steps: Number of diffusion+attention iterations (T).
|
|
70
|
+
diffusion_mode: 'simple', 'iterative', 'spectral', or 'factored'.
|
|
71
|
+
attention_mode: 'dense' (O(N^2)) or 'graph' (O(kN)).
|
|
72
|
+
k_neighbors: kNN graph degree.
|
|
73
|
+
use_normalized_laplacian: Use symmetric-normalised L (recommended).
|
|
74
|
+
use_sparse: Store adjacency as sparse_coo (O(kN) storage).
|
|
75
|
+
diffuse_key: Apply diffusion to stored patterns (keys).
|
|
76
|
+
diffuse_query: Apply diffusion to state patterns (queries).
|
|
77
|
+
use_logit_diffusion: Also smooth post-softmax attention weights.
|
|
78
|
+
logit_eta: eta for logit-level diffusion; defaults to eta.
|
|
79
|
+
adaptive_eta: Scale eta by attention entropy.
|
|
80
|
+
adaptive_temperature: Sigmoid temperature for adaptive eta.
|
|
81
|
+
adaptive_threshold: Entropy midpoint for adaptive eta.
|
|
82
|
+
cache_graph: Cache graph across forward passes.
|
|
83
|
+
energy_stop_tol: Early-stop if |Delta E| < tol (0 = disabled).
|
|
84
|
+
"""
|
|
85
|
+
eta: float = 0.1
|
|
86
|
+
beta: float = 1.0
|
|
87
|
+
steps: int = 3
|
|
88
|
+
diffusion_mode: str = "factored"
|
|
89
|
+
attention_mode: str = "dense"
|
|
90
|
+
k_neighbors: int = 5
|
|
91
|
+
use_normalized_laplacian: bool = True
|
|
92
|
+
use_sparse: bool = False
|
|
93
|
+
diffuse_key: bool = True
|
|
94
|
+
diffuse_query: bool = False
|
|
95
|
+
use_logit_diffusion: bool = False
|
|
96
|
+
logit_eta: Optional[float] = None
|
|
97
|
+
adaptive_eta: bool = False
|
|
98
|
+
adaptive_temperature: float = 5.0
|
|
99
|
+
adaptive_threshold: float = 1.0
|
|
100
|
+
cache_graph: bool = True
|
|
101
|
+
energy_stop_tol: float = 0.0
|
|
102
|
+
|
|
103
|
+
def to_dict(self) -> Dict[str, object]:
|
|
104
|
+
"""Return a JSON-serialisable dict."""
|
|
105
|
+
import dataclasses
|
|
106
|
+
return dataclasses.asdict(self)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# ---------------------------------------------------------------------------
|
|
110
|
+
# Graph Cache
|
|
111
|
+
# ---------------------------------------------------------------------------
|
|
112
|
+
|
|
113
|
+
class GraphCache:
|
|
114
|
+
"""
|
|
115
|
+
Builds and caches all graph objects for a given pattern tensor.
|
|
116
|
+
|
|
117
|
+
Responsibility: graph construction and operator lifecycle — nothing else.
|
|
118
|
+
|
|
119
|
+
Caches per unique input (keyed by tensor data pointer):
|
|
120
|
+
* W — (N, N) adjacency (dense or sparse_coo)
|
|
121
|
+
* deg — (N,) degree vector
|
|
122
|
+
* adj_indices — (N, k) kNN neighbor indices (needed by graph attention)
|
|
123
|
+
* L — (N, N) Laplacian (only when required by diffusion mode)
|
|
124
|
+
* op — Precomputed DiffusionOperator
|
|
125
|
+
|
|
126
|
+
Call ``invalidate()`` to force a full rebuild on the next ``get()``.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
config: DiffusionConfig controlling graph and diffusion settings.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
def __init__(self, config: DiffusionConfig) -> None:
|
|
133
|
+
self._cfg = config
|
|
134
|
+
self._graph_builder = GraphBuilder(
|
|
135
|
+
k=config.k_neighbors, use_sparse=config.use_sparse
|
|
136
|
+
)
|
|
137
|
+
self._lap_builder = LaplacianBuilder(
|
|
138
|
+
normalized=config.use_normalized_laplacian
|
|
139
|
+
)
|
|
140
|
+
self._reset()
|
|
141
|
+
|
|
142
|
+
# ------------------------------------------------------------------
|
|
143
|
+
# Public API
|
|
144
|
+
# ------------------------------------------------------------------
|
|
145
|
+
|
|
146
|
+
def get(
|
|
147
|
+
self, patterns: Tensor
|
|
148
|
+
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], DiffusionOperator]:
|
|
149
|
+
"""
|
|
150
|
+
Return (W, deg, adj_indices, L, op) for the given patterns.
|
|
151
|
+
|
|
152
|
+
Fast path (O(1)): returns cached objects when patterns are unchanged.
|
|
153
|
+
Slow path (O(N^2 d)): builds everything from scratch.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
patterns: (N, d) float32 representative patterns.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
W: (N, N) adjacency.
|
|
160
|
+
deg: (N,) degree vector.
|
|
161
|
+
adj_indices: (N, k) neighbor indices for graph attention.
|
|
162
|
+
L: (N, N) Laplacian — None for 'factored' mode.
|
|
163
|
+
op: Precomputed DiffusionOperator.
|
|
164
|
+
"""
|
|
165
|
+
ptr = patterns.data_ptr()
|
|
166
|
+
# Composite key: (data_ptr, shape, checksum) guards against PyTorch
|
|
167
|
+
# reusing the same memory address for a different tensor after the
|
|
168
|
+
# original is freed (BUG-4 fix).
|
|
169
|
+
cache_key = (ptr, tuple(patterns.shape), float(patterns.sum()))
|
|
170
|
+
if self._cfg.cache_graph and cache_key == self._cached_ptr:
|
|
171
|
+
return (
|
|
172
|
+
self._cached_W,
|
|
173
|
+
self._cached_deg,
|
|
174
|
+
self._cached_adj,
|
|
175
|
+
self._cached_L,
|
|
176
|
+
self._cached_op,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
W, deg, adj_idx, L, op = self._build(patterns)
|
|
180
|
+
|
|
181
|
+
if self._cfg.cache_graph:
|
|
182
|
+
self._cached_ptr = cache_key
|
|
183
|
+
self._cached_W = W
|
|
184
|
+
self._cached_deg = deg
|
|
185
|
+
self._cached_adj = adj_idx
|
|
186
|
+
self._cached_L = L
|
|
187
|
+
self._cached_op = op
|
|
188
|
+
|
|
189
|
+
return W, deg, adj_idx, L, op
|
|
190
|
+
|
|
191
|
+
def invalidate(self) -> None:
|
|
192
|
+
"""Force a full rebuild on the next call to ``get``."""
|
|
193
|
+
self._reset()
|
|
194
|
+
|
|
195
|
+
# ------------------------------------------------------------------
|
|
196
|
+
# Internal
|
|
197
|
+
# ------------------------------------------------------------------
|
|
198
|
+
|
|
199
|
+
def _reset(self) -> None:
|
|
200
|
+
self._cached_ptr: Optional[int] = None
|
|
201
|
+
self._cached_W: Optional[Tensor] = None
|
|
202
|
+
self._cached_deg: Optional[Tensor] = None
|
|
203
|
+
self._cached_adj: Optional[Tensor] = None
|
|
204
|
+
self._cached_L: Optional[Tensor] = None
|
|
205
|
+
self._cached_op: Optional[DiffusionOperator] = None
|
|
206
|
+
|
|
207
|
+
def _build(
|
|
208
|
+
self, patterns: Tensor
|
|
209
|
+
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], DiffusionOperator]:
|
|
210
|
+
cfg = self._cfg
|
|
211
|
+
X = patterns.detach().float() # (N, d)
|
|
212
|
+
|
|
213
|
+
# All graph construction through GraphBuilder (single responsibility)
|
|
214
|
+
W, deg, adj_idx = self._graph_builder.build(X)
|
|
215
|
+
|
|
216
|
+
# Move to target device/dtype
|
|
217
|
+
W = W.to(device=patterns.device)
|
|
218
|
+
deg = deg.to(dtype=patterns.dtype, device=patterns.device)
|
|
219
|
+
adj_idx = adj_idx.to(device=patterns.device)
|
|
220
|
+
|
|
221
|
+
# FactoredDiffusion: no L required — avoids dense N*N matrix
|
|
222
|
+
if cfg.diffusion_mode == "factored":
|
|
223
|
+
op = FactoredDiffusion(eta=cfg.eta, steps=cfg.steps)
|
|
224
|
+
op.precompute_from_graph(W, deg)
|
|
225
|
+
return W, deg, adj_idx, None, op
|
|
226
|
+
|
|
227
|
+
# All other modes need L (LaplacianBuilder — single responsibility)
|
|
228
|
+
L = self._lap_builder.build(W).to(
|
|
229
|
+
dtype=patterns.dtype, device=patterns.device
|
|
230
|
+
)
|
|
231
|
+
op = DiffusionOperator.create(
|
|
232
|
+
mode=cfg.diffusion_mode,
|
|
233
|
+
eta=cfg.eta,
|
|
234
|
+
steps=cfg.steps,
|
|
235
|
+
)
|
|
236
|
+
op.precompute(L)
|
|
237
|
+
return W, deg, adj_idx, L, op
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
# ---------------------------------------------------------------------------
|
|
241
|
+
# Energy Tracker
|
|
242
|
+
# ---------------------------------------------------------------------------
|
|
243
|
+
|
|
244
|
+
class EnergyTracker:
|
|
245
|
+
"""
|
|
246
|
+
Tracks Hopfield energy across diffusion steps.
|
|
247
|
+
|
|
248
|
+
E = -(beta * Q @ K^T).mean() + eta * trace(K^T L K) / N
|
|
249
|
+
|
|
250
|
+
Responsibility: measuring energy; providing early-stop signal.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
beta: Hopfield scaling factor.
|
|
254
|
+
eta: Diffusion regularisation strength.
|
|
255
|
+
tol: Stop if |E_t - E_{t-1}| < tol. 0 disables early stopping.
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
def __init__(self, beta: float, eta: float, tol: float = 0.0) -> None:
|
|
259
|
+
self.beta = beta
|
|
260
|
+
self.eta = eta
|
|
261
|
+
self.tol = tol
|
|
262
|
+
self._history: List[float] = []
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def history(self) -> List[float]:
|
|
266
|
+
"""List of per-step energy values."""
|
|
267
|
+
return list(self._history)
|
|
268
|
+
|
|
269
|
+
def reset(self) -> None:
|
|
270
|
+
self._history.clear()
|
|
271
|
+
|
|
272
|
+
@torch.no_grad()
|
|
273
|
+
def step(self, Q: Tensor, K: Tensor, L: Tensor) -> Tuple[float, bool]:
|
|
274
|
+
"""
|
|
275
|
+
Compute energy for the current step; return (energy, should_stop).
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
Q: (N, d) or (S, B, d) query patterns.
|
|
279
|
+
K: (N, d) or (S, B, d) key patterns (after diffusion at this step).
|
|
280
|
+
L: (N, N) graph Laplacian.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
energy: Scalar energy value.
|
|
284
|
+
should_stop: True if early-stop criterion is met.
|
|
285
|
+
"""
|
|
286
|
+
Q_2d = Q.mean(dim=1) if Q.dim() == 3 else Q
|
|
287
|
+
K_2d = K.mean(dim=1) if K.dim() == 3 else K
|
|
288
|
+
affinity = -(self.beta * Q_2d @ K_2d.t()).mean()
|
|
289
|
+
smoothness = self.eta * torch.trace(K_2d.t() @ L @ K_2d) / K_2d.shape[0]
|
|
290
|
+
energy = (affinity + smoothness).item()
|
|
291
|
+
self._history.append(energy)
|
|
292
|
+
|
|
293
|
+
if self.tol > 0.0 and len(self._history) >= 2:
|
|
294
|
+
if abs(self._history[-1] - self._history[-2]) < self.tol:
|
|
295
|
+
return energy, True
|
|
296
|
+
|
|
297
|
+
return energy, False
|
|
298
|
+
|
|
299
|
+
@torch.no_grad()
|
|
300
|
+
def step_factored(
|
|
301
|
+
self, Q: Tensor, K: Tensor, W: Tensor, deg: Tensor,
|
|
302
|
+
) -> Tuple[float, bool]:
|
|
303
|
+
"""
|
|
304
|
+
Compute energy without L, using the factored identity:
|
|
305
|
+
|
|
306
|
+
tr(K^T L K) = (deg ⊙ ||k_i||²).sum() - (W@K ⊙ K).sum()
|
|
307
|
+
|
|
308
|
+
Enables energy tracking with FactoredDiffusion (L=None).
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
Q: (N, d) or (S, B, d) query patterns.
|
|
312
|
+
K: (N, d) or (S, B, d) key patterns.
|
|
313
|
+
W: (N, N) adjacency (dense or sparse).
|
|
314
|
+
deg: (N,) degree vector.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
energy: Scalar energy value.
|
|
318
|
+
should_stop: True if early-stop criterion is met.
|
|
319
|
+
"""
|
|
320
|
+
Q_2d = Q.mean(dim=1) if Q.dim() == 3 else Q
|
|
321
|
+
K_2d = K.mean(dim=1) if K.dim() == 3 else K
|
|
322
|
+
|
|
323
|
+
affinity = -(self.beta * Q_2d @ K_2d.t()).mean()
|
|
324
|
+
K_norms_sq = (K_2d * K_2d).sum(dim=-1) # (N,)
|
|
325
|
+
deg_term = (deg * K_norms_sq).sum()
|
|
326
|
+
WK = torch.sparse.mm(W, K_2d) if W.is_sparse else W @ K_2d
|
|
327
|
+
smoothness = self.eta * (deg_term - (WK * K_2d).sum()) / K_2d.shape[0]
|
|
328
|
+
energy = (affinity + smoothness).item()
|
|
329
|
+
self._history.append(energy)
|
|
330
|
+
|
|
331
|
+
if self.tol > 0.0 and len(self._history) >= 2:
|
|
332
|
+
if abs(self._history[-1] - self._history[-2]) < self.tol:
|
|
333
|
+
return energy, True
|
|
334
|
+
return energy, False
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
# ---------------------------------------------------------------------------
|
|
338
|
+
# Dynamics Engine
|
|
339
|
+
# ---------------------------------------------------------------------------
|
|
340
|
+
|
|
341
|
+
class DynamicsEngine:
|
|
342
|
+
"""
|
|
343
|
+
Core iterative dynamical system: x_{t+1} = Attention(D * x_t).
|
|
344
|
+
|
|
345
|
+
Responsibility: run the T-step loop, alternating diffusion and attention.
|
|
346
|
+
This class NEVER builds or rebuilds the graph, Laplacian, or operators.
|
|
347
|
+
All precomputed objects are injected at construction time.
|
|
348
|
+
|
|
349
|
+
Full loop (spec Section 4.1)::
|
|
350
|
+
|
|
351
|
+
for t in range(T):
|
|
352
|
+
Q = diffusion(Q)
|
|
353
|
+
K = diffusion(K)
|
|
354
|
+
Q = attention(Q, K, V)
|
|
355
|
+
|
|
356
|
+
Each step costs:
|
|
357
|
+
diffusion : O(kNd) factored/sparse or O(N^2 d) dense
|
|
358
|
+
attention : O(kNd) graph mode or O(N^2 d) dense mode
|
|
359
|
+
|
|
360
|
+
The engine also exposes ``run_diffusion`` for single-tensor use (keys or
|
|
361
|
+
queries independently) to maintain backward compatibility with
|
|
362
|
+
``DiffusedHopfield._associate``.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
diffusion_op: Precomputed DiffusionOperator (must be callable).
|
|
366
|
+
attention_op: AttentionOperator in 'dense' or 'graph' mode.
|
|
367
|
+
Required only for ``run_dynamics``; optional for
|
|
368
|
+
``run_diffusion`` (backward-compat path).
|
|
369
|
+
steps: Number of dynamics iterations T.
|
|
370
|
+
energy_tracker: Optional EnergyTracker; enables per-step early-stop.
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
def __init__(
|
|
374
|
+
self,
|
|
375
|
+
diffusion_op: DiffusionOperator,
|
|
376
|
+
attention_op: Optional[AttentionOperator] = None,
|
|
377
|
+
steps: Optional[int] = None,
|
|
378
|
+
energy_tracker: Optional[EnergyTracker] = None,
|
|
379
|
+
query_diffusion_op: Optional[DiffusionOperator] = None,
|
|
380
|
+
) -> None:
|
|
381
|
+
self._diff_op = diffusion_op
|
|
382
|
+
self._query_diff_op = query_diffusion_op # separate op for Q; falls back to _diff_op if None
|
|
383
|
+
self._attn_op = attention_op
|
|
384
|
+
self._steps = steps if steps is not None else diffusion_op.steps
|
|
385
|
+
self._tracker = energy_tracker
|
|
386
|
+
|
|
387
|
+
# ------------------------------------------------------------------
|
|
388
|
+
# Full dynamics loop Q_T = run_dynamics(Q, K, V)
|
|
389
|
+
# ------------------------------------------------------------------
|
|
390
|
+
|
|
391
|
+
def run_dynamics(
|
|
392
|
+
self,
|
|
393
|
+
Q: Tensor,
|
|
394
|
+
K: Tensor,
|
|
395
|
+
V: Tensor,
|
|
396
|
+
adj_indices: Optional[Tensor] = None,
|
|
397
|
+
L: Optional[Tensor] = None,
|
|
398
|
+
W: Optional[Tensor] = None,
|
|
399
|
+
deg: Optional[Tensor] = None,
|
|
400
|
+
diffuse_query: bool = True,
|
|
401
|
+
diffuse_key: bool = True,
|
|
402
|
+
) -> Tuple[Tensor, Tensor]:
|
|
403
|
+
"""
|
|
404
|
+
Run the full T-step diffusion-attention loop.
|
|
405
|
+
|
|
406
|
+
Inside the loop — ZERO redundant computation:
|
|
407
|
+
* No graph rebuild
|
|
408
|
+
* No Laplacian recompute
|
|
409
|
+
* No new memory allocation (reuses Q, K in-place conceptually)
|
|
410
|
+
|
|
411
|
+
Energy tracking works with both L-based and factored (W, deg) forms.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
Q: (N, d) or (S, B, d) query patterns.
|
|
415
|
+
K: (N, d) or (S, B, d) key patterns.
|
|
416
|
+
V: (N, d) or (S, B, d) value patterns.
|
|
417
|
+
adj_indices: (N, k) neighbor indices — required for graph attention.
|
|
418
|
+
L: (N, N) Laplacian — for L-based energy tracking.
|
|
419
|
+
W: (N, N) adjacency — for factored energy tracking.
|
|
420
|
+
deg: (N,) degree vector — for factored energy tracking.
|
|
421
|
+
diffuse_query: Whether to diffuse Q each iteration.
|
|
422
|
+
diffuse_key: Whether to diffuse K each iteration.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
(Q, K): Tuple of updated queries and diffused keys after T steps.
|
|
426
|
+
|
|
427
|
+
Complexity per step:
|
|
428
|
+
diffusion : O(kNd) factored sparse or O(N^2 d) dense
|
|
429
|
+
attention : O(kNd) graph or O(N^2 d) dense
|
|
430
|
+
"""
|
|
431
|
+
if self._attn_op is None:
|
|
432
|
+
raise RuntimeError(
|
|
433
|
+
"DynamicsEngine.run_dynamics requires an AttentionOperator. "
|
|
434
|
+
"Pass attention_op= at construction time."
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# Clean slate for energy tracking each dynamics run
|
|
438
|
+
if self._tracker is not None:
|
|
439
|
+
self._tracker.reset()
|
|
440
|
+
|
|
441
|
+
use_energy_L = self._tracker is not None and L is not None
|
|
442
|
+
use_energy_fac = (
|
|
443
|
+
self._tracker is not None
|
|
444
|
+
and W is not None and deg is not None
|
|
445
|
+
and not use_energy_L
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
for _ in range(self._steps):
|
|
449
|
+
if diffuse_key:
|
|
450
|
+
K = self._diff_op(K)
|
|
451
|
+
if diffuse_query:
|
|
452
|
+
Q = (self._query_diff_op if self._query_diff_op is not None
|
|
453
|
+
else self._diff_op)(Q)
|
|
454
|
+
|
|
455
|
+
# Attention update — dense O(N^2) or graph O(kN)
|
|
456
|
+
Q = self._attn_op(Q, K, V, adj_indices=adj_indices)
|
|
457
|
+
|
|
458
|
+
# Optional: energy check for early stop
|
|
459
|
+
if use_energy_L:
|
|
460
|
+
_, stop = self._tracker.step(Q, K, L)
|
|
461
|
+
if stop:
|
|
462
|
+
break
|
|
463
|
+
elif use_energy_fac:
|
|
464
|
+
_, stop = self._tracker.step_factored(Q, K, W, deg)
|
|
465
|
+
if stop:
|
|
466
|
+
break
|
|
467
|
+
|
|
468
|
+
return Q, K
|
|
469
|
+
|
|
470
|
+
# ------------------------------------------------------------------
|
|
471
|
+
# Single-tensor diffusion (backward-compatible)
|
|
472
|
+
# ------------------------------------------------------------------
|
|
473
|
+
|
|
474
|
+
def run_diffusion(
|
|
475
|
+
self,
|
|
476
|
+
X: Tensor,
|
|
477
|
+
L: Optional[Tensor] = None,
|
|
478
|
+
Q_ref: Optional[Tensor] = None,
|
|
479
|
+
) -> Tensor:
|
|
480
|
+
"""
|
|
481
|
+
Apply the diffusion operator for ``steps`` iterations to a single tensor.
|
|
482
|
+
|
|
483
|
+
Preserves the original API used by ``DiffusedHopfield._associate``
|
|
484
|
+
for feature-level key/query diffusion before Hopfield attention.
|
|
485
|
+
|
|
486
|
+
No graph rebuild per step.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
X: (N, d) or (S, B, d) patterns to diffuse.
|
|
490
|
+
L: (N, N) Laplacian — required only for energy tracking.
|
|
491
|
+
Q_ref: (N, d) query reference — required only for energy tracking.
|
|
492
|
+
|
|
493
|
+
Returns:
|
|
494
|
+
X': Diffused patterns, same shape as X.
|
|
495
|
+
"""
|
|
496
|
+
use_tracking = (
|
|
497
|
+
self._tracker is not None
|
|
498
|
+
and L is not None
|
|
499
|
+
and Q_ref is not None
|
|
500
|
+
and X.dim() == 2
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
if use_tracking:
|
|
504
|
+
for _ in range(self._steps):
|
|
505
|
+
X = self._diff_op(X)
|
|
506
|
+
_, stop = self._tracker.step(Q_ref, X, L)
|
|
507
|
+
if stop:
|
|
508
|
+
break
|
|
509
|
+
else:
|
|
510
|
+
X = self._diff_op(X)
|
|
511
|
+
|
|
512
|
+
return X
|
|
513
|
+
|
|
514
|
+
# ------------------------------------------------------------------
|
|
515
|
+
# Adaptive eta utility
|
|
516
|
+
# ------------------------------------------------------------------
|
|
517
|
+
|
|
518
|
+
def compute_adaptive_eta(
|
|
519
|
+
self, logits: Tensor, base_eta: float,
|
|
520
|
+
temperature: float = 5.0, threshold: float = 1.0,
|
|
521
|
+
) -> float:
|
|
522
|
+
"""
|
|
523
|
+
Compute entropy-gated adaptive eta.
|
|
524
|
+
|
|
525
|
+
eta_eff = base_eta * sigmoid(temperature * (H(attn) - threshold))
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
logits: (..., S) raw attention logits or weights.
|
|
529
|
+
base_eta: Maximum eta value.
|
|
530
|
+
temperature: Sigmoid steepness.
|
|
531
|
+
threshold: Entropy midpoint.
|
|
532
|
+
|
|
533
|
+
Returns:
|
|
534
|
+
eta_eff: Scalar float in [0, base_eta].
|
|
535
|
+
"""
|
|
536
|
+
with torch.no_grad():
|
|
537
|
+
probs = torch.softmax(logits, dim=-1)
|
|
538
|
+
H = -(probs * (probs + 1e-9).log()).sum(dim=-1).mean()
|
|
539
|
+
scale = torch.sigmoid(temperature * (H - threshold))
|
|
540
|
+
return base_eta * scale.item()
|