cida-plugin 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,74 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class LatentSemanticTranslator(nn.Module):
6
+ """
7
+ Module 3 of Kairos OS: Latent-to-Semantic Translator.
8
+ Translates CIDA's latent superposition (Prosecutor vs Defender beliefs)
9
+ into structured semantic concept embeddings that an LLM can understand.
10
+ """
11
+ def __init__(self, d_hidden: int = 128, vocab_size: int = 50000, d_model_llm: int = 4096):
12
+ super().__init__()
13
+ self.d_hidden = d_hidden
14
+ self.d_model_llm = d_model_llm
15
+
16
+ # Non-linear projection from CIDA latent space to LLM semantic space
17
+ self.latent_to_semantic = nn.Sequential(
18
+ nn.Linear(d_hidden, d_hidden * 4),
19
+ nn.GELU(),
20
+ nn.Dropout(0.1),
21
+ nn.Linear(d_hidden * 4, d_model_llm)
22
+ )
23
+
24
+ # Attention pooling to summarize evidence
25
+ self.evidence_attention = nn.MultiheadAttention(embed_dim=d_model_llm, num_heads=8, batch_first=True)
26
+
27
+ def forward(self, superposition_state: dict, class_names: list = None):
28
+ """
29
+ Args:
30
+ superposition_state: Dict containing 'prosecutor_belief' and 'defender_belief'
31
+ Each tensor is of shape (B, num_classes) or (B, d_hidden)
32
+ class_names: List of class names for semantic grounding
33
+
34
+ Returns:
35
+ semantic_embeddings: (B, Seq, d_model_llm) ready to be prepended to LLM prompt
36
+ """
37
+ if superposition_state is None:
38
+ return None
39
+
40
+ prosecutor_b = superposition_state.get("prosecutor_belief") # (B, ...)
41
+ defender_b = superposition_state.get("defender_belief") # (B, ...)
42
+
43
+ # In a real implementation, b is usually (B, num_classes).
44
+ # For the epic pipeline, CIDA needs to output the raw hidden states (B, d_hidden)
45
+ # alongside the probabilities. We assume they are d_hidden here for the projection.
46
+
47
+ # Protect against shape mismatch if b is just probabilities (B, num_classes)
48
+ # We will pad or project it to d_hidden for this skeleton
49
+ B = prosecutor_b.size(0)
50
+
51
+ if prosecutor_b.size(-1) != self.d_hidden:
52
+ # Dummy projection if we only received probabilities
53
+ pad_p = torch.zeros(B, self.d_hidden, device=prosecutor_b.device)
54
+ pad_p[:, :prosecutor_b.size(-1)] = prosecutor_b
55
+ prosecutor_b = pad_p
56
+
57
+ pad_d = torch.zeros(B, self.d_hidden, device=defender_b.device)
58
+ pad_d[:, :defender_b.size(-1)] = defender_b
59
+ defender_b = pad_d
60
+
61
+ # Map to LLM dimension
62
+ sem_prosecutor = self.latent_to_semantic(prosecutor_b).unsqueeze(1) # (B, 1, d_llm)
63
+ sem_defender = self.latent_to_semantic(defender_b).unsqueeze(1) # (B, 1, d_llm)
64
+
65
+ # Combine hypotheses into a sequence
66
+ semantic_sequence = torch.cat([sem_prosecutor, sem_defender], dim=1) # (B, 2, d_llm)
67
+
68
+ # Self-attention to find conflicts
69
+ attn_out, _ = self.evidence_attention(semantic_sequence, semantic_sequence, semantic_sequence)
70
+
71
+ # Add residual
72
+ semantic_sequence = semantic_sequence + attn_out
73
+
74
+ return semantic_sequence
cida_plugin/ttt.py ADDED
@@ -0,0 +1,348 @@
1
+ """
2
+ Test-Time Training (TTT) для CIDA-Plugin.
3
+
4
+ Концепция (Apple 2024, MIT 2024):
5
+ Обычно: обучение = оффлайн, инференс = заморожен.
6
+ TTT: агент делает K шагов градиентного спуска на каждом входе,
7
+ адаптируясь к конкретному примеру ПЕРЕД ответом.
8
+
9
+ θ*_i = θ_i − α · ∇L_self(θ_i, x) ← K шагов Adam
10
+ b_i = Agent_i(x; θ*_i) ← belief после адаптации
11
+ θ_i ← θ_i ← восстановление весов
12
+
13
+ Почему это важно для CIDA:
14
+ - Первая архитектура где агенты РЕАЛЬНО «думают» о каждом входе
15
+ - Добавляем вычисление (K шагов), не параметры
16
+ - Маленькая модель адаптируется как большая
17
+
18
+ Self-supervised loss:
19
+ Masked reconstruction — маскируем часть компонент скрытого состояния
20
+ агента и заставляем его восстанавливать замаскированное.
21
+ Это заставляет агента «понять» структуру конкретного входа.
22
+ """
23
+
24
+ import io
25
+ from contextlib import contextmanager
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+
31
+ from .agent import AgentState
32
+
33
+
34
+ class MaskedReconstructionHead(nn.Module):
35
+ """
36
+ Вспомогательная голова для self-supervised loss.
37
+
38
+ Принимает замаскированное скрытое состояние агента,
39
+ предсказывает оригинальные значения замаскированных позиций.
40
+
41
+ Архитектура: 2-layer MLP с bottleneck (d_hidden → d_hidden//2 → d_hidden).
42
+ Bottleneck предотвращает тривиальное копирование и заставляет
43
+ модель извлекать семантическую структуру из незамаскированных позиций.
44
+ """
45
+
46
+ def __init__(self, d_hidden: int):
47
+ super().__init__()
48
+ self.net = nn.Sequential(
49
+ nn.Linear(d_hidden, d_hidden // 2),
50
+ nn.SiLU(),
51
+ nn.Linear(d_hidden // 2, d_hidden),
52
+ )
53
+
54
+ def forward(self, masked_states: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ masked_states: (B, M, d_hidden) — состояния с замаскированными компонентами.
57
+ Returns: (B, M, d_hidden) — предсказание оригинальных значений.
58
+ """
59
+ return self.net(masked_states)
60
+
61
+
62
+ class TestTimeTrainer(nn.Module):
63
+ """
64
+ Test-Time Training для агентов CIDA.
65
+
66
+ При каждом forward-проходе:
67
+ 1. Создаёт маску (mask_ratio компонент обнуляются)
68
+ 2. Пропускает замаскированный вход через AgentUpdater
69
+ 3. Считает reconstruction loss на замаскированных позициях
70
+ 4. Делает K шагов Adam на updater + reconstruction_head
71
+ 5. После deliberation — восстанавливает оригинальные веса
72
+
73
+ Параметры
74
+ ---------
75
+ d_hidden : int
76
+ Размерность скрытых состояний агентов.
77
+ ttt_steps : int
78
+ Число шагов градиентного спуска (K). По умолчанию 3.
79
+ Компромисс: больше шагов = лучше адаптация, но дороже.
80
+ ttt_lr : float
81
+ Learning rate для inner Adam. По умолчанию 1e-3.
82
+ Должен быть достаточно большим для быстрой адаптации
83
+ за K шагов, но не настолько чтобы сломать представления.
84
+ mask_ratio : float
85
+ Доля маскируемых компонент. По умолчанию 0.15 (как в BERT).
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ d_hidden: int,
91
+ ttt_steps: int = 3,
92
+ ttt_lr: float = 1e-3,
93
+ mask_ratio: float = 0.15,
94
+ ):
95
+ super().__init__()
96
+ self.ttt_steps = ttt_steps
97
+ self.ttt_lr = ttt_lr
98
+ self.mask_ratio = mask_ratio
99
+ self.d_hidden = d_hidden
100
+
101
+ # Reconstruction head — обучается вместе с основной моделью
102
+ self.reconstruction_head = MaskedReconstructionHead(d_hidden)
103
+
104
+ def _create_mask(self, shape: tuple, device: torch.device) -> torch.Tensor:
105
+ """
106
+ Создаёт бинарную маску: 1 = сохранить, 0 = замаскировать.
107
+
108
+ Маска одинаковая для всех агентов в батче (per-example, not per-agent)
109
+ чтобы все агенты видели одинаковые «пробелы» и были вынуждены
110
+ договариваться о восстановлении.
111
+
112
+ shape: (B, M, d_hidden)
113
+ Returns: (B, 1, d_hidden) — broadcast по агентам
114
+ """
115
+ B = shape[0]
116
+ d = shape[2]
117
+ num_masked = max(1, int(d * self.mask_ratio))
118
+
119
+ mask = torch.ones(B, 1, d, device=device)
120
+ for i in range(B):
121
+ indices = torch.randperm(d, device=device)[:num_masked]
122
+ mask[i, 0, indices] = 0.0
123
+
124
+ return mask
125
+
126
+ def compute_reconstruction_loss(
127
+ self,
128
+ updater: nn.Module,
129
+ masked_states: torch.Tensor,
130
+ original_states: torch.Tensor,
131
+ mask: torch.Tensor,
132
+ ) -> torch.Tensor:
133
+ """
134
+ Self-supervised loss: MSE на замаскированных позициях.
135
+
136
+ updater : nn.Module — AgentUpdater для проброса замаскированных состояний
137
+ masked_states : (B, M, d_hidden) — состояния с маской (вход)
138
+ original_states : (B, M, d_hidden) — оригинальные состояния (target)
139
+ mask : (B, 1, d_hidden) — маска (0 = замаскировано)
140
+
141
+ Returns: scalar — средний MSE на замаскированных позициях.
142
+ """
143
+ B, M, d = masked_states.shape
144
+ device = masked_states.device
145
+
146
+ # Создаем фиктивные данные для AgentUpdater (ему нужны только s, e, r для forward)
147
+ dummy_state = AgentState(
148
+ s=masked_states,
149
+ b=torch.zeros(B, M, 1, device=device),
150
+ u=torch.ones(B, M, 1, device=device),
151
+ p=torch.zeros(B, M, 1, device=device),
152
+ e=torch.zeros(B, M, d, device=device),
153
+ alpha=torch.zeros(B, M, 1, device=device)
154
+ )
155
+ dummy_r = torch.zeros(B, M, d, device=device)
156
+ dummy_e = torch.zeros(B, M, d, device=device)
157
+
158
+ # Пропускаем через updater (чтобы его веса получили градиенты!)
159
+ updated_state = updater(dummy_state, dummy_r, dummy_e)
160
+
161
+ predicted = self.reconstruction_head(updated_state.s) # (B, M, d)
162
+ inverted_mask = 1.0 - mask # 1 = замаскированные позиции
163
+
164
+ # MSE только на замаскированных
165
+ diff = (predicted - original_states) ** 2
166
+ masked_diff = diff * inverted_mask # (B, M, d)
167
+
168
+ # Средний loss: сумма по маскированным / число маскированных
169
+ num_masked = inverted_mask.sum().clamp(min=1.0)
170
+ loss = masked_diff.sum() / num_masked
171
+
172
+ return loss
173
+
174
+ @contextmanager
175
+ def adapt_context(
176
+ self,
177
+ updater: nn.Module,
178
+ agent_states: torch.Tensor,
179
+ ):
180
+ """
181
+ Context manager для TTT-адаптации.
182
+
183
+ Внутри контекста:
184
+ - updater имеет адаптированные веса (θ*)
185
+ - reconstruction_head тоже адаптирован
186
+ При выходе:
187
+ - Все веса восстанавливаются к θ
188
+
189
+ Использование:
190
+ with ttt.adapt_context(updater, s_t) as ttt_info:
191
+ # updater уже адаптирован, запускаем deliberation
192
+ state_t = updater(state_t, r_t, e_t)
193
+
194
+ Возвращает dict с диагностикой:
195
+ - 'losses': list[float] — losses по шагам
196
+ - 'param_shift': float — ||θ* - θ||
197
+ - 'steps_done': int — фактическое число шагов
198
+
199
+ Parameters
200
+ ----------
201
+ updater : nn.Module
202
+ AgentUpdater, чьи веса будут адаптированы.
203
+ agent_states : torch.Tensor
204
+ (B, M, d_hidden) — текущие состояния агентов (target для реконструкции).
205
+ """
206
+ original_states = agent_states.detach()
207
+ device = agent_states.device
208
+
209
+ # ── Сохраняем оригинальные параметры (BytesIO вместо deepcopy — быстрее) ──
210
+ saved_updater_state = self._save_to_buffer(updater)
211
+ saved_head_state = self._save_to_buffer(self.reconstruction_head)
212
+
213
+ # ── Создаём маску ───────────────────────────────────────────────
214
+ mask = self._create_mask(agent_states.shape, device)
215
+
216
+ # ── Inner optimization (K шагов Adam) ───────────────────────────
217
+ # Оптимизируем только updater + reconstruction_head
218
+ inner_params = list(updater.parameters()) + list(
219
+ self.reconstruction_head.parameters()
220
+ )
221
+ inner_optimizer = torch.optim.Adam(inner_params, lr=self.ttt_lr)
222
+
223
+ losses = []
224
+ for step in range(self.ttt_steps):
225
+ inner_optimizer.zero_grad()
226
+
227
+ # Маскируем состояния
228
+ masked_states = original_states * mask # (B, M, d)
229
+
230
+ # Reconstruction loss
231
+ loss = self.compute_reconstruction_loss(
232
+ updater, masked_states, original_states, mask
233
+ )
234
+
235
+ # FIX v4: retain_graph убран — каждый шаг пересоздаёт граф заново
236
+ # (original_states уже detached, mask постоянна → retain не нужен)
237
+ loss.backward()
238
+ inner_optimizer.step()
239
+ losses.append(loss.item())
240
+
241
+ # ── Вычисляем сдвиг параметров ──────────────────────────────
242
+ param_shift = 0.0
243
+ current_state = updater.state_dict()
244
+ saved_updater_state.seek(0)
245
+ saved_state_dict = torch.load(saved_updater_state, map_location='cpu', weights_only=True)
246
+ for key in saved_state_dict:
247
+ if key in current_state:
248
+ param_shift += (
249
+ current_state[key].float() - saved_state_dict[key].float()
250
+ ).norm(2).item()
251
+ saved_updater_state.seek(0) # reset for potential restore
252
+
253
+ ttt_info = {
254
+ "losses": losses,
255
+ "param_shift": param_shift,
256
+ "steps_done": self.ttt_steps,
257
+ }
258
+
259
+ try:
260
+ yield ttt_info
261
+ finally:
262
+ # ── Восстанавливаем оригинальные веса ───────────────────────
263
+ self._load_from_buffer(updater, saved_updater_state)
264
+ self._load_from_buffer(self.reconstruction_head, saved_head_state)
265
+
266
+ def adapt_and_get_info(
267
+ self,
268
+ updater: nn.Module,
269
+ agent_states: torch.Tensor,
270
+ ) -> dict:
271
+ """
272
+ Неконтекстная версия: адаптирует updater и возвращает диагностику.
273
+ Веса НЕ восстанавливаются — вызывающий код должен позаботиться.
274
+
275
+ Полезно когда нужно fine-grained контроль или когда контекст
276
+ менеджер неудобен (например внутри ODE-solver).
277
+
278
+ Returns: dict с 'losses', 'param_shift', 'steps_done',
279
+ 'saved_updater_state', 'saved_head_state'.
280
+ """
281
+ original_states = agent_states.detach()
282
+ device = agent_states.device
283
+
284
+ saved_updater_state = self._save_to_buffer(updater)
285
+ saved_head_state = self._save_to_buffer(self.reconstruction_head)
286
+
287
+ mask = self._create_mask(agent_states.shape, device)
288
+
289
+ inner_params = list(updater.parameters()) + list(
290
+ self.reconstruction_head.parameters()
291
+ )
292
+ inner_optimizer = torch.optim.Adam(inner_params, lr=self.ttt_lr)
293
+
294
+ losses = []
295
+ for step in range(self.ttt_steps):
296
+ inner_optimizer.zero_grad()
297
+ masked_states = original_states * mask
298
+ loss = self.compute_reconstruction_loss(
299
+ updater, masked_states, original_states, mask
300
+ )
301
+ # FIX v4: retain_graph убран — пересоздаём граф на каждом шаге
302
+ loss.backward()
303
+ inner_optimizer.step()
304
+ losses.append(loss.item())
305
+
306
+ param_shift = 0.0
307
+ current_state = updater.state_dict()
308
+ saved_updater_state.seek(0)
309
+ saved_state_dict = torch.load(saved_updater_state, map_location='cpu', weights_only=True)
310
+ for key in saved_state_dict:
311
+ if key in current_state:
312
+ param_shift += (
313
+ current_state[key].float() - saved_state_dict[key].float()
314
+ ).norm(2).item()
315
+ saved_updater_state.seek(0) # reset for potential restore
316
+
317
+ return {
318
+ "losses": losses,
319
+ "param_shift": param_shift,
320
+ "steps_done": self.ttt_steps,
321
+ "saved_updater_state": saved_updater_state,
322
+ "saved_head_state": saved_head_state,
323
+ }
324
+
325
+ def restore_weights(
326
+ self,
327
+ updater: nn.Module,
328
+ saved_updater_state,
329
+ saved_head_state,
330
+ ):
331
+ """Восстановить оригинальные веса после TTT-адаптации."""
332
+ self._load_from_buffer(updater, saved_updater_state)
333
+ self._load_from_buffer(self.reconstruction_head, saved_head_state)
334
+
335
+ @staticmethod
336
+ def _save_to_buffer(module: nn.Module) -> io.BytesIO:
337
+ """Сохраняет state_dict в memory buffer (быстрее чем copy.deepcopy)."""
338
+ buf = io.BytesIO()
339
+ torch.save(module.state_dict(), buf)
340
+ buf.seek(0)
341
+ return buf
342
+
343
+ @staticmethod
344
+ def _load_from_buffer(module: nn.Module, buf: io.BytesIO):
345
+ """Восстанавливает state_dict из memory buffer."""
346
+ buf.seek(0)
347
+ state = torch.load(buf, map_location='cpu', weights_only=True)
348
+ module.load_state_dict(state)
@@ -0,0 +1,43 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ from torchvision.models.feature_extraction import create_feature_extractor
5
+
6
+ class VisionFeatureExtractor(nn.Module):
7
+ """
8
+ Extracts both pooled and spatial features from a Vision Backbone (ResNet50).
9
+ Pooled features: (B, 2048)
10
+ Spatial features: (B, 49, 2048) - where 49 = 7x7 grid
11
+ """
12
+ def __init__(self, model_name="resnet50", pretrained=True):
13
+ super().__init__()
14
+ if model_name == "resnet50":
15
+ base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)
16
+ # We want the layer before global avg pool
17
+ return_nodes = {
18
+ 'layer4': 'spatial_features',
19
+ 'avgpool': 'pooled_features'
20
+ }
21
+ self.feature_extractor = create_feature_extractor(base_model, return_nodes=return_nodes)
22
+ self.d_output = 2048
23
+ else:
24
+ raise ValueError(f"Unsupported model: {model_name}")
25
+
26
+ def forward(self, x):
27
+ """
28
+ x: (B, 3, H, W)
29
+ Returns:
30
+ pooled: (B, d_output)
31
+ spatial: (B, L, d_output) where L is spatial grid size
32
+ """
33
+ features = self.feature_extractor(x)
34
+
35
+ # Pooled features come as (B, 2048, 1, 1) -> (B, 2048)
36
+ pooled = features['pooled_features'].flatten(1)
37
+
38
+ # Spatial features come as (B, 2048, 7, 7) -> (B, 49, 2048)
39
+ spatial = features['spatial_features']
40
+ B, C, H, W = spatial.shape
41
+ spatial = spatial.view(B, C, -1).transpose(1, 2) # (B, 49, 2048)
42
+
43
+ return pooled, spatial
@@ -0,0 +1,167 @@
1
+ Metadata-Version: 2.4
2
+ Name: cida-plugin
3
+ Version: 1.0.0
4
+ Summary: Universal Evidence-Grounded Multi-Agent Deliberation Layer for any encoder
5
+ Author-email: Kairat Zhaksylykov <zhaksylykov.k06@gmail.com>
6
+ License: Apache-2.0
7
+ Project-URL: Homepage, https://github.com/Kairatzh/CIDA-plugin
8
+ Project-URL: Repository, https://github.com/Kairatzh/CIDA-plugin.git
9
+ Project-URL: Documentation, https://github.com/Kairatzh/CIDA-plugin#readme
10
+ Project-URL: Issues, https://github.com/Kairatzh/CIDA-plugin/issues
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3 :: Only
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Classifier: License :: OSI Approved :: Apache Software License
20
+ Requires-Python: >=3.10
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: numpy>=1.24
24
+ Requires-Dist: torch>=2.1
25
+ Requires-Dist: transformers>=4.0
26
+ Requires-Dist: huggingface_hub>=0.14.0
27
+ Requires-Dist: torchdiffeq>=0.2.3
28
+ Provides-Extra: dev
29
+ Requires-Dist: pytest>=8.0; extra == "dev"
30
+ Requires-Dist: pytest-cov>=5.0; extra == "dev"
31
+ Requires-Dist: ruff>=0.3.0; extra == "dev"
32
+ Requires-Dist: black>=24.0.0; extra == "dev"
33
+ Dynamic: license-file
34
+
35
+ # CIDA-Plugin v3: Universal Evidence-Grounded Multi-Agent Deliberation Layer
36
+
37
+ > *"What if a neural network could argue with itself — and reach a better answer?"*
38
+
39
+ **CIDA-Plugin** is a drop-in architectural layer that can be added on top of **any** pre-trained Transformer encoder (BERT, DistilBERT, RoBERTa, etc.) or Vision Backbone (ResNet, DenseNet, etc.). Instead of a simple Linear Head, CIDA-Plugin introduces a **Multi-Agent Deliberation Protocol**.
40
+
41
+ It forces the model to form independent perspectives (agents), exchange arguments, and reach a consensus weighted by each agent's uncertainty.
42
+
43
+ **Result:** Massive reductions in Expected Calibration Error (ECE), robust uncertainty estimation, and better-reasoned predictions without relying on post-hoc calibration methods like Temperature Scaling.
44
+
45
+ ---
46
+
47
+ ## ⚡ What's New in v3 (Simplified Architecture)
48
+ We transitioned from a highly complex, statistical embedding-based formulation (v2) to a streamlined, theoretically grounded Bayesian-inspired architecture (v3):
49
+
50
+ * **Bayesian Role Priors**: Removed the hyperparameter-heavy `debate_loss`, role-specific serialization losses, and learnable `RoleEmbeddings`. Instead, agents are assigned fixed, mathematically guaranteed prior beliefs (Prosecutor, Defender, Skeptic, Integrator). This enforces structural disagreement at all times, preventing agent representation collapse.
51
+ * **Weighted Mean Consensus**: Replaced the Product of Experts (PoE) aggregator, which incorrectly assumed agent independence and amplified shared bias (causing overconfidence). We now aggregate beliefs via a Weighted Mean, which is mathematically sound for correlated variables.
52
+ * **Disagreement-as-Uncertainty**: Replaced the circular and unstable `ReliabilityTracker` with an observable, non-learned uncertainty quantification based on the variance (standard deviation) between expert beliefs: $U = f(\text{std}(b))$.
53
+ * **3-Component Loss System**: Reduced the loss system from 11 components to 3 core components: Task Loss (CE/BCE), Calibration Loss (Brier Score), and Anti-Collapse Loss (used only as a safety net).
54
+
55
+ ---
56
+
57
+ ## ⚡ Comparison: Legacy vs. v3 Simplified
58
+
59
+ | Feature | Legacy CIDA (v2/Omega) | CIDA v3 (Simplified) |
60
+ | :--- | :--- | :--- |
61
+ | **Agent Diversity** | Additive embeddings + `debate_loss` | Fixed Role Priors + `anti_collapse_loss` |
62
+ | **Consensus Mech** | Product of Experts (PoE) | Weighted Mean (Correlation-Aware) |
63
+ | **Reliability/Uncertainty** | Learned EMA Reliability Tracker | Observed Disagreement ($std(b)$) |
64
+ | **Loss System** | 11 components (hard to tune) | 3 components (extremely stable) |
65
+ | **Hyperparameters** | 11 (lambda schedules, temp, etc.) | 2 (lambda_cal, lambda_ac) |
66
+
67
+ ---
68
+
69
+ ## 📦 Installation
70
+
71
+ ```bash
72
+ pip install .
73
+ ```
74
+
75
+ ---
76
+
77
+ ## ⚡ Quickstart
78
+
79
+ CIDA-Plugin is designed to be as easy to use as a standard Hugging Face model.
80
+
81
+ ### 1. Training with any Encoder
82
+ ```python
83
+ import torch
84
+ from transformers import AutoModel
85
+ from cida_plugin import CIDAPlugin, CIDAPluginConfig, CIDALoss
86
+
87
+ # 1. Load any frozen encoder
88
+ encoder = AutoModel.from_pretrained("distilbert-base-uncased")
89
+ d_model = encoder.config.hidden_size
90
+
91
+ # 2. Initialize the plugin config
92
+ config = CIDAPluginConfig(
93
+ d_input=d_model, # Match encoder output dimension
94
+ d_hidden=128, # Internal plugin dimension
95
+ num_classes=2,
96
+ max_rounds=3, # Deliberation rounds
97
+ early_stop_threshold=0.90
98
+ )
99
+ plugin = CIDAPlugin(config)
100
+ loss_fn = CIDALoss(lambda_cal=0.4, lambda_ac=0.2)
101
+
102
+ # 3. Forward pass
103
+ input_ids = torch.randint(0, 1000, (4, 128))
104
+ out = encoder(input_ids)
105
+ pooled = out.last_hidden_state[:, 0, :]
106
+
107
+ # The plugin takes the pooled representation and deliberates
108
+ plugin_out = plugin(pooled, seq_output=out.last_hidden_state)
109
+
110
+ logits = plugin_out["p_final"] # (Batch, Num_Classes)
111
+ loss, loss_components = loss_fn(logits, targets, plugin_out["b_all"])
112
+ ```
113
+
114
+ ### 2. Saving and Loading (Hugging Face style)
115
+ ```python
116
+ # Save to disk
117
+ plugin.save_pretrained("./my-cida-plugin")
118
+
119
+ # Load from disk
120
+ loaded_plugin = CIDAPlugin.from_pretrained("./my-cida-plugin")
121
+ ```
122
+
123
+ ---
124
+
125
+ ## 🛠️ Architecture Overview
126
+
127
+ The plugin takes the output of your encoder and processes it through the following steps:
128
+
129
+ 1. **Input Projection:** Maps the arbitrary `d_input` of the encoder to the internal `d_hidden` of the agents.
130
+ 2. **Agent Initialization:** Creates $M$ distinct agents initialized with the pooled representation.
131
+ 3. **Deliberation Loop ($R$ rounds):**
132
+ - **Evidence Extraction:** Agents attend to the input sequence to gather distinct evidence.
133
+ - **Message Formulation:** Agents compress their beliefs and evidence into theses.
134
+ - **Cross-Attention Communication:** Agents listen to others, explicitly weighting disagreement.
135
+ - **Gated Update:** Agents update their internal states.
136
+ - **Role Prior Blending:** Enforces structural roles (Prosecutor, Defender, Skeptic, Integrator) on updated beliefs.
137
+ 4. **Consensus Aggregation:** A final Weighted Mean consensus calculation.
138
+
139
+ ---
140
+
141
+ ## 🧪 Liquid Dynamics & TTT (v5 Additions)
142
+ * **Liquid Neural ODE**: Set `use_liquid_dynamics=True` to replace the discrete iteration loop with continuous-time deliberation solver ($ds/dt = -s/\tau(x) + F(s,r,e)$).
143
+ * **Test-Time Training (TTT)**: Set `use_ttt=True` to allow agents to adapt their weights to a specific input using self-supervised masked state reconstruction steps before answering.
144
+
145
+ ---
146
+
147
+ ## 🎮 Interactive Demo (Hugging Face Spaces)
148
+
149
+ See how 4 agents deliberate before answering — with agent vote charts and uncertainty gauges.
150
+
151
+ ```bash
152
+ # Train demo checkpoints (~5 min on CPU)
153
+ python demo/train_demo.py
154
+
155
+ # Launch Gradio locally
156
+ python demo/app.py
157
+ ```
158
+
159
+ Deploy to [Hugging Face Spaces](https://huggingface.co/spaces): create a Gradio Space pointing to the `demo/` folder (see `demo/README.md`).
160
+
161
+ ---
162
+
163
+ ## ⚡ Running Tests
164
+ To verify the installation and execution:
165
+ ```bash
166
+ pytest tests/ -v
167
+ ```
@@ -0,0 +1,18 @@
1
+ cida_plugin/__init__.py,sha256=2tKyADKoGtZ_mkm9pBBoJ-0XlNxeN9zLTRVpsGUmXww,1801
2
+ cida_plugin/agent.py,sha256=YVEFLIGBiM-w2hEMmz1oE7C_M8_hvXWxZAkvKQMLZMM,8537
3
+ cida_plugin/config.py,sha256=7_YALMO3vDu6E7Ah1ODgcmPdK2DO_VIHqMBUz3hPzQE,11156
4
+ cida_plugin/consensus.py,sha256=WDokEu5Ybr4j5l6cMOddM8Y7eTkdV_4v3iaeybl_-R8,8832
5
+ cida_plugin/core.py,sha256=9oNih5KXtuJcGhj7oBDF_0tJvx01cSfxoEJEHCfcQHQ,20143
6
+ cida_plugin/deliberation.py,sha256=ABedI-Le0BFjP4GGEpfzYMZ-4A469yx-vUwM-xsUAYY,18470
7
+ cida_plugin/diagnostics.py,sha256=PeMjHMnihO_Qu1nTG-PFTphAL8b8h44Sl-wjlbC8VnA,5242
8
+ cida_plugin/hf.py,sha256=hn5AOxVNFlOB73kQyZfpStZ3a7-nJe6AoqdCerIL3aw,7929
9
+ cida_plugin/liquid_dynamics.py,sha256=iLFV6bPllhDW4wimTtIXd7n-_S9FgCpen-ct56DExSw,10489
10
+ cida_plugin/losses.py,sha256=F6Cj-uQfKMJbP1O_pJ9C8H-XRXGBuG6gNU-_09s6Irs,8493
11
+ cida_plugin/translator.py,sha256=3ZHzOM1B7_xnqNSwUmf_dKgnETV62X0YWhZMlB5yO2I,3296
12
+ cida_plugin/ttt.py,sha256=ILE8NxzKo5bT1OgS-ro18RouOeopmcWTzbrLthPDRSk,15449
13
+ cida_plugin/vision_backbone.py,sha256=T4GdlZTFlh2xCBnvYJBgmZmIh6MoDQ8JUS30UGemlZo,1649
14
+ cida_plugin-1.0.0.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
15
+ cida_plugin-1.0.0.dist-info/METADATA,sha256=MB0rqdFaAZMyoFmW0I3CLIA95z6QGCBAYBJXuIr4b6Y,7543
16
+ cida_plugin-1.0.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
17
+ cida_plugin-1.0.0.dist-info/top_level.txt,sha256=IsOo3gdhzBYnw_n7e2zwHEUL0cQAq9GYrR4v9ubjqy4,12
18
+ cida_plugin-1.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+