difflayers 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,252 @@
1
+ import torch
2
+
3
+ from math import ceil
4
+ from torch.utils.data import Dataset
5
+ from typing import Dict, Optional, Sequence, Tuple, Union
6
+
7
+
8
+ class BitPatternSet(Dataset):
9
+ """
10
+ Binary multiple instance learning (MIL) data set comprising bit patterns as instances,
11
+ with implanted bit patterns unique to one of the classes.
12
+ """
13
+
14
+ def __init__(self, num_bags: int, num_instances: int, num_signals: int, num_signals_per_bag: int = 1,
15
+ fraction_targets: float = 0.5, num_bits: int = 8, dtype: torch.dtype = torch.float32,
16
+ seed_signals: int = 43, seed_data: int = 44):
17
+ """
18
+ Create new binary bit pattern data set conforming to the specified properties.
19
+
20
+ :param num_bags: amount of bags
21
+ :param num_instances: amount of instances per bag
22
+ :param num_signals: amount of unique signals used to distinguish both classes
23
+ :param num_signals_per_bag: amount of unique signals to be implanted per bag
24
+ :param fraction_targets: fraction of targets
25
+ :param num_bits: amount of bits per instance
26
+ :param dtype: data type of instances
27
+ :param seed_signals: random seed used to generate the signals of the data set (excl. samples)
28
+ :param seed_data: random seed used to generate the samples of the data set (excl. signals)
29
+ """
30
+ super(BitPatternSet, self).__init__()
31
+ assert (type(num_bags) == int) and (num_bags > 0), r'"num_bags" must be a positive integer!'
32
+ assert (type(num_instances) == int) and (num_instances > 0), r'"num_instances" must be a positive integer!'
33
+ assert (type(num_signals) == int) and (num_signals > 0), r'"num_signals" must be a positive integer!'
34
+ assert (type(num_signals_per_bag) == int) and (num_signals_per_bag >= 0) and (
35
+ num_signals_per_bag <= num_instances), r'"num_signals_per_bag" must be a non-negative integer!'
36
+ assert (type(fraction_targets) == float) and (fraction_targets > 0) and (
37
+ fraction_targets < 1), r'"fraction_targets" must be in interval (0; 1)!'
38
+ assert (type(num_bits) == int) and (num_bits > 0), r'"num_bits" must be a positive integer!'
39
+ assert ((2 ** num_bits) - 1) > num_signals, r'"num_signals" must be smaller than "2 ** num_bits - 1"!'
40
+ assert type(seed_signals) == int, r'"seed_signals" must be an integer!'
41
+ assert type(seed_data) == int, r'"seed_data" must be an integer!'
42
+
43
+ self.__num_bags = num_bags
44
+ self.__num_instances = num_instances
45
+ self.__num_signals = num_signals
46
+ self.__num_signals_per_bag = num_signals_per_bag
47
+ self.__fraction_targets = fraction_targets
48
+ self.__num_targets = min(self.__num_bags, max(1, ceil(self.__num_bags * self.__fraction_targets)))
49
+ self.__num_bits = num_bits
50
+ self.__dtype = dtype
51
+ self.__seed_signals = seed_signals
52
+ self.__seed_data = seed_data
53
+ self.__data, self.__targets, self.__signals = self._generate_bit_pattern_set()
54
+
55
+ def __len__(self) -> int:
56
+ """
57
+ Fetch amount of bags.
58
+
59
+ :return: amount of bags
60
+ """
61
+ return self.__num_bags
62
+
63
+ def __getitem__(self, item_index: int) -> Dict[str, torch.Tensor]:
64
+ """
65
+ Fetch specific bag.
66
+
67
+ :param item_index: specific bag to fetch
68
+ :return: specific bag as dictionary of tensors
69
+ """
70
+ return {r'data': self.__data[item_index].to(dtype=self.__dtype),
71
+ r'target': self.__targets[item_index].to(dtype=self.__dtype)}
72
+
73
+ def _generate_bit_pattern_set(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
74
+ """
75
+ Generate underlying bit pattern data set.
76
+
77
+ :return: tuple containing generated bags, targets and signals
78
+ """
79
+ torch.random.manual_seed(seed=self.__seed_signals)
80
+
81
+ # Generate signal patterns.
82
+ generated_signals = torch.randint(low=0, high=2, size=(self.__num_signals, self.__num_bits))
83
+ check_instances = True
84
+ while check_instances:
85
+ generated_signals = torch.unique(input=generated_signals, dim=0)
86
+ generated_signals = generated_signals[generated_signals.sum(axis=1) != 0]
87
+ missing_signals = self.__num_signals - generated_signals.shape[0]
88
+ if missing_signals > 0:
89
+ generated_signals = torch.cat(tensors=(
90
+ generated_signals, torch.randint(low=0, high=2, size=(missing_signals, self.__num_bits))), dim=0)
91
+ else:
92
+ check_instances = False
93
+
94
+ # Generate data and target tensors.
95
+ torch.random.manual_seed(seed=self.__seed_data)
96
+ generated_data = torch.randint(low=0, high=2, size=(self.__num_bags, self.__num_instances, self.__num_bits))
97
+ generated_targets = torch.zeros(size=(self.__num_bags,), dtype=generated_data.dtype)
98
+ generated_targets[:self.__num_targets] = 1
99
+
100
+ # Check invalid (all-zero and signal) instances and re-sample them.
101
+ check_instances = True
102
+ while check_instances:
103
+ invalid_instances = (generated_data.sum(axis=2) == 0).logical_or(
104
+ torch.sum(torch.stack([(generated_data == _).all(axis=2) for _ in generated_signals]), axis=0))
105
+ if invalid_instances.sum() > 0:
106
+ generated_data[invalid_instances] = torch.randint(
107
+ low=0, high=2, size=(invalid_instances.sum(), self.__num_bits))
108
+ else:
109
+ check_instances = False
110
+
111
+ # Re-implant signal into respective bags.
112
+ for data_index in range(self.__num_targets):
113
+ implantation_indices = []
114
+ for _ in range(self.__num_signals_per_bag):
115
+ while True:
116
+ current_implantation_index = torch.randint(low=0, high=generated_data.shape[1], size=(1,))
117
+ if current_implantation_index not in implantation_indices:
118
+ implantation_indices.append(current_implantation_index)
119
+ break
120
+ current_signal_index = torch.randint(low=0, high=generated_signals.shape[0], size=(1,))
121
+ generated_data[data_index, current_implantation_index] = generated_signals[current_signal_index]
122
+
123
+ return generated_data, generated_targets, generated_signals
124
+
125
+ @property
126
+ def num_bags(self) -> int:
127
+ return self.__num_bags
128
+
129
+ @property
130
+ def num_instances(self) -> int:
131
+ return self.__num_instances
132
+
133
+ @property
134
+ def num_bits(self) -> int:
135
+ return self.__num_bits
136
+
137
+ @property
138
+ def num_targets_high(self) -> int:
139
+ return self.__num_targets
140
+
141
+ @property
142
+ def num_targets_low(self) -> int:
143
+ return self.__num_bags - self.__num_targets
144
+
145
+ @property
146
+ def num_signals(self) -> int:
147
+ return self.__num_signals
148
+
149
+ @property
150
+ def num_signals_per_bag(self) -> int:
151
+ return self.__num_signals_per_bag
152
+
153
+ @property
154
+ def initial_seed(self) -> int:
155
+ return self.__seed_signals
156
+
157
+ @property
158
+ def bags(self) -> torch.Tensor:
159
+ return self.__data.clone()
160
+
161
+ @property
162
+ def targets(self) -> torch.Tensor:
163
+ return self.__targets.clone()
164
+
165
+ @property
166
+ def signals(self) -> torch.Tensor:
167
+ return self.__signals.clone()
168
+
169
+
170
+ class LatchSequenceSet(Dataset):
171
+ """
172
+ Latch data set comprising patterns as one-hot encoded instances.
173
+ """
174
+
175
+ def __init__(self, num_samples: int, num_instances: int = 20, num_characters: int = 6,
176
+ dtype: torch.dtype = torch.float32, seed: int = 43):
177
+ """
178
+ Create new latch sequence data set conforming to the specified properties.
179
+
180
+ :param num_samples: amount of samples
181
+ :param num_instances: amount of instances per sample
182
+ :param num_characters: amount of different characters
183
+ :param dtype: data type of samples
184
+ :param seed: random seed used to generate the samples of the data set
185
+ """
186
+ super(LatchSequenceSet, self).__init__()
187
+ assert (type(num_samples) == int) and (num_samples > 0), r'"num_samples" must be a positive integer!'
188
+ assert (type(num_instances) == int) and (num_instances > 0), r'"num_instances" must be a positive integer!'
189
+ assert (type(num_characters) == int) and (num_characters > 0), r'"num_characters" must be a positive integer!'
190
+ assert type(seed) == int, r'"seed" must be an integer!'
191
+
192
+ self.__num_samples = num_samples
193
+ self.__num_instances = num_instances
194
+ self.__num_characters = num_characters
195
+ self.__dtype = dtype
196
+ self.__seed = seed
197
+ self.__data, self.__targets = self._generate_latch_sequences()
198
+
199
+ def __len__(self) -> int:
200
+ """
201
+ Fetch amount of samples.
202
+
203
+ :return: amount of samples
204
+ """
205
+ return self.__num_samples
206
+
207
+ def __getitem__(self, item_index: int) -> Dict[str, torch.Tensor]:
208
+ """
209
+ Fetch specific sample.
210
+
211
+ :param item_index: specific sample to fetch
212
+ :return: specific sample as dictionary of tensors
213
+ """
214
+ return {r'data': self.__data[item_index].to(dtype=self.__dtype),
215
+ r'target': self.__targets[item_index].to(dtype=self.__dtype)}
216
+
217
+ def _generate_latch_sequences(self) -> Tuple[torch.Tensor, torch.Tensor]:
218
+ """
219
+ Generate underlying latch sequence data set.
220
+
221
+ :return: tuple containing generated data and targets
222
+ """
223
+ torch.random.manual_seed(seed=self.__seed)
224
+
225
+ # Generate data and target tensors.
226
+ generated_data = torch.randint(
227
+ low=2, high=self.__num_characters, size=(self.__num_samples, self.__num_instances))
228
+ generated_signal = torch.randint(low=0, high=2, size=(self.__num_samples,))
229
+ generated_data[:, 0] = generated_signal
230
+ generated_data = torch.nn.functional.one_hot(input=generated_data, num_classes=self.__num_characters)
231
+
232
+ return generated_data, generated_signal
233
+
234
+ @property
235
+ def num_samples(self) -> int:
236
+ return self.__num_samples
237
+
238
+ @property
239
+ def num_instances(self) -> int:
240
+ return self.__num_instances
241
+
242
+ @property
243
+ def num_characters(self) -> int:
244
+ return self.__num_characters
245
+
246
+ @property
247
+ def initial_seed(self) -> int:
248
+ return self.__seed
249
+
250
+ @property
251
+ def targets(self) -> torch.Tensor:
252
+ return self.__targets.clone()
@@ -0,0 +1,427 @@
1
+ """
2
+ Graph-Regularized (Diffusion-Augmented) Hopfield Attention.
3
+
4
+ Core idea
5
+ ---------
6
+ Core model / dynamics loop
7
+ --------------------------
8
+
9
+ x_{t+1} = Attention(D · x_t) — spec Section 0
10
+
11
+ In practice the stored key patterns (and optionally queries) are diffused
12
+ before the Hopfield attention layer:
13
+
14
+ for t in range(T): # via DynamicsEngine.run_diffusion
15
+ K' = D @ K # D = I - η*L (or factored form)
16
+ Q' = D @ Q
17
+ Attention = softmax(β Q' K'ᵀ) V — dense O(N²) or graph O(kN)
18
+
19
+ Optionally, post-softmax attention weights are also smoothed over the graph
20
+ (logit-level diffusion):
21
+ weights' = diffuse(weights, L_K, η_logit)
22
+ output = weights' @ V
23
+
24
+ Four diffusion modes (implemented in ``DiffusionOperator`` subclasses):
25
+ * **factored** — x' = (1-η·deg)⊙x + η·W@x. O(kNd), no L formed.
26
+ * **simple** — D = I - η*L, applied once. O(N²d).
27
+ * **iterative** — D^T X. Deeper smoothing; early-stop guard. O(T·N²d).
28
+ * **spectral** — H = U exp(-η Λ) U^T. Exact heat kernel. O(N³) precompute.
29
+
30
+ Two attention modes (implemented in ``AttentionOperator``):
31
+ * **dense** (default) — full softmax(β Q Kᵀ) V. O(N²d). Exact baseline.
32
+ * **graph** — attend only to kNN neighbors. O(kNd). Faster.
33
+
34
+ Design
35
+ ------
36
+ ``DiffusedHopfield`` is a drop-in replacement for ``Hopfield``. It inherits
37
+ the full constructor and only overrides ``_associate`` to splice in diffusion.
38
+
39
+ Internally it delegates all graph/diffusion work to:
40
+ * ``GraphCache`` — builds and caches (W, deg, adj_idx, L, op) once.
41
+ * ``DynamicsEngine`` — runs T-step loop; no rebuild inside loop.
42
+ * ``AttentionOperator`` — dense or graph-constrained attention.
43
+ * ``EnergyTracker`` — (optional) per-step energy + early-stop.
44
+
45
+ This satisfies Open-Closed: new diffusion modes are added by subclassing
46
+ ``DiffusionOperator`` in ``diffusion.py``; no changes needed here.
47
+ """
48
+
49
+ from typing import Dict, Optional, Tuple, Union
50
+
51
+ import copy
52
+
53
+ import torch
54
+ import torch.nn.functional as F
55
+ from torch import Tensor
56
+
57
+ from . import Hopfield
58
+ from .attention_operator import AttentionOperator
59
+ from .diffusion import DiffusionOperator
60
+ from .dynamics_engine import DiffusionConfig, DynamicsEngine, EnergyTracker, GraphCache
61
+
62
+
63
+ class DiffusedHopfield(Hopfield):
64
+ """
65
+ Hopfield association module augmented with Laplacian graph diffusion.
66
+
67
+ Adds diffusion-specific parameters on top of the full ``Hopfield`` API.
68
+ All graph construction, caching, and diffusion are delegated to
69
+ ``GraphCache`` and ``DynamicsEngine``; this class only orchestrates the
70
+ pre-processing hook inside ``_associate``.
71
+
72
+ New parameters (beyond the standard Hopfield signature)
73
+ -------------------------------------------------------
74
+ eta : float — Diffusion strength η. Default: 0.1.
75
+ k_neighbors : int — kNN graph degree. Default: 5.
76
+ use_normalized_laplacian : bool — Symmetric-normalised L. Default: True.
77
+ diffuse_query : bool — Diffuse query patterns. Default: False.
78
+ diffuse_key : bool — Diffuse key patterns. Default: True.
79
+ diffusion_mode : str — 'factored'|'simple'|'iterative'|'spectral'.
80
+ diffusion_steps : int — Iterations for iterative/spectral mode.
81
+ attention_mode : str — 'dense' (O(N²)) or 'graph' (O(kN)).
82
+ use_sparse : bool — Sparse adjacency for O(kN) products.
83
+ use_logit_diffusion : bool — Smooth post-softmax weights.
84
+ logit_eta : float — η for logit-level diffusion.
85
+ adaptive_eta : bool — Entropy-gated η scaling.
86
+ adaptive_temperature : float — Sigmoid temperature for adaptive η.
87
+ adaptive_threshold : float — Entropy midpoint for adaptive η.
88
+ cache_graph : bool — Cache graph/operator between passes.
89
+ energy_stop_tol : float — Early-stop tolerance; 0 = off.
90
+
91
+ All remaining kwargs are forwarded unchanged to ``Hopfield``.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ input_size: Optional[int] = None,
97
+ hidden_size: Optional[int] = None,
98
+ output_size: Optional[int] = None,
99
+ pattern_size: Optional[int] = None,
100
+ num_heads: int = 1,
101
+ scaling: Optional[Union[float, Tensor]] = None,
102
+ update_steps_max: Optional[Union[int, Tensor]] = 0,
103
+ update_steps_eps: Union[float, Tensor] = 1e-4,
104
+
105
+ normalize_stored_pattern: bool = True,
106
+ normalize_stored_pattern_affine: bool = True,
107
+ normalize_stored_pattern_eps: float = 1e-5,
108
+ normalize_state_pattern: bool = True,
109
+ normalize_state_pattern_affine: bool = True,
110
+ normalize_state_pattern_eps: float = 1e-5,
111
+ normalize_pattern_projection: bool = True,
112
+ normalize_pattern_projection_affine: bool = True,
113
+ normalize_pattern_projection_eps: float = 1e-5,
114
+ normalize_hopfield_space: bool = False,
115
+ normalize_hopfield_space_affine: bool = False,
116
+ normalize_hopfield_space_eps: float = 1e-5,
117
+ stored_pattern_as_static: bool = False,
118
+ state_pattern_as_static: bool = False,
119
+ pattern_projection_as_static: bool = False,
120
+ pattern_projection_as_connected: bool = False,
121
+ stored_pattern_size: Optional[int] = None,
122
+ pattern_projection_size: Optional[int] = None,
123
+
124
+ batch_first: bool = True,
125
+ association_activation: Optional[str] = None,
126
+ dropout: float = 0.0,
127
+ input_bias: bool = True,
128
+ concat_bias_pattern: bool = False,
129
+ add_zero_association: bool = False,
130
+ disable_out_projection: bool = False,
131
+
132
+ # --- Diffusion-specific parameters ---
133
+ eta: float = 0.1,
134
+ k_neighbors: int = 5,
135
+ use_normalized_laplacian: bool = True,
136
+ diffuse_query: bool = False,
137
+ diffuse_key: bool = True,
138
+ diffusion_mode: str = "factored",
139
+ diffusion_steps: int = 3,
140
+ attention_mode: str = "dense",
141
+ use_sparse: bool = False,
142
+ use_logit_diffusion: bool = False,
143
+ logit_eta: Optional[float] = None,
144
+ adaptive_eta: bool = False,
145
+ adaptive_temperature: float = 5.0,
146
+ adaptive_threshold: float = 1.0,
147
+ cache_graph: bool = True,
148
+ energy_stop_tol: float = 0.0,
149
+ ):
150
+ super().__init__(
151
+ input_size=input_size,
152
+ hidden_size=hidden_size,
153
+ output_size=output_size,
154
+ pattern_size=pattern_size,
155
+ num_heads=num_heads,
156
+ scaling=scaling,
157
+ update_steps_max=update_steps_max,
158
+ update_steps_eps=update_steps_eps,
159
+ normalize_stored_pattern=normalize_stored_pattern,
160
+ normalize_stored_pattern_affine=normalize_stored_pattern_affine,
161
+ normalize_stored_pattern_eps=normalize_stored_pattern_eps,
162
+ normalize_state_pattern=normalize_state_pattern,
163
+ normalize_state_pattern_affine=normalize_state_pattern_affine,
164
+ normalize_state_pattern_eps=normalize_state_pattern_eps,
165
+ normalize_pattern_projection=normalize_pattern_projection,
166
+ normalize_pattern_projection_affine=normalize_pattern_projection_affine,
167
+ normalize_pattern_projection_eps=normalize_pattern_projection_eps,
168
+ normalize_hopfield_space=normalize_hopfield_space,
169
+ normalize_hopfield_space_affine=normalize_hopfield_space_affine,
170
+ normalize_hopfield_space_eps=normalize_hopfield_space_eps,
171
+ stored_pattern_as_static=stored_pattern_as_static,
172
+ state_pattern_as_static=state_pattern_as_static,
173
+ pattern_projection_as_static=pattern_projection_as_static,
174
+ pattern_projection_as_connected=pattern_projection_as_connected,
175
+ stored_pattern_size=stored_pattern_size,
176
+ pattern_projection_size=pattern_projection_size,
177
+ batch_first=batch_first,
178
+ association_activation=association_activation,
179
+ dropout=dropout,
180
+ input_bias=input_bias,
181
+ concat_bias_pattern=concat_bias_pattern,
182
+ add_zero_association=add_zero_association,
183
+ disable_out_projection=disable_out_projection,
184
+ )
185
+
186
+ # Build unified config
187
+ _beta = float(scaling) if isinstance(scaling, (int, float)) else 1.0
188
+ self._diff_cfg = DiffusionConfig(
189
+ eta=eta,
190
+ beta=_beta,
191
+ steps=diffusion_steps,
192
+ diffusion_mode=diffusion_mode,
193
+ attention_mode=attention_mode,
194
+ k_neighbors=k_neighbors,
195
+ use_normalized_laplacian=use_normalized_laplacian,
196
+ use_sparse=use_sparse,
197
+ diffuse_key=diffuse_key,
198
+ diffuse_query=diffuse_query,
199
+ use_logit_diffusion=use_logit_diffusion,
200
+ logit_eta=logit_eta if logit_eta is not None else eta,
201
+ adaptive_eta=adaptive_eta,
202
+ adaptive_temperature=adaptive_temperature,
203
+ adaptive_threshold=adaptive_threshold,
204
+ cache_graph=cache_graph,
205
+ energy_stop_tol=energy_stop_tol,
206
+ )
207
+
208
+ # Separate cache per role (key vs query patterns may differ in shape)
209
+ self._key_cache = GraphCache(self._diff_cfg)
210
+ self._query_cache = GraphCache(self._diff_cfg)
211
+
212
+ # Attention operator — dense O(N²) or graph-constrained O(kN)
213
+ self._attn_op = AttentionOperator(beta=_beta, mode=attention_mode)
214
+
215
+ # Energy tracker (shared across key/query for the last step)
216
+ self._energy_tracker: Optional[EnergyTracker] = (
217
+ EnergyTracker(
218
+ beta=_beta,
219
+ eta=eta,
220
+ tol=energy_stop_tol,
221
+ )
222
+ if energy_stop_tol > 0.0 else None
223
+ )
224
+
225
+ # Expose scalar hypers for backward compat property access
226
+ self.eta = eta
227
+ self.k_neighbors = k_neighbors
228
+ self.use_normalized_laplacian = use_normalized_laplacian
229
+ self.diffuse_query = diffuse_query
230
+ self.diffuse_key = diffuse_key
231
+ self.diffusion_mode = diffusion_mode
232
+ self.diffusion_steps = diffusion_steps
233
+ self.attention_mode = attention_mode
234
+ self.use_logit_diffusion = use_logit_diffusion
235
+ self.logit_eta = logit_eta if logit_eta is not None else eta
236
+ self.adaptive_eta = adaptive_eta
237
+ self.adaptive_temperature = adaptive_temperature
238
+ self.adaptive_threshold = adaptive_threshold
239
+ self.cache_graph = cache_graph
240
+
241
+ # Cache Hopfield name-mangled attrs used in association_core call
242
+ self._d_update_steps_max = update_steps_max
243
+ self._d_update_steps_eps = update_steps_eps
244
+
245
+ # ------------------------------------------------------------------
246
+ # Cache management
247
+ # ------------------------------------------------------------------
248
+
249
+ def invalidate_cache(self) -> None:
250
+ """Clear cached Laplacians and operators. Call when patterns change."""
251
+ self._key_cache.invalidate()
252
+ self._query_cache.invalidate()
253
+
254
+ # ------------------------------------------------------------------
255
+ # Override _associate to inject diffusion
256
+ # ------------------------------------------------------------------
257
+
258
+ def _associate(
259
+ self,
260
+ data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
261
+ return_raw_associations: bool = False,
262
+ return_projected_patterns: bool = False,
263
+ stored_pattern_padding_mask: Optional[Tensor] = None,
264
+ association_mask: Optional[Tensor] = None,
265
+ ) -> Tuple[Optional[Tensor], ...]:
266
+ """
267
+ Hopfield association with graph-diffusion pre-processing.
268
+
269
+ Mirrors ``Hopfield._associate`` exactly; inserts a DynamicsEngine
270
+ diffusion pass on stored / state patterns after optional LayerNorm
271
+ but before HopfieldCore attention.
272
+
273
+ No graph is rebuilt if the same patterns are passed again
274
+ (GraphCache returns the cached DiffusionOperator in O(1)).
275
+ """
276
+ assert (type(data) == Tensor) or (
277
+ (type(data) == tuple) and (len(data) == 3)
278
+ ), (
279
+ "either one tensor or a 3-tuple "
280
+ "(stored_pattern, state_pattern, pattern_projection) must be provided."
281
+ )
282
+
283
+ if type(data) == Tensor:
284
+ stored_pattern = state_pattern = pattern_projection = data
285
+ else:
286
+ stored_pattern, state_pattern, pattern_projection = data
287
+
288
+ # --- batch_first transpose ---
289
+ stored_pattern, state_pattern, pattern_projection = self._maybe_transpose(
290
+ stored_pattern, state_pattern, pattern_projection
291
+ )
292
+
293
+ # --- Optional LayerNorm (mirroring Hopfield._associate) ---
294
+ if self.norm_stored_pattern is not None:
295
+ stored_pattern = self.norm_stored_pattern(
296
+ input=stored_pattern.reshape(-1, stored_pattern.shape[2])
297
+ ).reshape(stored_pattern.shape)
298
+
299
+ if self.norm_state_pattern is not None:
300
+ state_pattern = self.norm_state_pattern(
301
+ input=state_pattern.reshape(-1, state_pattern.shape[2])
302
+ ).reshape(state_pattern.shape)
303
+
304
+ if self.norm_pattern_projection is not None:
305
+ pattern_projection = self.norm_pattern_projection(
306
+ input=pattern_projection.reshape(-1, pattern_projection.shape[2])
307
+ ).reshape(pattern_projection.shape)
308
+
309
+ # --- Full dynamics loop: interleaved diffusion + attention (§0, §4) ---
310
+ cfg = self._diff_cfg
311
+ _W_k = _deg_k = _adj_k = L_k = op_k = None
312
+ _W_q = _deg_q = _adj_q = L_q = op_q = None
313
+ if cfg.eta > 0.0 and (cfg.diffuse_key or cfg.diffuse_query):
314
+ key_repr = stored_pattern.detach().mean(dim=1).float()
315
+ _W_k, _deg_k, _adj_k, L_k, op_k = self._key_cache.get(key_repr)
316
+
317
+ # Build a query-specific graph so Q is diffused over its own topology.
318
+ if cfg.diffuse_query:
319
+ q_repr = state_pattern.detach().mean(dim=1).float()
320
+ _W_q, _deg_q, _adj_q, L_q, op_q = self._query_cache.get(q_repr)
321
+
322
+ # Adaptive η: scale by attention entropy before dynamics loop
323
+ eta_eff = cfg.eta
324
+ if cfg.adaptive_eta:
325
+ raw_logits = torch.bmm(
326
+ state_pattern.permute(1, 0, 2),
327
+ stored_pattern.permute(1, 2, 0),
328
+ ) # (B, L, S)
329
+ engine_tmp = DynamicsEngine(op_k)
330
+ eta_eff = engine_tmp.compute_adaptive_eta(
331
+ raw_logits, cfg.eta,
332
+ cfg.adaptive_temperature, cfg.adaptive_threshold,
333
+ )
334
+ if abs(eta_eff - cfg.eta) / (cfg.eta + 1e-9) > 0.05:
335
+ from .diffusion import FactoredDiffusion
336
+ if cfg.diffusion_mode == "factored":
337
+ op_k = FactoredDiffusion(
338
+ eta=eta_eff, steps=cfg.steps
339
+ ).precompute_from_graph(_W_k, _deg_k)
340
+ else:
341
+ op_k = DiffusionOperator.create(
342
+ cfg.diffusion_mode, eta_eff, cfg.steps
343
+ ).precompute(L_k)
344
+ # Rebuild query op with same eta_eff
345
+ if op_q is not None:
346
+ if cfg.diffusion_mode == "factored":
347
+ op_q = FactoredDiffusion(
348
+ eta=eta_eff, steps=cfg.steps
349
+ ).precompute_from_graph(_W_q, _deg_q)
350
+ else:
351
+ op_q = DiffusionOperator.create(
352
+ cfg.diffusion_mode, eta_eff, cfg.steps
353
+ ).precompute(L_q)
354
+
355
+ # Single-step copy for dynamics (outer loop controls iteration count)
356
+ op_dyn = copy.copy(op_k)
357
+ op_dyn.steps = 1
358
+
359
+ op_q_dyn = None
360
+ if op_q is not None:
361
+ op_q_dyn = copy.copy(op_q)
362
+ op_q_dyn.steps = 1
363
+
364
+ engine = DynamicsEngine(
365
+ diffusion_op=op_dyn,
366
+ attention_op=self._attn_op,
367
+ steps=self.diffusion_steps,
368
+ energy_tracker=self._energy_tracker,
369
+ query_diffusion_op=op_q_dyn,
370
+ )
371
+ state_pattern, stored_pattern = engine.run_dynamics(
372
+ Q=state_pattern, K=stored_pattern, V=pattern_projection,
373
+ adj_indices=_adj_k, L=L_k, W=_W_k, deg=_deg_k,
374
+ diffuse_query=cfg.diffuse_query, diffuse_key=cfg.diffuse_key,
375
+ )
376
+
377
+ # --- Logit-level diffusion (post-softmax weight smoothing) ---
378
+ # Injected after core association when requested.
379
+ # We apply it on the raw-association output before returning.
380
+ result = self.association_core(
381
+ query=state_pattern,
382
+ key=stored_pattern,
383
+ value=pattern_projection,
384
+ key_padding_mask=stored_pattern_padding_mask,
385
+ need_weights=cfg.use_logit_diffusion, # need weights for smoothing
386
+ attn_mask=association_mask,
387
+ scaling=self.scaling,
388
+ update_steps_max=self._d_update_steps_max,
389
+ update_steps_eps=self._d_update_steps_eps,
390
+ return_raw_associations=return_raw_associations,
391
+ return_pattern_projections=return_projected_patterns,
392
+ )
393
+
394
+ if cfg.use_logit_diffusion and cfg.logit_eta > 0.0:
395
+ # result[0] = output, result[1] = attn_weights (B, H, L, S)
396
+ # Smooth weights over the key graph and re-normalise.
397
+ attn_weights = result[1] # (B, H, L, S) or None
398
+ if attn_weights is not None:
399
+ # Reuse cached graph if available; otherwise build from keys
400
+ if op_k is None:
401
+ key_repr = stored_pattern.detach().mean(dim=1).float()
402
+ _W_k, _deg_k, _adj_k, L_k, op_k = self._key_cache.get(key_repr)
403
+ op_logit = op_k
404
+ # Treat (S,) distribution per query as the diffusion signal.
405
+ # Flatten to (S, B*H*L), diffuse, reshape and renormalise.
406
+ B, H, L_q_len, S = attn_weights.shape
407
+ w_flat = attn_weights.permute(3, 0, 1, 2).reshape(S, B * H * L_q_len)
408
+ w_diff = op_logit(w_flat) # diffuse along S
409
+ w_diff = w_diff.reshape(S, B, H, L_q_len).permute(1, 2, 3, 0)
410
+ w_diff = w_diff.clamp(min=0.0)
411
+ w_diff = w_diff / (w_diff.sum(dim=-1, keepdim=True) + 1e-9)
412
+ # Recompute output with smoothed weights.
413
+ V = pattern_projection.permute(1, 0, 2) # (B, S, d)
414
+ out_smooth = torch.einsum("bhls,bsd->bld", w_diff, V)
415
+ out_smooth = out_smooth.permute(1, 0, 2) # (L, B, d)
416
+ result = (out_smooth,) + result[1:]
417
+
418
+ return result
419
+
420
+ # ------------------------------------------------------------------
421
+ # Config dict API
422
+ # ------------------------------------------------------------------
423
+
424
+ def get_config(self) -> Dict[str, object]:
425
+ """Return a JSON-serialisable dict of all diffusion hyperparameters."""
426
+ return self._diff_cfg.to_dict()
427
+