difflayers 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,540 @@
1
+ """
2
+ Dynamical Memory Engine for iterative diffusion-attention systems.
3
+
4
+ Responsibility: Orchestrate the interleaved diffusion → attention loop.
5
+
6
+ This module provides four focused classes following Single Responsibility:
7
+
8
+ GraphCache — builds and caches (W, deg, adj_indices, L, op) once per input
9
+ DynamicsEngine — runs T steps of diffusion → attention on Q, K, V patterns
10
+ EnergyTracker — computes and tracks Hopfield energy across steps
11
+ DiffusionConfig — single serialisable config object shared by all classes
12
+
13
+ Design notes
14
+ ------------
15
+ * ``GraphCache`` is the single source of truth for all graph objects.
16
+ No other module rebuilds W, deg, L, or the diffusion operator.
17
+ * ``DynamicsEngine`` holds references to a ``DiffusionOperator`` and an
18
+ ``AttentionOperator``; it does NOT rebuild either per step or per call.
19
+ * ``FactoredDiffusion`` is the default for sparse mode — it never forms L.
20
+ ``SpectralDiffusion`` and ``SimpleDiffusion`` use L when required.
21
+ * ``EnergyTracker`` is optional and zero-cost when disabled.
22
+
23
+ Full dynamics loop (Section 4.1 of spec)
24
+ -----------------------------------------
25
+ for t in range(T):
26
+ Q = diffusion(Q)
27
+ K = diffusion(K)
28
+ Q = attention(Q, K, V)
29
+
30
+ Each iteration costs:
31
+ diffusion : O(kNd) with FactoredDiffusion + sparse W
32
+ attention : O(N²d) dense OR O(kNd) graph mode
33
+
34
+ Complexity (N patterns, d features, T steps, k graph neighbours)
35
+ -----------------------------------------------------------------
36
+ GraphCache.get build : O(N²d) sim-matrix + O(N²) kNN + O(N³) if spectral
37
+ GraphCache.get hit : O(1)
38
+ DynamicsEngine.run : O(T * kNd) sparse factored diffusion + graph attn
39
+ O(T * N^2d) dense diffusion + dense attn
40
+ EnergyTracker.step : O(N²)
41
+ """
42
+
43
+ from __future__ import annotations
44
+
45
+ from dataclasses import dataclass, field
46
+ from typing import Dict, List, Optional, Tuple
47
+
48
+ import torch
49
+ from torch import Tensor
50
+
51
+ from .attention_operator import AttentionOperator
52
+ from .diffusion import DiffusionOperator, FactoredDiffusion
53
+ from .graph.builder import GraphBuilder
54
+ from .graph.laplacian_builder import LaplacianBuilder
55
+
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # Config dataclass
59
+ # ---------------------------------------------------------------------------
60
+
61
+ @dataclass
62
+ class DiffusionConfig:
63
+ """
64
+ Unified, serialisable configuration for the dynamical diffusion system.
65
+
66
+ Attributes:
67
+ eta: Diffusion strength eta.
68
+ beta: Hopfield scaling / attention inverse-temperature.
69
+ steps: Number of diffusion+attention iterations (T).
70
+ diffusion_mode: 'simple', 'iterative', 'spectral', or 'factored'.
71
+ attention_mode: 'dense' (O(N^2)) or 'graph' (O(kN)).
72
+ k_neighbors: kNN graph degree.
73
+ use_normalized_laplacian: Use symmetric-normalised L (recommended).
74
+ use_sparse: Store adjacency as sparse_coo (O(kN) storage).
75
+ diffuse_key: Apply diffusion to stored patterns (keys).
76
+ diffuse_query: Apply diffusion to state patterns (queries).
77
+ use_logit_diffusion: Also smooth post-softmax attention weights.
78
+ logit_eta: eta for logit-level diffusion; defaults to eta.
79
+ adaptive_eta: Scale eta by attention entropy.
80
+ adaptive_temperature: Sigmoid temperature for adaptive eta.
81
+ adaptive_threshold: Entropy midpoint for adaptive eta.
82
+ cache_graph: Cache graph across forward passes.
83
+ energy_stop_tol: Early-stop if |Delta E| < tol (0 = disabled).
84
+ """
85
+ eta: float = 0.1
86
+ beta: float = 1.0
87
+ steps: int = 3
88
+ diffusion_mode: str = "factored"
89
+ attention_mode: str = "dense"
90
+ k_neighbors: int = 5
91
+ use_normalized_laplacian: bool = True
92
+ use_sparse: bool = False
93
+ diffuse_key: bool = True
94
+ diffuse_query: bool = False
95
+ use_logit_diffusion: bool = False
96
+ logit_eta: Optional[float] = None
97
+ adaptive_eta: bool = False
98
+ adaptive_temperature: float = 5.0
99
+ adaptive_threshold: float = 1.0
100
+ cache_graph: bool = True
101
+ energy_stop_tol: float = 0.0
102
+
103
+ def to_dict(self) -> Dict[str, object]:
104
+ """Return a JSON-serialisable dict."""
105
+ import dataclasses
106
+ return dataclasses.asdict(self)
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Graph Cache
111
+ # ---------------------------------------------------------------------------
112
+
113
+ class GraphCache:
114
+ """
115
+ Builds and caches all graph objects for a given pattern tensor.
116
+
117
+ Responsibility: graph construction and operator lifecycle — nothing else.
118
+
119
+ Caches per unique input (keyed by tensor data pointer):
120
+ * W — (N, N) adjacency (dense or sparse_coo)
121
+ * deg — (N,) degree vector
122
+ * adj_indices — (N, k) kNN neighbor indices (needed by graph attention)
123
+ * L — (N, N) Laplacian (only when required by diffusion mode)
124
+ * op — Precomputed DiffusionOperator
125
+
126
+ Call ``invalidate()`` to force a full rebuild on the next ``get()``.
127
+
128
+ Args:
129
+ config: DiffusionConfig controlling graph and diffusion settings.
130
+ """
131
+
132
+ def __init__(self, config: DiffusionConfig) -> None:
133
+ self._cfg = config
134
+ self._graph_builder = GraphBuilder(
135
+ k=config.k_neighbors, use_sparse=config.use_sparse
136
+ )
137
+ self._lap_builder = LaplacianBuilder(
138
+ normalized=config.use_normalized_laplacian
139
+ )
140
+ self._reset()
141
+
142
+ # ------------------------------------------------------------------
143
+ # Public API
144
+ # ------------------------------------------------------------------
145
+
146
+ def get(
147
+ self, patterns: Tensor
148
+ ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], DiffusionOperator]:
149
+ """
150
+ Return (W, deg, adj_indices, L, op) for the given patterns.
151
+
152
+ Fast path (O(1)): returns cached objects when patterns are unchanged.
153
+ Slow path (O(N^2 d)): builds everything from scratch.
154
+
155
+ Args:
156
+ patterns: (N, d) float32 representative patterns.
157
+
158
+ Returns:
159
+ W: (N, N) adjacency.
160
+ deg: (N,) degree vector.
161
+ adj_indices: (N, k) neighbor indices for graph attention.
162
+ L: (N, N) Laplacian — None for 'factored' mode.
163
+ op: Precomputed DiffusionOperator.
164
+ """
165
+ ptr = patterns.data_ptr()
166
+ # Composite key: (data_ptr, shape, checksum) guards against PyTorch
167
+ # reusing the same memory address for a different tensor after the
168
+ # original is freed (BUG-4 fix).
169
+ cache_key = (ptr, tuple(patterns.shape), float(patterns.sum()))
170
+ if self._cfg.cache_graph and cache_key == self._cached_ptr:
171
+ return (
172
+ self._cached_W,
173
+ self._cached_deg,
174
+ self._cached_adj,
175
+ self._cached_L,
176
+ self._cached_op,
177
+ )
178
+
179
+ W, deg, adj_idx, L, op = self._build(patterns)
180
+
181
+ if self._cfg.cache_graph:
182
+ self._cached_ptr = cache_key
183
+ self._cached_W = W
184
+ self._cached_deg = deg
185
+ self._cached_adj = adj_idx
186
+ self._cached_L = L
187
+ self._cached_op = op
188
+
189
+ return W, deg, adj_idx, L, op
190
+
191
+ def invalidate(self) -> None:
192
+ """Force a full rebuild on the next call to ``get``."""
193
+ self._reset()
194
+
195
+ # ------------------------------------------------------------------
196
+ # Internal
197
+ # ------------------------------------------------------------------
198
+
199
+ def _reset(self) -> None:
200
+ self._cached_ptr: Optional[int] = None
201
+ self._cached_W: Optional[Tensor] = None
202
+ self._cached_deg: Optional[Tensor] = None
203
+ self._cached_adj: Optional[Tensor] = None
204
+ self._cached_L: Optional[Tensor] = None
205
+ self._cached_op: Optional[DiffusionOperator] = None
206
+
207
+ def _build(
208
+ self, patterns: Tensor
209
+ ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], DiffusionOperator]:
210
+ cfg = self._cfg
211
+ X = patterns.detach().float() # (N, d)
212
+
213
+ # All graph construction through GraphBuilder (single responsibility)
214
+ W, deg, adj_idx = self._graph_builder.build(X)
215
+
216
+ # Move to target device/dtype
217
+ W = W.to(device=patterns.device)
218
+ deg = deg.to(dtype=patterns.dtype, device=patterns.device)
219
+ adj_idx = adj_idx.to(device=patterns.device)
220
+
221
+ # FactoredDiffusion: no L required — avoids dense N*N matrix
222
+ if cfg.diffusion_mode == "factored":
223
+ op = FactoredDiffusion(eta=cfg.eta, steps=cfg.steps)
224
+ op.precompute_from_graph(W, deg)
225
+ return W, deg, adj_idx, None, op
226
+
227
+ # All other modes need L (LaplacianBuilder — single responsibility)
228
+ L = self._lap_builder.build(W).to(
229
+ dtype=patterns.dtype, device=patterns.device
230
+ )
231
+ op = DiffusionOperator.create(
232
+ mode=cfg.diffusion_mode,
233
+ eta=cfg.eta,
234
+ steps=cfg.steps,
235
+ )
236
+ op.precompute(L)
237
+ return W, deg, adj_idx, L, op
238
+
239
+
240
+ # ---------------------------------------------------------------------------
241
+ # Energy Tracker
242
+ # ---------------------------------------------------------------------------
243
+
244
+ class EnergyTracker:
245
+ """
246
+ Tracks Hopfield energy across diffusion steps.
247
+
248
+ E = -(beta * Q @ K^T).mean() + eta * trace(K^T L K) / N
249
+
250
+ Responsibility: measuring energy; providing early-stop signal.
251
+
252
+ Args:
253
+ beta: Hopfield scaling factor.
254
+ eta: Diffusion regularisation strength.
255
+ tol: Stop if |E_t - E_{t-1}| < tol. 0 disables early stopping.
256
+ """
257
+
258
+ def __init__(self, beta: float, eta: float, tol: float = 0.0) -> None:
259
+ self.beta = beta
260
+ self.eta = eta
261
+ self.tol = tol
262
+ self._history: List[float] = []
263
+
264
+ @property
265
+ def history(self) -> List[float]:
266
+ """List of per-step energy values."""
267
+ return list(self._history)
268
+
269
+ def reset(self) -> None:
270
+ self._history.clear()
271
+
272
+ @torch.no_grad()
273
+ def step(self, Q: Tensor, K: Tensor, L: Tensor) -> Tuple[float, bool]:
274
+ """
275
+ Compute energy for the current step; return (energy, should_stop).
276
+
277
+ Args:
278
+ Q: (N, d) or (S, B, d) query patterns.
279
+ K: (N, d) or (S, B, d) key patterns (after diffusion at this step).
280
+ L: (N, N) graph Laplacian.
281
+
282
+ Returns:
283
+ energy: Scalar energy value.
284
+ should_stop: True if early-stop criterion is met.
285
+ """
286
+ Q_2d = Q.mean(dim=1) if Q.dim() == 3 else Q
287
+ K_2d = K.mean(dim=1) if K.dim() == 3 else K
288
+ affinity = -(self.beta * Q_2d @ K_2d.t()).mean()
289
+ smoothness = self.eta * torch.trace(K_2d.t() @ L @ K_2d) / K_2d.shape[0]
290
+ energy = (affinity + smoothness).item()
291
+ self._history.append(energy)
292
+
293
+ if self.tol > 0.0 and len(self._history) >= 2:
294
+ if abs(self._history[-1] - self._history[-2]) < self.tol:
295
+ return energy, True
296
+
297
+ return energy, False
298
+
299
+ @torch.no_grad()
300
+ def step_factored(
301
+ self, Q: Tensor, K: Tensor, W: Tensor, deg: Tensor,
302
+ ) -> Tuple[float, bool]:
303
+ """
304
+ Compute energy without L, using the factored identity:
305
+
306
+ tr(K^T L K) = (deg ⊙ ||k_i||²).sum() - (W@K ⊙ K).sum()
307
+
308
+ Enables energy tracking with FactoredDiffusion (L=None).
309
+
310
+ Args:
311
+ Q: (N, d) or (S, B, d) query patterns.
312
+ K: (N, d) or (S, B, d) key patterns.
313
+ W: (N, N) adjacency (dense or sparse).
314
+ deg: (N,) degree vector.
315
+
316
+ Returns:
317
+ energy: Scalar energy value.
318
+ should_stop: True if early-stop criterion is met.
319
+ """
320
+ Q_2d = Q.mean(dim=1) if Q.dim() == 3 else Q
321
+ K_2d = K.mean(dim=1) if K.dim() == 3 else K
322
+
323
+ affinity = -(self.beta * Q_2d @ K_2d.t()).mean()
324
+ K_norms_sq = (K_2d * K_2d).sum(dim=-1) # (N,)
325
+ deg_term = (deg * K_norms_sq).sum()
326
+ WK = torch.sparse.mm(W, K_2d) if W.is_sparse else W @ K_2d
327
+ smoothness = self.eta * (deg_term - (WK * K_2d).sum()) / K_2d.shape[0]
328
+ energy = (affinity + smoothness).item()
329
+ self._history.append(energy)
330
+
331
+ if self.tol > 0.0 and len(self._history) >= 2:
332
+ if abs(self._history[-1] - self._history[-2]) < self.tol:
333
+ return energy, True
334
+ return energy, False
335
+
336
+
337
+ # ---------------------------------------------------------------------------
338
+ # Dynamics Engine
339
+ # ---------------------------------------------------------------------------
340
+
341
+ class DynamicsEngine:
342
+ """
343
+ Core iterative dynamical system: x_{t+1} = Attention(D * x_t).
344
+
345
+ Responsibility: run the T-step loop, alternating diffusion and attention.
346
+ This class NEVER builds or rebuilds the graph, Laplacian, or operators.
347
+ All precomputed objects are injected at construction time.
348
+
349
+ Full loop (spec Section 4.1)::
350
+
351
+ for t in range(T):
352
+ Q = diffusion(Q)
353
+ K = diffusion(K)
354
+ Q = attention(Q, K, V)
355
+
356
+ Each step costs:
357
+ diffusion : O(kNd) factored/sparse or O(N^2 d) dense
358
+ attention : O(kNd) graph mode or O(N^2 d) dense mode
359
+
360
+ The engine also exposes ``run_diffusion`` for single-tensor use (keys or
361
+ queries independently) to maintain backward compatibility with
362
+ ``DiffusedHopfield._associate``.
363
+
364
+ Args:
365
+ diffusion_op: Precomputed DiffusionOperator (must be callable).
366
+ attention_op: AttentionOperator in 'dense' or 'graph' mode.
367
+ Required only for ``run_dynamics``; optional for
368
+ ``run_diffusion`` (backward-compat path).
369
+ steps: Number of dynamics iterations T.
370
+ energy_tracker: Optional EnergyTracker; enables per-step early-stop.
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ diffusion_op: DiffusionOperator,
376
+ attention_op: Optional[AttentionOperator] = None,
377
+ steps: Optional[int] = None,
378
+ energy_tracker: Optional[EnergyTracker] = None,
379
+ query_diffusion_op: Optional[DiffusionOperator] = None,
380
+ ) -> None:
381
+ self._diff_op = diffusion_op
382
+ self._query_diff_op = query_diffusion_op # separate op for Q; falls back to _diff_op if None
383
+ self._attn_op = attention_op
384
+ self._steps = steps if steps is not None else diffusion_op.steps
385
+ self._tracker = energy_tracker
386
+
387
+ # ------------------------------------------------------------------
388
+ # Full dynamics loop Q_T = run_dynamics(Q, K, V)
389
+ # ------------------------------------------------------------------
390
+
391
+ def run_dynamics(
392
+ self,
393
+ Q: Tensor,
394
+ K: Tensor,
395
+ V: Tensor,
396
+ adj_indices: Optional[Tensor] = None,
397
+ L: Optional[Tensor] = None,
398
+ W: Optional[Tensor] = None,
399
+ deg: Optional[Tensor] = None,
400
+ diffuse_query: bool = True,
401
+ diffuse_key: bool = True,
402
+ ) -> Tuple[Tensor, Tensor]:
403
+ """
404
+ Run the full T-step diffusion-attention loop.
405
+
406
+ Inside the loop — ZERO redundant computation:
407
+ * No graph rebuild
408
+ * No Laplacian recompute
409
+ * No new memory allocation (reuses Q, K in-place conceptually)
410
+
411
+ Energy tracking works with both L-based and factored (W, deg) forms.
412
+
413
+ Args:
414
+ Q: (N, d) or (S, B, d) query patterns.
415
+ K: (N, d) or (S, B, d) key patterns.
416
+ V: (N, d) or (S, B, d) value patterns.
417
+ adj_indices: (N, k) neighbor indices — required for graph attention.
418
+ L: (N, N) Laplacian — for L-based energy tracking.
419
+ W: (N, N) adjacency — for factored energy tracking.
420
+ deg: (N,) degree vector — for factored energy tracking.
421
+ diffuse_query: Whether to diffuse Q each iteration.
422
+ diffuse_key: Whether to diffuse K each iteration.
423
+
424
+ Returns:
425
+ (Q, K): Tuple of updated queries and diffused keys after T steps.
426
+
427
+ Complexity per step:
428
+ diffusion : O(kNd) factored sparse or O(N^2 d) dense
429
+ attention : O(kNd) graph or O(N^2 d) dense
430
+ """
431
+ if self._attn_op is None:
432
+ raise RuntimeError(
433
+ "DynamicsEngine.run_dynamics requires an AttentionOperator. "
434
+ "Pass attention_op= at construction time."
435
+ )
436
+
437
+ # Clean slate for energy tracking each dynamics run
438
+ if self._tracker is not None:
439
+ self._tracker.reset()
440
+
441
+ use_energy_L = self._tracker is not None and L is not None
442
+ use_energy_fac = (
443
+ self._tracker is not None
444
+ and W is not None and deg is not None
445
+ and not use_energy_L
446
+ )
447
+
448
+ for _ in range(self._steps):
449
+ if diffuse_key:
450
+ K = self._diff_op(K)
451
+ if diffuse_query:
452
+ Q = (self._query_diff_op if self._query_diff_op is not None
453
+ else self._diff_op)(Q)
454
+
455
+ # Attention update — dense O(N^2) or graph O(kN)
456
+ Q = self._attn_op(Q, K, V, adj_indices=adj_indices)
457
+
458
+ # Optional: energy check for early stop
459
+ if use_energy_L:
460
+ _, stop = self._tracker.step(Q, K, L)
461
+ if stop:
462
+ break
463
+ elif use_energy_fac:
464
+ _, stop = self._tracker.step_factored(Q, K, W, deg)
465
+ if stop:
466
+ break
467
+
468
+ return Q, K
469
+
470
+ # ------------------------------------------------------------------
471
+ # Single-tensor diffusion (backward-compatible)
472
+ # ------------------------------------------------------------------
473
+
474
+ def run_diffusion(
475
+ self,
476
+ X: Tensor,
477
+ L: Optional[Tensor] = None,
478
+ Q_ref: Optional[Tensor] = None,
479
+ ) -> Tensor:
480
+ """
481
+ Apply the diffusion operator for ``steps`` iterations to a single tensor.
482
+
483
+ Preserves the original API used by ``DiffusedHopfield._associate``
484
+ for feature-level key/query diffusion before Hopfield attention.
485
+
486
+ No graph rebuild per step.
487
+
488
+ Args:
489
+ X: (N, d) or (S, B, d) patterns to diffuse.
490
+ L: (N, N) Laplacian — required only for energy tracking.
491
+ Q_ref: (N, d) query reference — required only for energy tracking.
492
+
493
+ Returns:
494
+ X': Diffused patterns, same shape as X.
495
+ """
496
+ use_tracking = (
497
+ self._tracker is not None
498
+ and L is not None
499
+ and Q_ref is not None
500
+ and X.dim() == 2
501
+ )
502
+
503
+ if use_tracking:
504
+ for _ in range(self._steps):
505
+ X = self._diff_op(X)
506
+ _, stop = self._tracker.step(Q_ref, X, L)
507
+ if stop:
508
+ break
509
+ else:
510
+ X = self._diff_op(X)
511
+
512
+ return X
513
+
514
+ # ------------------------------------------------------------------
515
+ # Adaptive eta utility
516
+ # ------------------------------------------------------------------
517
+
518
+ def compute_adaptive_eta(
519
+ self, logits: Tensor, base_eta: float,
520
+ temperature: float = 5.0, threshold: float = 1.0,
521
+ ) -> float:
522
+ """
523
+ Compute entropy-gated adaptive eta.
524
+
525
+ eta_eff = base_eta * sigmoid(temperature * (H(attn) - threshold))
526
+
527
+ Args:
528
+ logits: (..., S) raw attention logits or weights.
529
+ base_eta: Maximum eta value.
530
+ temperature: Sigmoid steepness.
531
+ threshold: Entropy midpoint.
532
+
533
+ Returns:
534
+ eta_eff: Scalar float in [0, base_eta].
535
+ """
536
+ with torch.no_grad():
537
+ probs = torch.softmax(logits, dim=-1)
538
+ H = -(probs * (probs + 1e-9).log()).sum(dim=-1).mean()
539
+ scale = torch.sigmoid(temperature * (H - threshold))
540
+ return base_eta * scale.item()