cida-plugin 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cida_plugin/core.py ADDED
@@ -0,0 +1,451 @@
1
+ """
2
+ CIDAPlugin — Universal Multi-Agent Deliberation Layer (v3 Simplified).
3
+
4
+ Architecture:
5
+ 1. Input Projection: d_input → d_hidden
6
+ 2. Agent Initialization with Role Priors (structural disagreement)
7
+ 3. Deliberation Loop (evidence → messages → communication → update)
8
+ 4. Weighted Mean Consensus (correct for correlated agents)
9
+ 5. Uncertainty = f(disagreement) — observed, not learned
10
+ """
11
+ import os
12
+ import warnings
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from .config import CIDAPluginConfig
18
+ from .agent import AgentState, apply_role_prior, apply_role_priors_batched, compute_role_orthogonality_loss
19
+ from .deliberation import (
20
+ AgentEvidenceExtractor,
21
+ MessageFormulator,
22
+ CounterargumentCommunication,
23
+ AgentUpdater,
24
+ PerspectiveProjector,
25
+ )
26
+ from .consensus import ConsensusAggregator, HaltingPredictor
27
+ from .ttt import TestTimeTrainer
28
+ from .liquid_dynamics import LiquidDeliberationSolver
29
+
30
+
31
+ class CIDAPlugin(nn.Module):
32
+ """
33
+ Universal CIDA-Plugin.
34
+
35
+ Accepts pooled_output from ANY encoder via input_proj,
36
+ runs multi-agent deliberation with structural role priors,
37
+ returns calibrated class distribution with uncertainty.
38
+
39
+ Example:
40
+ cfg = CIDAPluginConfig(d_input=768, num_classes=2)
41
+ plugin = CIDAPlugin(cfg)
42
+ out = plugin(pooled) # (B, 768) → dict
43
+ logits = out["p_final"] # (B, 2)
44
+ """
45
+
46
+ def __init__(self, config: CIDAPluginConfig):
47
+ super().__init__()
48
+ self.config = config
49
+ cfg = config
50
+
51
+ # ── Input projection: d_input → d_hidden ──────────────────────────────
52
+ self.input_proj = nn.Sequential(
53
+ nn.Linear(cfg.d_input, cfg.d_hidden),
54
+ nn.LayerNorm(cfg.d_hidden),
55
+ nn.SiLU(),
56
+ )
57
+ if cfg.freeze_input_proj:
58
+ for p in self.input_proj.parameters():
59
+ p.requires_grad = False
60
+
61
+ # Projection for sequence output (evidence pointers)
62
+ self.seq_proj = (
63
+ nn.Linear(cfg.d_input, cfg.d_hidden)
64
+ if cfg.d_input != cfg.d_hidden
65
+ else nn.Identity()
66
+ )
67
+
68
+ # ── Perspective Projector (optional per-agent projections) ─────────────
69
+ if cfg.use_perspective_projector:
70
+ self.perspective_projector = PerspectiveProjector(
71
+ cfg.num_agents, cfg.d_input, cfg.d_hidden
72
+ )
73
+ else:
74
+ self.perspective_projector = None
75
+
76
+ # ── Deliberation components ───────────────────────────────────────────
77
+ self.evidence_extractor = AgentEvidenceExtractor(cfg.d_hidden)
78
+ self.message_formulator = MessageFormulator(
79
+ cfg.d_hidden, cfg.num_classes, cfg.d_message,
80
+ anonymous=cfg.anonymous_messages,
81
+ )
82
+ self.communication = CounterargumentCommunication(
83
+ cfg.d_hidden, cfg.d_message
84
+ )
85
+ self.updater = AgentUpdater(
86
+ cfg.d_hidden,
87
+ cfg.num_classes,
88
+ num_heads=cfg.num_attn_heads,
89
+ multi_label=cfg.multi_label,
90
+ )
91
+
92
+ # ── Consensus (Weighted Mean for correlated agents) ───────────────────
93
+ self.consensus = ConsensusAggregator(
94
+ num_agents=cfg.num_agents,
95
+ multi_label=cfg.multi_label,
96
+ use_dynamic_trust=cfg.use_dynamic_trust,
97
+ trust_ema_gamma=cfg.trust_ema_gamma,
98
+ )
99
+ self.halter = HaltingPredictor()
100
+
101
+ # ── Gated communication warm-up ───────────────────────────────────────
102
+ self.comm_gate = nn.Sequential(
103
+ nn.Linear(cfg.d_hidden, 1),
104
+ nn.Sigmoid(),
105
+ )
106
+ nn.init.constant_(self.comm_gate[0].bias, -0.5)
107
+
108
+ # ── Test-Time Training (optional) ─────────────────────────────────────
109
+ if cfg.use_ttt:
110
+ self.ttt_trainer = TestTimeTrainer(
111
+ d_hidden=cfg.d_hidden,
112
+ ttt_steps=cfg.ttt_steps,
113
+ ttt_lr=cfg.ttt_lr,
114
+ mask_ratio=cfg.ttt_mask_ratio,
115
+ )
116
+
117
+ # ── Liquid Neural ODE Dynamics (optional) ─────────────────────────────
118
+ if cfg.use_liquid_dynamics:
119
+ self.liquid_solver = LiquidDeliberationSolver(
120
+ extractor=self.evidence_extractor,
121
+ message_formulator=self.message_formulator,
122
+ communication=self.communication,
123
+ updater=self.updater,
124
+ d_hidden=cfg.d_hidden,
125
+ solver=cfg.liquid_solver,
126
+ atol=cfg.liquid_atol,
127
+ rtol=cfg.liquid_rtol,
128
+ T=float(cfg.max_rounds),
129
+ trajectory_save_every=cfg.trajectory_save_every,
130
+ )
131
+
132
+ # ── Hugging Face Serialization ────────────────────────────────────────────
133
+
134
+ def save_pretrained(self, save_directory: str):
135
+ """Save config and weights for later loading or HF Hub publishing."""
136
+ os.makedirs(save_directory, exist_ok=True)
137
+ self.config.save_pretrained(save_directory)
138
+ weights_path = os.path.join(save_directory, "pytorch_model.bin")
139
+ torch.save(self.state_dict(), weights_path)
140
+
141
+ @classmethod
142
+ def from_pretrained(cls, pretrained_model_name_or_path: str):
143
+ """Load a pretrained plugin (config + weights)."""
144
+ config = CIDAPluginConfig.from_pretrained(
145
+ pretrained_model_name_or_path
146
+ )
147
+ model = cls(config)
148
+
149
+ if os.path.isdir(pretrained_model_name_or_path):
150
+ weights_path = os.path.join(
151
+ pretrained_model_name_or_path, "pytorch_model.bin"
152
+ )
153
+ else:
154
+ from huggingface_hub import hf_hub_download
155
+ weights_path = hf_hub_download(
156
+ repo_id=pretrained_model_name_or_path,
157
+ filename="pytorch_model.bin",
158
+ )
159
+
160
+ state_dict = torch.load(weights_path, map_location="cpu")
161
+ model.load_state_dict(state_dict)
162
+ return model
163
+
164
+ # ── Forward ───────────────────────────────────────────────────────────────
165
+
166
+ def forward(
167
+ self,
168
+ pooled: torch.Tensor,
169
+ seq_output: torch.Tensor = None,
170
+ mask: torch.Tensor = None,
171
+ ) -> dict:
172
+ """
173
+ Parameters
174
+ ----------
175
+ pooled : (B, d_input) pooled representation from encoder
176
+ seq_output : (B, seq_len, d_input) optional — for evidence pointers
177
+ mask : (B, seq_len) padding mask
178
+
179
+ Returns dict:
180
+ p_final : (B, K) — final calibrated distribution
181
+ b_all : list — belief history per round
182
+ u_all : list — uncertainty history
183
+ m_all : list — message history
184
+ p_ptr_all : list — evidence pointer history
185
+ uncertainty : (B,) — final uncertainty
186
+ disagreement : (B,) — final agent disagreement
187
+ L_orth : scalar — orthogonality loss (representations)
188
+ rounds_used : int — actual number of rounds used
189
+ ponder_cost : (B,) or None — NFE for liquid dynamics
190
+ ttt_info : dict or None — TTT diagnostics
191
+ """
192
+ cfg = self.config
193
+ B = pooled.size(0)
194
+ M = cfg.num_agents
195
+
196
+ # ── Project input ─────────────────────────────────────────────────────
197
+ h_cls = self.input_proj(pooled) # (B, d_hidden)
198
+
199
+ H = None
200
+ if seq_output is not None:
201
+ H = self.seq_proj(seq_output) # (B, seq_len, d_hidden)
202
+
203
+ # ── Initialize agents ─────────────────────────────────────────────────
204
+ # All agents start from the same representation (no RoleEmbeddings).
205
+ # Structural diversity comes from apply_role_prior on beliefs.
206
+ if self.perspective_projector is not None:
207
+ # Per-agent projections for diverse initial views
208
+ s_t = self.perspective_projector(pooled) # (B, M, d_hidden)
209
+ else:
210
+ # All agents start from the same representation
211
+ s_t = h_cls.unsqueeze(1).expand(B, M, cfg.d_hidden).clone()
212
+
213
+ seq_len = H.size(1) if H is not None else 1
214
+ e_init = torch.zeros(B, M, cfg.d_hidden, device=pooled.device)
215
+ r_init = torch.zeros_like(e_init)
216
+
217
+ if cfg.multi_label:
218
+ b_init = torch.full(
219
+ (B, M, cfg.num_classes), 0.5, device=pooled.device
220
+ )
221
+ else:
222
+ b_init = torch.zeros(
223
+ B, M, cfg.num_classes, device=pooled.device
224
+ )
225
+
226
+ state_t = AgentState(
227
+ s=s_t,
228
+ b=b_init,
229
+ u=torch.ones(B, M, 1, device=pooled.device),
230
+ p=torch.zeros(B, M, seq_len, device=pooled.device),
231
+ e=e_init,
232
+ alpha=torch.ones(B, M, cfg.num_classes, device=pooled.device),
233
+ )
234
+
235
+ # Initial update to get first beliefs from data
236
+ state_t = self.updater(state_t, r_init, e_init)
237
+
238
+ # Apply structural role priors to initial beliefs (Batched)
239
+ state_t.b = apply_role_priors_batched(
240
+ state_t.b,
241
+ M,
242
+ multi_label=cfg.multi_label,
243
+ positive_class_idx=cfg.positive_class_idx,
244
+ )
245
+
246
+ # ── TTT: adapt before deliberation ────────────────────────────────────
247
+ ttt_info = None
248
+ _ttt_saved_updater = None
249
+ _ttt_saved_head = None
250
+ if cfg.use_ttt:
251
+ ttt_info = self.ttt_trainer.adapt_and_get_info(
252
+ self.updater, state_t.s
253
+ )
254
+ _ttt_saved_updater = ttt_info.pop("saved_updater_state")
255
+ _ttt_saved_head = ttt_info.pop("saved_head_state")
256
+
257
+ # ── Deliberation ──────────────────────────────────────────────────────
258
+ b_all, u_all, m_all, p_ptr_all = [], [], [], []
259
+ rounds_used = 0
260
+ ponder_cost = None
261
+
262
+ if cfg.use_liquid_dynamics and H is not None:
263
+ # ── Continuous dynamics via ODE ────────────────────────────────────
264
+ s_final, traj = self.liquid_solver(state_t.s, H, mask)
265
+ b_all = traj["b_all"]
266
+ u_all = traj["u_all"]
267
+ p_ptr_all = traj["p_ptr_all"]
268
+ m_all = traj["m_all"]
269
+
270
+ if len(b_all) > 0:
271
+ state_t.s = s_final
272
+ # Final update to get beliefs from final state
273
+ e_final = s_final
274
+ if H is not None and not cfg.abl_no_pointers:
275
+ _, e_final = self.evidence_extractor(s_final, H, mask)
276
+
277
+ r_final = torch.zeros_like(s_final)
278
+ if not cfg.abl_no_messages and not cfg.abl_no_communication:
279
+ num_classes = self.updater.g[-1].out_features
280
+ dummy_b = torch.zeros(
281
+ B, M, num_classes, device=pooled.device
282
+ )
283
+ dummy_u = torch.ones(B, M, 1, device=pooled.device)
284
+ m_final = self.message_formulator(
285
+ s_final, e_final, dummy_b, dummy_u
286
+ )
287
+ r_final = self.communication(
288
+ s_final, m_final, dummy_b, e_final
289
+ )
290
+
291
+ dummy_state = AgentState(
292
+ s=s_final,
293
+ b=b_all[-1] if b_all else state_t.b,
294
+ u=u_all[-1] if u_all else state_t.u,
295
+ p=torch.zeros(B, M, 1, device=pooled.device),
296
+ e=e_final,
297
+ alpha=torch.ones(
298
+ B, M, cfg.num_classes, device=pooled.device
299
+ ),
300
+ )
301
+ state_t = self.updater(dummy_state, r_final, e_final)
302
+
303
+ # Apply role priors to final beliefs (Batched)
304
+ state_t.b = apply_role_priors_batched(
305
+ state_t.b,
306
+ M,
307
+ multi_label=cfg.multi_label,
308
+ positive_class_idx=cfg.positive_class_idx,
309
+ )
310
+ b_all[-1] = state_t.b
311
+
312
+ ponder_cost = torch.full(
313
+ (B,), float(traj["nfe"]), device=pooled.device
314
+ )
315
+ rounds_used = traj["nfe"]
316
+
317
+ else:
318
+ # ── Discrete deliberation loop ────────────────────────────────────
319
+ if cfg.use_liquid_dynamics and H is None:
320
+ # ── Silent fallback warning ────────────────────────────────────
321
+ warnings.warn(
322
+ "use_liquid_dynamics=True but seq_output is None. "
323
+ "Falling back to discrete deliberation mode. "
324
+ "Provide seq_output for continuous ODE dynamics.",
325
+ UserWarning,
326
+ )
327
+
328
+ from torch.utils.checkpoint import checkpoint
329
+
330
+ prev_disagreement = float('inf')
331
+ rollback_triggered = False
332
+
333
+ for t in range(cfg.max_rounds):
334
+ rounds_used = t + 1
335
+
336
+ def deliberation_round(s_in, b_in, u_in, p_in, e_in):
337
+ state = AgentState(s=s_in, b=b_in, u=u_in, p=p_in, e=e_in, alpha=torch.ones_like(b_in))
338
+
339
+ # Evidence Pointers
340
+ if cfg.abl_no_pointers or H is None:
341
+ p_new = (torch.ones(B, M, seq_len, device=pooled.device) / seq_len)
342
+ e_new = state.s
343
+ else:
344
+ p_new, e_new = self.evidence_extractor(state.s, H, mask)
345
+ state.p = p_new
346
+ state.e = e_new
347
+
348
+ # Message Formulation
349
+ if cfg.abl_no_messages:
350
+ m_new = torch.zeros(B, M, cfg.d_message, device=pooled.device)
351
+ else:
352
+ m_new = self.message_formulator(state.s, state.e, state.b, state.u)
353
+
354
+ # Communication
355
+ if cfg.abl_no_communication:
356
+ r_new = torch.zeros_like(state.s)
357
+ else:
358
+ r_new = self.communication(state.s, m_new, state.b, state.e, sparse_k=cfg.sparse_comm_k)
359
+ gate = self.comm_gate(state.s)
360
+ r_new = gate * r_new
361
+ if self.training and cfg.comm_dropout > 0:
362
+ r_new = F.dropout(r_new, p=cfg.comm_dropout, training=True)
363
+
364
+ # Agent Update
365
+ state = self.updater(state, r_new, e_new)
366
+
367
+ # Apply role priors to updated beliefs
368
+ state.b = apply_role_priors_batched(
369
+ state.b, M, multi_label=cfg.multi_label, positive_class_idx=cfg.positive_class_idx
370
+ )
371
+
372
+ return state.s, state.b, state.u, state.p, state.e, m_new
373
+
374
+ if self.training and state_t.s.requires_grad and pooled.is_cuda:
375
+ s_t, b_t, u_t, p_t, e_t, m_t = checkpoint(
376
+ deliberation_round, state_t.s, state_t.b, state_t.u, state_t.p, state_t.e,
377
+ use_reentrant=False
378
+ )
379
+ else:
380
+ s_t, b_t, u_t, p_t, e_t, m_t = deliberation_round(
381
+ state_t.s, state_t.b, state_t.u, state_t.p, state_t.e
382
+ )
383
+
384
+ # ── Kairos Dynamic Deliberation: Time-Travel (Rollback) ────────
385
+ if getattr(cfg, 'enable_rollback', False) and not self.training:
386
+ with torch.no_grad():
387
+ _, _, current_d = self.consensus(b_t)
388
+ current_d_mean = current_d.mean().item()
389
+
390
+ if current_d_mean > prev_disagreement:
391
+ # Entropy increased -> Rollback!
392
+ noise_std = getattr(cfg, 'rollback_noise_std', 0.05)
393
+ noise = torch.randn_like(state_t.s) * noise_std
394
+ state_t.s = state_t.s + noise
395
+ rollback_triggered = True
396
+ prev_disagreement = float('inf') # Reset to allow new path
397
+ continue # Skip update and try next round
398
+
399
+ prev_disagreement = current_d_mean
400
+ # ───────────────────────────────────────────────────────────────
401
+
402
+ # Update main state
403
+ state_t.s, state_t.b, state_t.u, state_t.p, state_t.e = s_t, b_t, u_t, p_t, e_t
404
+
405
+ p_ptr_all.append(state_t.p)
406
+ m_all.append(m_t)
407
+ b_all.append(state_t.b)
408
+ u_all.append(state_t.u)
409
+
410
+ # Confidence-based Early Stopping (inference only)
411
+ if not self.training and cfg.early_stop_threshold is not None:
412
+ with torch.no_grad():
413
+ tmp_p, _, tmp_d = self.consensus(state_t.b)
414
+ if self.halter.should_halt(tmp_d).all():
415
+ break
416
+ if (
417
+ tmp_p.max(dim=-1).values.mean().item()
418
+ >= cfg.early_stop_threshold
419
+ ):
420
+ break
421
+
422
+ # ── Final Consensus ───────────────────────────────────────────────────
423
+ p_final, uncertainty, disagreement = self.consensus(
424
+ state_t.b, state_t.u
425
+ )
426
+
427
+ # ── Orthogonality loss (on representations, not weights) ──────────────
428
+ L_orth = compute_role_orthogonality_loss(state_t.s)
429
+
430
+ # ── TTT: restore weights ──────────────────────────────────────────────
431
+ if cfg.use_ttt and _ttt_saved_updater is not None:
432
+ self.ttt_trainer.restore_weights(
433
+ self.updater, _ttt_saved_updater, _ttt_saved_head
434
+ )
435
+
436
+ return {
437
+ "p_final": p_final,
438
+ "b_all": b_all,
439
+ "u_all": u_all,
440
+ "m_all": m_all,
441
+ "p_ptr_all": p_ptr_all,
442
+ "uncertainty": uncertainty,
443
+ "disagreement": disagreement,
444
+ "u_epi": state_t.u_epi,
445
+ "u_alea": state_t.u_alea,
446
+ "L_orth": L_orth,
447
+ "state_final": state_t,
448
+ "rounds_used": rounds_used,
449
+ "ponder_cost": ponder_cost,
450
+ "ttt_info": ttt_info,
451
+ }