difflayers 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,395 @@
1
+ """
2
+ Graph diffusion operators for Graph-Regularized Hopfield attention.
3
+
4
+ Responsibility: Apply graph diffusion to pattern tensors.
5
+
6
+ Architecture: Strategy pattern — each mode is a subclass of
7
+ ``DiffusionOperator`` that precomputes an operator from (L, eta) and
8
+ applies it in __call__. The operator D = I - eta*L is precomputed once
9
+ and reused across all diffusion steps and forward passes.
10
+
11
+ Four diffusion modes:
12
+
13
+ 1. **simple** — One explicit Euler step using precomputed D = I - eta*L.
14
+ 2. **iterative** — T applications of D, giving (I - eta*L)^T X.
15
+ 3. **spectral** — Exact heat-kernel via eigendecomposition:
16
+ X' = U exp(-eta*Λ) U^T X.
17
+ 4. **factored** — Memory-optimal Laplacian-free form:
18
+ x' = (1 - η·deg) ⊙ x + η · W @ x
19
+ Stores only (W_sparse, deg). O(kN) memory.
20
+
21
+ All operators support 2-D (N, d) and 3-D (S, B, d) inputs.
22
+
23
+ Backward-compatible functional API (``apply_diffusion``) is preserved.
24
+
25
+ Complexity:
26
+ DiffusionOperator.precompute : O(N²) [simple/iterative]
27
+ O(N³) [spectral — eigendecomp]
28
+ DiffusionOperator.__call__ : O(N²) [dense matmul, all modes]
29
+ With sparse L and sparse matmul: O(kN) per step [simple/iterative]
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ from abc import ABC, abstractmethod
35
+ from typing import Optional
36
+
37
+ import torch
38
+ from torch import Tensor
39
+
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Internal helper
43
+ # ---------------------------------------------------------------------------
44
+
45
+ def _matmul(M: Tensor, X: Tensor) -> Tensor:
46
+ """
47
+ M @ X for 2-D or 3-D X, supporting sparse M.
48
+
49
+ Args:
50
+ M: (N, N) dense or sparse_coo operator.
51
+ X: (N, d) or (S, B, d) input.
52
+
53
+ Returns:
54
+ Result of M @ X, same shape as X.
55
+ """
56
+ if X.dim() == 2:
57
+ if M.is_sparse:
58
+ return torch.sparse.mm(M, X)
59
+ return M @ X
60
+ # 3-D: (S, B, d) — reshape to (S, B*d), matmul, reshape back.
61
+ S, B, d = X.shape
62
+ X_flat = X.reshape(S, B * d)
63
+ if M.is_sparse:
64
+ out = torch.sparse.mm(M, X_flat)
65
+ else:
66
+ out = M @ X_flat
67
+ return out.reshape(S, B, d)
68
+
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # Strategy base
72
+ # ---------------------------------------------------------------------------
73
+
74
+ class DiffusionOperator(ABC):
75
+ """
76
+ Abstract base for graph diffusion operators.
77
+
78
+ Subclasses must implement ``precompute`` (build the operator from L and
79
+ eta) and ``__call__`` (apply the operator to X).
80
+
81
+ The precomputed operator is stored in ``self._op`` and is never
82
+ rebuilt unless ``precompute`` is called again — satisfying the
83
+ no-recomputation-per-step requirement.
84
+
85
+ Args:
86
+ eta: Diffusion strength / time.
87
+ steps: Number of iterations (relevant for iterative mode only).
88
+ """
89
+
90
+ def __init__(self, eta: float, steps: int = 1) -> None:
91
+ self.eta = eta
92
+ self.steps = steps
93
+ self._op: Optional[Tensor] = None # precomputed operator
94
+
95
+ def precompute(self, L: Tensor) -> "DiffusionOperator":
96
+ """
97
+ Build and cache the diffusion operator from the Laplacian L.
98
+
99
+ Must be called once per unique (L, eta) pair before __call__.
100
+
101
+ Args:
102
+ L: (N, N) graph Laplacian (dense or sparse_coo).
103
+
104
+ Returns:
105
+ self — enables chaining: op = SimpleDiffusion(eta).precompute(L)
106
+ """
107
+ self._op = self._build_operator(L)
108
+ return self
109
+
110
+ @abstractmethod
111
+ def _build_operator(self, L: Tensor) -> Tensor:
112
+ """Construct the precomputed operator from L."""
113
+
114
+ @abstractmethod
115
+ def __call__(self, X: Tensor) -> Tensor:
116
+ """
117
+ Apply the diffusion operator to X.
118
+
119
+ Args:
120
+ X: (N, d) or (S, B, d) input patterns, float32.
121
+
122
+ Returns:
123
+ X': Diffused tensor, same shape as X.
124
+ """
125
+
126
+ def _check_ready(self) -> None:
127
+ if self._op is None:
128
+ raise RuntimeError(
129
+ f"{type(self).__name__}.precompute(L) must be called before __call__."
130
+ )
131
+
132
+ # ------------------------------------------------------------------
133
+ # Factory
134
+ # ------------------------------------------------------------------
135
+
136
+ @staticmethod
137
+ def create(mode: str, eta: float, steps: int = 3) -> "DiffusionOperator":
138
+ """
139
+ Factory method — returns the correct DiffusionOperator subclass.
140
+
141
+ Args:
142
+ mode: 'simple', 'iterative', or 'spectral'.
143
+ eta: Diffusion strength.
144
+ steps: Iterations for 'iterative' mode.
145
+
146
+ Returns:
147
+ Un-precomputed DiffusionOperator instance.
148
+ """
149
+ _MAP = {
150
+ "simple": SimpleDiffusion,
151
+ "iterative": IterativeDiffusion,
152
+ "spectral": SpectralDiffusion,
153
+ "factored": FactoredDiffusion,
154
+ }
155
+ cls = _MAP.get(mode)
156
+ if cls is None:
157
+ raise ValueError(
158
+ f"Unknown diffusion mode '{mode}'. Choose from {list(_MAP.keys())}."
159
+ )
160
+ return cls(eta=eta, steps=steps)
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # Concrete strategies
165
+ # ---------------------------------------------------------------------------
166
+
167
+ class SimpleDiffusion(DiffusionOperator):
168
+ """
169
+ One-step explicit Euler diffusion: X' = D @ X, D = I - eta*L.
170
+
171
+ Precomputes D once; applies a single dense (or sparse) matmul per call.
172
+
173
+ Complexity:
174
+ precompute : O(N²) — dense; O(kN) if L is sparse
175
+ __call__ : O(N²d) — dense; O(kNd) if D stored sparse
176
+ """
177
+
178
+ def _build_operator(self, L: Tensor) -> Tensor:
179
+ # D = I - eta * L (always dense — sparse I - eta*L is dense anyway)
180
+ L_dense = L.to_dense() if L.is_sparse else L
181
+ N = L_dense.shape[0]
182
+ I = torch.eye(N, dtype=L_dense.dtype, device=L_dense.device)
183
+ return I - self.eta * L_dense # (N, N) dense
184
+
185
+ def __call__(self, X: Tensor) -> Tensor:
186
+ self._check_ready()
187
+ if self.eta == 0.0:
188
+ return X
189
+ return _matmul(self._op, X)
190
+
191
+
192
+ class IterativeDiffusion(DiffusionOperator):
193
+ """
194
+ Multi-step explicit Euler: X' = D^steps @ X, D = I - eta*L.
195
+
196
+ Precomputes D once; applies it ``steps`` times per call.
197
+ Equivalent to (I - eta*L)^steps X without re-recomputing L.
198
+
199
+ Over-smoothing guard: stops early if ||X_t - X_{t-1}||_F < tol.
200
+
201
+ Complexity:
202
+ precompute : O(N²)
203
+ __call__ : O(steps * N²d)
204
+ """
205
+
206
+ def __init__(self, eta: float, steps: int = 3,
207
+ early_stop_tol: float = 1e-6) -> None:
208
+ super().__init__(eta=eta, steps=steps)
209
+ self.early_stop_tol = early_stop_tol
210
+
211
+ def _build_operator(self, L: Tensor) -> Tensor:
212
+ L_dense = L.to_dense() if L.is_sparse else L
213
+ N = L_dense.shape[0]
214
+ I = torch.eye(N, dtype=L_dense.dtype, device=L_dense.device)
215
+ return I - self.eta * L_dense
216
+
217
+ def __call__(self, X: Tensor) -> Tensor:
218
+ self._check_ready()
219
+ if self.eta == 0.0 or self.steps == 0:
220
+ return X
221
+ for _ in range(self.steps):
222
+ X_new = _matmul(self._op, X)
223
+ # Early stopping: over-smoothing guard.
224
+ if torch.norm(X_new - X).item() < self.early_stop_tol:
225
+ break
226
+ X = X_new
227
+ return X
228
+
229
+
230
+ class SpectralDiffusion(DiffusionOperator):
231
+ """
232
+ Exact heat-kernel diffusion: X' = H @ X,
233
+ H = U exp(-eta*Λ) U^T (eigendecomposition of L).
234
+
235
+ Unconditionally stable for all eta > 0. More expensive than Euler
236
+ methods to precompute but cheap to apply (single matmul).
237
+
238
+ Complexity:
239
+ precompute : O(N³) — full symmetric eigendecomposition
240
+ __call__ : O(N²d)
241
+ """
242
+
243
+ def _build_operator(self, L: Tensor) -> Tensor:
244
+ L_dense = (L.to_dense() if L.is_sparse else L).float()
245
+ eigenvalues, U = torch.linalg.eigh(L_dense) # (N,), (N, N)
246
+ H = U @ torch.diag(torch.exp(-self.eta * eigenvalues)) @ U.t()
247
+ return H # (N, N)
248
+
249
+ def __call__(self, X: Tensor) -> Tensor:
250
+ self._check_ready()
251
+ if self.eta == 0.0:
252
+ return X
253
+ H = self._op.to(dtype=X.dtype, device=X.device)
254
+ return _matmul(H, X)
255
+
256
+
257
+ class FactoredDiffusion(DiffusionOperator):
258
+ """
259
+ Memory-optimal diffusion using factored form — no explicit L matrix.
260
+
261
+ Instead of precomputing and storing the dense D = I - η*L (O(N²) memory),
262
+ this operator stores only:
263
+ * W — (N, N) sparse adjacency → O(kN) when sparse
264
+ * deg — (N,) degree vector → O(N)
265
+
266
+ and applies the identity-expanded form of one diffusion step:
267
+
268
+ x' = (1 - η·deg) ⊙ x + η · W @ x
269
+
270
+ which is algebraically identical to (I - η*L) @ x when L = D - A,
271
+ i.e. for the unnormalised Laplacian.
272
+
273
+ Initialisation
274
+ --------------
275
+ Preferred path — call ``precompute_from_graph(W, deg)`` directly
276
+ (provided by ``GraphCache``) to bypass L-construction entirely.
277
+
278
+ Fallback path — call ``precompute(L)`` as with other operators;
279
+ the class then recovers (W, deg) from L.
280
+
281
+ Args:
282
+ eta: Diffusion strength η.
283
+ steps: Number of applications per forward call.
284
+
285
+ Time complexity: O(kNd) per step (sparse W) or O(N²d) (dense W)
286
+ Memory: O(kN) + O(N) — NO full N×N matrix stored
287
+ """
288
+
289
+ def __init__(self, eta: float, steps: int = 1) -> None:
290
+ super().__init__(eta=eta, steps=steps)
291
+ self._W: Optional[Tensor] = None # sparse or dense adjacency
292
+ self._deg: Optional[Tensor] = None # (N,) degree
293
+
294
+ def precompute_from_graph(
295
+ self, W: Tensor, deg: Tensor
296
+ ) -> "FactoredDiffusion":
297
+ """
298
+ Initialise directly from adjacency W and degree vector — no L needed.
299
+
300
+ Args:
301
+ W: (N, N) adjacency matrix (dense or sparse_coo).
302
+ deg: (N,) float32 degree vector.
303
+
304
+ Returns:
305
+ self — enables chaining.
306
+ """
307
+ self._W = W
308
+ self._deg = deg
309
+ self._op = True # sentinel: marks precomputed
310
+ return self
311
+
312
+ def _build_operator(self, L: Tensor) -> bool:
313
+ """
314
+ Fallback: reconstruct (W, deg) from an unnormalised Laplacian.
315
+
316
+ For L = D - A: deg = diag(D) = diag(L) and A = diag(deg) - L.
317
+ NOTE: this identity holds ONLY for the unnormalised Laplacian
318
+ (diagonal values equal the node degrees, typically >> 1).
319
+ For the normalised Laplacian (diagonal values in [0, 1]),
320
+ call ``precompute_from_graph(W, deg)`` instead.
321
+ """
322
+ L_dense = L.to_dense() if L.is_sparse else L
323
+ diag_vals = L_dense.diagonal()
324
+ if diag_vals.max().item() <= 1.5:
325
+ raise ValueError(
326
+ "FactoredDiffusion.precompute(L) was called with a normalised "
327
+ "Laplacian (max diagonal value {:.4f} <= 1.5). "
328
+ "The factored identity A = diag(deg) - L only holds for the "
329
+ "unnormalised Laplacian. "
330
+ "Use precompute_from_graph(W, deg) instead, or switch to "
331
+ "SimpleDiffusion / IterativeDiffusion.".format(diag_vals.max().item())
332
+ )
333
+ self._deg = diag_vals.clone() # (N,)
334
+ self._W = torch.diag(self._deg) - L_dense # A = D - L
335
+ return True # sentinel — _op not used as a matrix here
336
+
337
+ def __call__(self, X: Tensor) -> Tensor:
338
+ if self._W is None:
339
+ raise RuntimeError(
340
+ "FactoredDiffusion: call precompute(L) or "
341
+ "precompute_from_graph(W, deg) before __call__."
342
+ )
343
+ if self.eta == 0.0 or self.steps == 0:
344
+ return X
345
+
346
+ W = self._W
347
+ eta = self.eta
348
+
349
+ for _ in range(self.steps):
350
+ # x' = (1 - η·deg) ⊙ x + η · W @ x
351
+ if X.dim() == 2:
352
+ scale = (1.0 - eta * self._deg).unsqueeze(-1) # (N, 1)
353
+ Wx = torch.sparse.mm(W, X) if W.is_sparse else W @ X
354
+ X = scale * X + eta * Wx
355
+ else:
356
+ # 3-D: (S, B, d) where S = N patterns
357
+ S, B, d = X.shape
358
+ scale = (1.0 - eta * self._deg).view(S, 1, 1) # (S, 1, 1)
359
+ X_flat = X.reshape(S, B * d)
360
+ Wx_flat = (
361
+ torch.sparse.mm(W, X_flat) if W.is_sparse else W @ X_flat
362
+ )
363
+ X = scale * X + eta * Wx_flat.reshape(S, B, d)
364
+
365
+ return X
366
+
367
+
368
+ # ---------------------------------------------------------------------------
369
+ # Backward-compatible functional API
370
+ # ---------------------------------------------------------------------------
371
+
372
+ def apply_diffusion(X: Tensor, L: Tensor, eta: float,
373
+ mode: str = "simple", steps: int = 3) -> Tensor:
374
+ """
375
+ Backward-compatible functional diffusion dispatch.
376
+
377
+ Builds and immediately applies a DiffusionOperator. For repeated
378
+ use on the same graph, prefer the class-based API which amortises
379
+ the precompute cost.
380
+
381
+ Args:
382
+ X: (N, d) or (S, B, d) input patterns.
383
+ L: (N, N) graph Laplacian.
384
+ eta: Diffusion strength / time.
385
+ mode: 'simple', 'iterative', or 'spectral'.
386
+ steps: Iterations (iterative mode only).
387
+
388
+ Returns:
389
+ X': Diffused tensor, same shape as X.
390
+ """
391
+ if eta == 0.0:
392
+ return X
393
+ op = DiffusionOperator.create(mode=mode, eta=eta, steps=steps)
394
+ op.precompute(L)
395
+ return op(X)