cida-plugin 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cida_plugin/__init__.py +62 -0
- cida_plugin/agent.py +198 -0
- cida_plugin/config.py +228 -0
- cida_plugin/consensus.py +178 -0
- cida_plugin/core.py +451 -0
- cida_plugin/deliberation.py +376 -0
- cida_plugin/diagnostics.py +124 -0
- cida_plugin/hf.py +194 -0
- cida_plugin/liquid_dynamics.py +291 -0
- cida_plugin/losses.py +174 -0
- cida_plugin/translator.py +74 -0
- cida_plugin/ttt.py +348 -0
- cida_plugin/vision_backbone.py +43 -0
- cida_plugin-1.0.0.dist-info/METADATA +167 -0
- cida_plugin-1.0.0.dist-info/RECORD +18 -0
- cida_plugin-1.0.0.dist-info/WHEEL +5 -0
- cida_plugin-1.0.0.dist-info/licenses/LICENSE +201 -0
- cida_plugin-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
class LatentSemanticTranslator(nn.Module):
|
|
6
|
+
"""
|
|
7
|
+
Module 3 of Kairos OS: Latent-to-Semantic Translator.
|
|
8
|
+
Translates CIDA's latent superposition (Prosecutor vs Defender beliefs)
|
|
9
|
+
into structured semantic concept embeddings that an LLM can understand.
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, d_hidden: int = 128, vocab_size: int = 50000, d_model_llm: int = 4096):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.d_hidden = d_hidden
|
|
14
|
+
self.d_model_llm = d_model_llm
|
|
15
|
+
|
|
16
|
+
# Non-linear projection from CIDA latent space to LLM semantic space
|
|
17
|
+
self.latent_to_semantic = nn.Sequential(
|
|
18
|
+
nn.Linear(d_hidden, d_hidden * 4),
|
|
19
|
+
nn.GELU(),
|
|
20
|
+
nn.Dropout(0.1),
|
|
21
|
+
nn.Linear(d_hidden * 4, d_model_llm)
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Attention pooling to summarize evidence
|
|
25
|
+
self.evidence_attention = nn.MultiheadAttention(embed_dim=d_model_llm, num_heads=8, batch_first=True)
|
|
26
|
+
|
|
27
|
+
def forward(self, superposition_state: dict, class_names: list = None):
|
|
28
|
+
"""
|
|
29
|
+
Args:
|
|
30
|
+
superposition_state: Dict containing 'prosecutor_belief' and 'defender_belief'
|
|
31
|
+
Each tensor is of shape (B, num_classes) or (B, d_hidden)
|
|
32
|
+
class_names: List of class names for semantic grounding
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
semantic_embeddings: (B, Seq, d_model_llm) ready to be prepended to LLM prompt
|
|
36
|
+
"""
|
|
37
|
+
if superposition_state is None:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
prosecutor_b = superposition_state.get("prosecutor_belief") # (B, ...)
|
|
41
|
+
defender_b = superposition_state.get("defender_belief") # (B, ...)
|
|
42
|
+
|
|
43
|
+
# In a real implementation, b is usually (B, num_classes).
|
|
44
|
+
# For the epic pipeline, CIDA needs to output the raw hidden states (B, d_hidden)
|
|
45
|
+
# alongside the probabilities. We assume they are d_hidden here for the projection.
|
|
46
|
+
|
|
47
|
+
# Protect against shape mismatch if b is just probabilities (B, num_classes)
|
|
48
|
+
# We will pad or project it to d_hidden for this skeleton
|
|
49
|
+
B = prosecutor_b.size(0)
|
|
50
|
+
|
|
51
|
+
if prosecutor_b.size(-1) != self.d_hidden:
|
|
52
|
+
# Dummy projection if we only received probabilities
|
|
53
|
+
pad_p = torch.zeros(B, self.d_hidden, device=prosecutor_b.device)
|
|
54
|
+
pad_p[:, :prosecutor_b.size(-1)] = prosecutor_b
|
|
55
|
+
prosecutor_b = pad_p
|
|
56
|
+
|
|
57
|
+
pad_d = torch.zeros(B, self.d_hidden, device=defender_b.device)
|
|
58
|
+
pad_d[:, :defender_b.size(-1)] = defender_b
|
|
59
|
+
defender_b = pad_d
|
|
60
|
+
|
|
61
|
+
# Map to LLM dimension
|
|
62
|
+
sem_prosecutor = self.latent_to_semantic(prosecutor_b).unsqueeze(1) # (B, 1, d_llm)
|
|
63
|
+
sem_defender = self.latent_to_semantic(defender_b).unsqueeze(1) # (B, 1, d_llm)
|
|
64
|
+
|
|
65
|
+
# Combine hypotheses into a sequence
|
|
66
|
+
semantic_sequence = torch.cat([sem_prosecutor, sem_defender], dim=1) # (B, 2, d_llm)
|
|
67
|
+
|
|
68
|
+
# Self-attention to find conflicts
|
|
69
|
+
attn_out, _ = self.evidence_attention(semantic_sequence, semantic_sequence, semantic_sequence)
|
|
70
|
+
|
|
71
|
+
# Add residual
|
|
72
|
+
semantic_sequence = semantic_sequence + attn_out
|
|
73
|
+
|
|
74
|
+
return semantic_sequence
|
cida_plugin/ttt.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test-Time Training (TTT) для CIDA-Plugin.
|
|
3
|
+
|
|
4
|
+
Концепция (Apple 2024, MIT 2024):
|
|
5
|
+
Обычно: обучение = оффлайн, инференс = заморожен.
|
|
6
|
+
TTT: агент делает K шагов градиентного спуска на каждом входе,
|
|
7
|
+
адаптируясь к конкретному примеру ПЕРЕД ответом.
|
|
8
|
+
|
|
9
|
+
θ*_i = θ_i − α · ∇L_self(θ_i, x) ← K шагов Adam
|
|
10
|
+
b_i = Agent_i(x; θ*_i) ← belief после адаптации
|
|
11
|
+
θ_i ← θ_i ← восстановление весов
|
|
12
|
+
|
|
13
|
+
Почему это важно для CIDA:
|
|
14
|
+
- Первая архитектура где агенты РЕАЛЬНО «думают» о каждом входе
|
|
15
|
+
- Добавляем вычисление (K шагов), не параметры
|
|
16
|
+
- Маленькая модель адаптируется как большая
|
|
17
|
+
|
|
18
|
+
Self-supervised loss:
|
|
19
|
+
Masked reconstruction — маскируем часть компонент скрытого состояния
|
|
20
|
+
агента и заставляем его восстанавливать замаскированное.
|
|
21
|
+
Это заставляет агента «понять» структуру конкретного входа.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import io
|
|
25
|
+
from contextlib import contextmanager
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn as nn
|
|
29
|
+
import torch.nn.functional as F
|
|
30
|
+
|
|
31
|
+
from .agent import AgentState
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class MaskedReconstructionHead(nn.Module):
|
|
35
|
+
"""
|
|
36
|
+
Вспомогательная голова для self-supervised loss.
|
|
37
|
+
|
|
38
|
+
Принимает замаскированное скрытое состояние агента,
|
|
39
|
+
предсказывает оригинальные значения замаскированных позиций.
|
|
40
|
+
|
|
41
|
+
Архитектура: 2-layer MLP с bottleneck (d_hidden → d_hidden//2 → d_hidden).
|
|
42
|
+
Bottleneck предотвращает тривиальное копирование и заставляет
|
|
43
|
+
модель извлекать семантическую структуру из незамаскированных позиций.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, d_hidden: int):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.net = nn.Sequential(
|
|
49
|
+
nn.Linear(d_hidden, d_hidden // 2),
|
|
50
|
+
nn.SiLU(),
|
|
51
|
+
nn.Linear(d_hidden // 2, d_hidden),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def forward(self, masked_states: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
"""
|
|
56
|
+
masked_states: (B, M, d_hidden) — состояния с замаскированными компонентами.
|
|
57
|
+
Returns: (B, M, d_hidden) — предсказание оригинальных значений.
|
|
58
|
+
"""
|
|
59
|
+
return self.net(masked_states)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class TestTimeTrainer(nn.Module):
|
|
63
|
+
"""
|
|
64
|
+
Test-Time Training для агентов CIDA.
|
|
65
|
+
|
|
66
|
+
При каждом forward-проходе:
|
|
67
|
+
1. Создаёт маску (mask_ratio компонент обнуляются)
|
|
68
|
+
2. Пропускает замаскированный вход через AgentUpdater
|
|
69
|
+
3. Считает reconstruction loss на замаскированных позициях
|
|
70
|
+
4. Делает K шагов Adam на updater + reconstruction_head
|
|
71
|
+
5. После deliberation — восстанавливает оригинальные веса
|
|
72
|
+
|
|
73
|
+
Параметры
|
|
74
|
+
---------
|
|
75
|
+
d_hidden : int
|
|
76
|
+
Размерность скрытых состояний агентов.
|
|
77
|
+
ttt_steps : int
|
|
78
|
+
Число шагов градиентного спуска (K). По умолчанию 3.
|
|
79
|
+
Компромисс: больше шагов = лучше адаптация, но дороже.
|
|
80
|
+
ttt_lr : float
|
|
81
|
+
Learning rate для inner Adam. По умолчанию 1e-3.
|
|
82
|
+
Должен быть достаточно большим для быстрой адаптации
|
|
83
|
+
за K шагов, но не настолько чтобы сломать представления.
|
|
84
|
+
mask_ratio : float
|
|
85
|
+
Доля маскируемых компонент. По умолчанию 0.15 (как в BERT).
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
d_hidden: int,
|
|
91
|
+
ttt_steps: int = 3,
|
|
92
|
+
ttt_lr: float = 1e-3,
|
|
93
|
+
mask_ratio: float = 0.15,
|
|
94
|
+
):
|
|
95
|
+
super().__init__()
|
|
96
|
+
self.ttt_steps = ttt_steps
|
|
97
|
+
self.ttt_lr = ttt_lr
|
|
98
|
+
self.mask_ratio = mask_ratio
|
|
99
|
+
self.d_hidden = d_hidden
|
|
100
|
+
|
|
101
|
+
# Reconstruction head — обучается вместе с основной моделью
|
|
102
|
+
self.reconstruction_head = MaskedReconstructionHead(d_hidden)
|
|
103
|
+
|
|
104
|
+
def _create_mask(self, shape: tuple, device: torch.device) -> torch.Tensor:
|
|
105
|
+
"""
|
|
106
|
+
Создаёт бинарную маску: 1 = сохранить, 0 = замаскировать.
|
|
107
|
+
|
|
108
|
+
Маска одинаковая для всех агентов в батче (per-example, not per-agent)
|
|
109
|
+
чтобы все агенты видели одинаковые «пробелы» и были вынуждены
|
|
110
|
+
договариваться о восстановлении.
|
|
111
|
+
|
|
112
|
+
shape: (B, M, d_hidden)
|
|
113
|
+
Returns: (B, 1, d_hidden) — broadcast по агентам
|
|
114
|
+
"""
|
|
115
|
+
B = shape[0]
|
|
116
|
+
d = shape[2]
|
|
117
|
+
num_masked = max(1, int(d * self.mask_ratio))
|
|
118
|
+
|
|
119
|
+
mask = torch.ones(B, 1, d, device=device)
|
|
120
|
+
for i in range(B):
|
|
121
|
+
indices = torch.randperm(d, device=device)[:num_masked]
|
|
122
|
+
mask[i, 0, indices] = 0.0
|
|
123
|
+
|
|
124
|
+
return mask
|
|
125
|
+
|
|
126
|
+
def compute_reconstruction_loss(
|
|
127
|
+
self,
|
|
128
|
+
updater: nn.Module,
|
|
129
|
+
masked_states: torch.Tensor,
|
|
130
|
+
original_states: torch.Tensor,
|
|
131
|
+
mask: torch.Tensor,
|
|
132
|
+
) -> torch.Tensor:
|
|
133
|
+
"""
|
|
134
|
+
Self-supervised loss: MSE на замаскированных позициях.
|
|
135
|
+
|
|
136
|
+
updater : nn.Module — AgentUpdater для проброса замаскированных состояний
|
|
137
|
+
masked_states : (B, M, d_hidden) — состояния с маской (вход)
|
|
138
|
+
original_states : (B, M, d_hidden) — оригинальные состояния (target)
|
|
139
|
+
mask : (B, 1, d_hidden) — маска (0 = замаскировано)
|
|
140
|
+
|
|
141
|
+
Returns: scalar — средний MSE на замаскированных позициях.
|
|
142
|
+
"""
|
|
143
|
+
B, M, d = masked_states.shape
|
|
144
|
+
device = masked_states.device
|
|
145
|
+
|
|
146
|
+
# Создаем фиктивные данные для AgentUpdater (ему нужны только s, e, r для forward)
|
|
147
|
+
dummy_state = AgentState(
|
|
148
|
+
s=masked_states,
|
|
149
|
+
b=torch.zeros(B, M, 1, device=device),
|
|
150
|
+
u=torch.ones(B, M, 1, device=device),
|
|
151
|
+
p=torch.zeros(B, M, 1, device=device),
|
|
152
|
+
e=torch.zeros(B, M, d, device=device),
|
|
153
|
+
alpha=torch.zeros(B, M, 1, device=device)
|
|
154
|
+
)
|
|
155
|
+
dummy_r = torch.zeros(B, M, d, device=device)
|
|
156
|
+
dummy_e = torch.zeros(B, M, d, device=device)
|
|
157
|
+
|
|
158
|
+
# Пропускаем через updater (чтобы его веса получили градиенты!)
|
|
159
|
+
updated_state = updater(dummy_state, dummy_r, dummy_e)
|
|
160
|
+
|
|
161
|
+
predicted = self.reconstruction_head(updated_state.s) # (B, M, d)
|
|
162
|
+
inverted_mask = 1.0 - mask # 1 = замаскированные позиции
|
|
163
|
+
|
|
164
|
+
# MSE только на замаскированных
|
|
165
|
+
diff = (predicted - original_states) ** 2
|
|
166
|
+
masked_diff = diff * inverted_mask # (B, M, d)
|
|
167
|
+
|
|
168
|
+
# Средний loss: сумма по маскированным / число маскированных
|
|
169
|
+
num_masked = inverted_mask.sum().clamp(min=1.0)
|
|
170
|
+
loss = masked_diff.sum() / num_masked
|
|
171
|
+
|
|
172
|
+
return loss
|
|
173
|
+
|
|
174
|
+
@contextmanager
|
|
175
|
+
def adapt_context(
|
|
176
|
+
self,
|
|
177
|
+
updater: nn.Module,
|
|
178
|
+
agent_states: torch.Tensor,
|
|
179
|
+
):
|
|
180
|
+
"""
|
|
181
|
+
Context manager для TTT-адаптации.
|
|
182
|
+
|
|
183
|
+
Внутри контекста:
|
|
184
|
+
- updater имеет адаптированные веса (θ*)
|
|
185
|
+
- reconstruction_head тоже адаптирован
|
|
186
|
+
При выходе:
|
|
187
|
+
- Все веса восстанавливаются к θ
|
|
188
|
+
|
|
189
|
+
Использование:
|
|
190
|
+
with ttt.adapt_context(updater, s_t) as ttt_info:
|
|
191
|
+
# updater уже адаптирован, запускаем deliberation
|
|
192
|
+
state_t = updater(state_t, r_t, e_t)
|
|
193
|
+
|
|
194
|
+
Возвращает dict с диагностикой:
|
|
195
|
+
- 'losses': list[float] — losses по шагам
|
|
196
|
+
- 'param_shift': float — ||θ* - θ||
|
|
197
|
+
- 'steps_done': int — фактическое число шагов
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
updater : nn.Module
|
|
202
|
+
AgentUpdater, чьи веса будут адаптированы.
|
|
203
|
+
agent_states : torch.Tensor
|
|
204
|
+
(B, M, d_hidden) — текущие состояния агентов (target для реконструкции).
|
|
205
|
+
"""
|
|
206
|
+
original_states = agent_states.detach()
|
|
207
|
+
device = agent_states.device
|
|
208
|
+
|
|
209
|
+
# ── Сохраняем оригинальные параметры (BytesIO вместо deepcopy — быстрее) ──
|
|
210
|
+
saved_updater_state = self._save_to_buffer(updater)
|
|
211
|
+
saved_head_state = self._save_to_buffer(self.reconstruction_head)
|
|
212
|
+
|
|
213
|
+
# ── Создаём маску ───────────────────────────────────────────────
|
|
214
|
+
mask = self._create_mask(agent_states.shape, device)
|
|
215
|
+
|
|
216
|
+
# ── Inner optimization (K шагов Adam) ───────────────────────────
|
|
217
|
+
# Оптимизируем только updater + reconstruction_head
|
|
218
|
+
inner_params = list(updater.parameters()) + list(
|
|
219
|
+
self.reconstruction_head.parameters()
|
|
220
|
+
)
|
|
221
|
+
inner_optimizer = torch.optim.Adam(inner_params, lr=self.ttt_lr)
|
|
222
|
+
|
|
223
|
+
losses = []
|
|
224
|
+
for step in range(self.ttt_steps):
|
|
225
|
+
inner_optimizer.zero_grad()
|
|
226
|
+
|
|
227
|
+
# Маскируем состояния
|
|
228
|
+
masked_states = original_states * mask # (B, M, d)
|
|
229
|
+
|
|
230
|
+
# Reconstruction loss
|
|
231
|
+
loss = self.compute_reconstruction_loss(
|
|
232
|
+
updater, masked_states, original_states, mask
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# FIX v4: retain_graph убран — каждый шаг пересоздаёт граф заново
|
|
236
|
+
# (original_states уже detached, mask постоянна → retain не нужен)
|
|
237
|
+
loss.backward()
|
|
238
|
+
inner_optimizer.step()
|
|
239
|
+
losses.append(loss.item())
|
|
240
|
+
|
|
241
|
+
# ── Вычисляем сдвиг параметров ──────────────────────────────
|
|
242
|
+
param_shift = 0.0
|
|
243
|
+
current_state = updater.state_dict()
|
|
244
|
+
saved_updater_state.seek(0)
|
|
245
|
+
saved_state_dict = torch.load(saved_updater_state, map_location='cpu', weights_only=True)
|
|
246
|
+
for key in saved_state_dict:
|
|
247
|
+
if key in current_state:
|
|
248
|
+
param_shift += (
|
|
249
|
+
current_state[key].float() - saved_state_dict[key].float()
|
|
250
|
+
).norm(2).item()
|
|
251
|
+
saved_updater_state.seek(0) # reset for potential restore
|
|
252
|
+
|
|
253
|
+
ttt_info = {
|
|
254
|
+
"losses": losses,
|
|
255
|
+
"param_shift": param_shift,
|
|
256
|
+
"steps_done": self.ttt_steps,
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
yield ttt_info
|
|
261
|
+
finally:
|
|
262
|
+
# ── Восстанавливаем оригинальные веса ───────────────────────
|
|
263
|
+
self._load_from_buffer(updater, saved_updater_state)
|
|
264
|
+
self._load_from_buffer(self.reconstruction_head, saved_head_state)
|
|
265
|
+
|
|
266
|
+
def adapt_and_get_info(
|
|
267
|
+
self,
|
|
268
|
+
updater: nn.Module,
|
|
269
|
+
agent_states: torch.Tensor,
|
|
270
|
+
) -> dict:
|
|
271
|
+
"""
|
|
272
|
+
Неконтекстная версия: адаптирует updater и возвращает диагностику.
|
|
273
|
+
Веса НЕ восстанавливаются — вызывающий код должен позаботиться.
|
|
274
|
+
|
|
275
|
+
Полезно когда нужно fine-grained контроль или когда контекст
|
|
276
|
+
менеджер неудобен (например внутри ODE-solver).
|
|
277
|
+
|
|
278
|
+
Returns: dict с 'losses', 'param_shift', 'steps_done',
|
|
279
|
+
'saved_updater_state', 'saved_head_state'.
|
|
280
|
+
"""
|
|
281
|
+
original_states = agent_states.detach()
|
|
282
|
+
device = agent_states.device
|
|
283
|
+
|
|
284
|
+
saved_updater_state = self._save_to_buffer(updater)
|
|
285
|
+
saved_head_state = self._save_to_buffer(self.reconstruction_head)
|
|
286
|
+
|
|
287
|
+
mask = self._create_mask(agent_states.shape, device)
|
|
288
|
+
|
|
289
|
+
inner_params = list(updater.parameters()) + list(
|
|
290
|
+
self.reconstruction_head.parameters()
|
|
291
|
+
)
|
|
292
|
+
inner_optimizer = torch.optim.Adam(inner_params, lr=self.ttt_lr)
|
|
293
|
+
|
|
294
|
+
losses = []
|
|
295
|
+
for step in range(self.ttt_steps):
|
|
296
|
+
inner_optimizer.zero_grad()
|
|
297
|
+
masked_states = original_states * mask
|
|
298
|
+
loss = self.compute_reconstruction_loss(
|
|
299
|
+
updater, masked_states, original_states, mask
|
|
300
|
+
)
|
|
301
|
+
# FIX v4: retain_graph убран — пересоздаём граф на каждом шаге
|
|
302
|
+
loss.backward()
|
|
303
|
+
inner_optimizer.step()
|
|
304
|
+
losses.append(loss.item())
|
|
305
|
+
|
|
306
|
+
param_shift = 0.0
|
|
307
|
+
current_state = updater.state_dict()
|
|
308
|
+
saved_updater_state.seek(0)
|
|
309
|
+
saved_state_dict = torch.load(saved_updater_state, map_location='cpu', weights_only=True)
|
|
310
|
+
for key in saved_state_dict:
|
|
311
|
+
if key in current_state:
|
|
312
|
+
param_shift += (
|
|
313
|
+
current_state[key].float() - saved_state_dict[key].float()
|
|
314
|
+
).norm(2).item()
|
|
315
|
+
saved_updater_state.seek(0) # reset for potential restore
|
|
316
|
+
|
|
317
|
+
return {
|
|
318
|
+
"losses": losses,
|
|
319
|
+
"param_shift": param_shift,
|
|
320
|
+
"steps_done": self.ttt_steps,
|
|
321
|
+
"saved_updater_state": saved_updater_state,
|
|
322
|
+
"saved_head_state": saved_head_state,
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
def restore_weights(
|
|
326
|
+
self,
|
|
327
|
+
updater: nn.Module,
|
|
328
|
+
saved_updater_state,
|
|
329
|
+
saved_head_state,
|
|
330
|
+
):
|
|
331
|
+
"""Восстановить оригинальные веса после TTT-адаптации."""
|
|
332
|
+
self._load_from_buffer(updater, saved_updater_state)
|
|
333
|
+
self._load_from_buffer(self.reconstruction_head, saved_head_state)
|
|
334
|
+
|
|
335
|
+
@staticmethod
|
|
336
|
+
def _save_to_buffer(module: nn.Module) -> io.BytesIO:
|
|
337
|
+
"""Сохраняет state_dict в memory buffer (быстрее чем copy.deepcopy)."""
|
|
338
|
+
buf = io.BytesIO()
|
|
339
|
+
torch.save(module.state_dict(), buf)
|
|
340
|
+
buf.seek(0)
|
|
341
|
+
return buf
|
|
342
|
+
|
|
343
|
+
@staticmethod
|
|
344
|
+
def _load_from_buffer(module: nn.Module, buf: io.BytesIO):
|
|
345
|
+
"""Восстанавливает state_dict из memory buffer."""
|
|
346
|
+
buf.seek(0)
|
|
347
|
+
state = torch.load(buf, map_location='cpu', weights_only=True)
|
|
348
|
+
module.load_state_dict(state)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torchvision import models
|
|
4
|
+
from torchvision.models.feature_extraction import create_feature_extractor
|
|
5
|
+
|
|
6
|
+
class VisionFeatureExtractor(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Extracts both pooled and spatial features from a Vision Backbone (ResNet50).
|
|
9
|
+
Pooled features: (B, 2048)
|
|
10
|
+
Spatial features: (B, 49, 2048) - where 49 = 7x7 grid
|
|
11
|
+
"""
|
|
12
|
+
def __init__(self, model_name="resnet50", pretrained=True):
|
|
13
|
+
super().__init__()
|
|
14
|
+
if model_name == "resnet50":
|
|
15
|
+
base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)
|
|
16
|
+
# We want the layer before global avg pool
|
|
17
|
+
return_nodes = {
|
|
18
|
+
'layer4': 'spatial_features',
|
|
19
|
+
'avgpool': 'pooled_features'
|
|
20
|
+
}
|
|
21
|
+
self.feature_extractor = create_feature_extractor(base_model, return_nodes=return_nodes)
|
|
22
|
+
self.d_output = 2048
|
|
23
|
+
else:
|
|
24
|
+
raise ValueError(f"Unsupported model: {model_name}")
|
|
25
|
+
|
|
26
|
+
def forward(self, x):
|
|
27
|
+
"""
|
|
28
|
+
x: (B, 3, H, W)
|
|
29
|
+
Returns:
|
|
30
|
+
pooled: (B, d_output)
|
|
31
|
+
spatial: (B, L, d_output) where L is spatial grid size
|
|
32
|
+
"""
|
|
33
|
+
features = self.feature_extractor(x)
|
|
34
|
+
|
|
35
|
+
# Pooled features come as (B, 2048, 1, 1) -> (B, 2048)
|
|
36
|
+
pooled = features['pooled_features'].flatten(1)
|
|
37
|
+
|
|
38
|
+
# Spatial features come as (B, 2048, 7, 7) -> (B, 49, 2048)
|
|
39
|
+
spatial = features['spatial_features']
|
|
40
|
+
B, C, H, W = spatial.shape
|
|
41
|
+
spatial = spatial.view(B, C, -1).transpose(1, 2) # (B, 49, 2048)
|
|
42
|
+
|
|
43
|
+
return pooled, spatial
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cida-plugin
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Universal Evidence-Grounded Multi-Agent Deliberation Layer for any encoder
|
|
5
|
+
Author-email: Kairat Zhaksylykov <zhaksylykov.k06@gmail.com>
|
|
6
|
+
License: Apache-2.0
|
|
7
|
+
Project-URL: Homepage, https://github.com/Kairatzh/CIDA-plugin
|
|
8
|
+
Project-URL: Repository, https://github.com/Kairatzh/CIDA-plugin.git
|
|
9
|
+
Project-URL: Documentation, https://github.com/Kairatzh/CIDA-plugin#readme
|
|
10
|
+
Project-URL: Issues, https://github.com/Kairatzh/CIDA-plugin/issues
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: numpy>=1.24
|
|
24
|
+
Requires-Dist: torch>=2.1
|
|
25
|
+
Requires-Dist: transformers>=4.0
|
|
26
|
+
Requires-Dist: huggingface_hub>=0.14.0
|
|
27
|
+
Requires-Dist: torchdiffeq>=0.2.3
|
|
28
|
+
Provides-Extra: dev
|
|
29
|
+
Requires-Dist: pytest>=8.0; extra == "dev"
|
|
30
|
+
Requires-Dist: pytest-cov>=5.0; extra == "dev"
|
|
31
|
+
Requires-Dist: ruff>=0.3.0; extra == "dev"
|
|
32
|
+
Requires-Dist: black>=24.0.0; extra == "dev"
|
|
33
|
+
Dynamic: license-file
|
|
34
|
+
|
|
35
|
+
# CIDA-Plugin v3: Universal Evidence-Grounded Multi-Agent Deliberation Layer
|
|
36
|
+
|
|
37
|
+
> *"What if a neural network could argue with itself — and reach a better answer?"*
|
|
38
|
+
|
|
39
|
+
**CIDA-Plugin** is a drop-in architectural layer that can be added on top of **any** pre-trained Transformer encoder (BERT, DistilBERT, RoBERTa, etc.) or Vision Backbone (ResNet, DenseNet, etc.). Instead of a simple Linear Head, CIDA-Plugin introduces a **Multi-Agent Deliberation Protocol**.
|
|
40
|
+
|
|
41
|
+
It forces the model to form independent perspectives (agents), exchange arguments, and reach a consensus weighted by each agent's uncertainty.
|
|
42
|
+
|
|
43
|
+
**Result:** Massive reductions in Expected Calibration Error (ECE), robust uncertainty estimation, and better-reasoned predictions without relying on post-hoc calibration methods like Temperature Scaling.
|
|
44
|
+
|
|
45
|
+
---
|
|
46
|
+
|
|
47
|
+
## ⚡ What's New in v3 (Simplified Architecture)
|
|
48
|
+
We transitioned from a highly complex, statistical embedding-based formulation (v2) to a streamlined, theoretically grounded Bayesian-inspired architecture (v3):
|
|
49
|
+
|
|
50
|
+
* **Bayesian Role Priors**: Removed the hyperparameter-heavy `debate_loss`, role-specific serialization losses, and learnable `RoleEmbeddings`. Instead, agents are assigned fixed, mathematically guaranteed prior beliefs (Prosecutor, Defender, Skeptic, Integrator). This enforces structural disagreement at all times, preventing agent representation collapse.
|
|
51
|
+
* **Weighted Mean Consensus**: Replaced the Product of Experts (PoE) aggregator, which incorrectly assumed agent independence and amplified shared bias (causing overconfidence). We now aggregate beliefs via a Weighted Mean, which is mathematically sound for correlated variables.
|
|
52
|
+
* **Disagreement-as-Uncertainty**: Replaced the circular and unstable `ReliabilityTracker` with an observable, non-learned uncertainty quantification based on the variance (standard deviation) between expert beliefs: $U = f(\text{std}(b))$.
|
|
53
|
+
* **3-Component Loss System**: Reduced the loss system from 11 components to 3 core components: Task Loss (CE/BCE), Calibration Loss (Brier Score), and Anti-Collapse Loss (used only as a safety net).
|
|
54
|
+
|
|
55
|
+
---
|
|
56
|
+
|
|
57
|
+
## ⚡ Comparison: Legacy vs. v3 Simplified
|
|
58
|
+
|
|
59
|
+
| Feature | Legacy CIDA (v2/Omega) | CIDA v3 (Simplified) |
|
|
60
|
+
| :--- | :--- | :--- |
|
|
61
|
+
| **Agent Diversity** | Additive embeddings + `debate_loss` | Fixed Role Priors + `anti_collapse_loss` |
|
|
62
|
+
| **Consensus Mech** | Product of Experts (PoE) | Weighted Mean (Correlation-Aware) |
|
|
63
|
+
| **Reliability/Uncertainty** | Learned EMA Reliability Tracker | Observed Disagreement ($std(b)$) |
|
|
64
|
+
| **Loss System** | 11 components (hard to tune) | 3 components (extremely stable) |
|
|
65
|
+
| **Hyperparameters** | 11 (lambda schedules, temp, etc.) | 2 (lambda_cal, lambda_ac) |
|
|
66
|
+
|
|
67
|
+
---
|
|
68
|
+
|
|
69
|
+
## 📦 Installation
|
|
70
|
+
|
|
71
|
+
```bash
|
|
72
|
+
pip install .
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
---
|
|
76
|
+
|
|
77
|
+
## ⚡ Quickstart
|
|
78
|
+
|
|
79
|
+
CIDA-Plugin is designed to be as easy to use as a standard Hugging Face model.
|
|
80
|
+
|
|
81
|
+
### 1. Training with any Encoder
|
|
82
|
+
```python
|
|
83
|
+
import torch
|
|
84
|
+
from transformers import AutoModel
|
|
85
|
+
from cida_plugin import CIDAPlugin, CIDAPluginConfig, CIDALoss
|
|
86
|
+
|
|
87
|
+
# 1. Load any frozen encoder
|
|
88
|
+
encoder = AutoModel.from_pretrained("distilbert-base-uncased")
|
|
89
|
+
d_model = encoder.config.hidden_size
|
|
90
|
+
|
|
91
|
+
# 2. Initialize the plugin config
|
|
92
|
+
config = CIDAPluginConfig(
|
|
93
|
+
d_input=d_model, # Match encoder output dimension
|
|
94
|
+
d_hidden=128, # Internal plugin dimension
|
|
95
|
+
num_classes=2,
|
|
96
|
+
max_rounds=3, # Deliberation rounds
|
|
97
|
+
early_stop_threshold=0.90
|
|
98
|
+
)
|
|
99
|
+
plugin = CIDAPlugin(config)
|
|
100
|
+
loss_fn = CIDALoss(lambda_cal=0.4, lambda_ac=0.2)
|
|
101
|
+
|
|
102
|
+
# 3. Forward pass
|
|
103
|
+
input_ids = torch.randint(0, 1000, (4, 128))
|
|
104
|
+
out = encoder(input_ids)
|
|
105
|
+
pooled = out.last_hidden_state[:, 0, :]
|
|
106
|
+
|
|
107
|
+
# The plugin takes the pooled representation and deliberates
|
|
108
|
+
plugin_out = plugin(pooled, seq_output=out.last_hidden_state)
|
|
109
|
+
|
|
110
|
+
logits = plugin_out["p_final"] # (Batch, Num_Classes)
|
|
111
|
+
loss, loss_components = loss_fn(logits, targets, plugin_out["b_all"])
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
### 2. Saving and Loading (Hugging Face style)
|
|
115
|
+
```python
|
|
116
|
+
# Save to disk
|
|
117
|
+
plugin.save_pretrained("./my-cida-plugin")
|
|
118
|
+
|
|
119
|
+
# Load from disk
|
|
120
|
+
loaded_plugin = CIDAPlugin.from_pretrained("./my-cida-plugin")
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
---
|
|
124
|
+
|
|
125
|
+
## 🛠️ Architecture Overview
|
|
126
|
+
|
|
127
|
+
The plugin takes the output of your encoder and processes it through the following steps:
|
|
128
|
+
|
|
129
|
+
1. **Input Projection:** Maps the arbitrary `d_input` of the encoder to the internal `d_hidden` of the agents.
|
|
130
|
+
2. **Agent Initialization:** Creates $M$ distinct agents initialized with the pooled representation.
|
|
131
|
+
3. **Deliberation Loop ($R$ rounds):**
|
|
132
|
+
- **Evidence Extraction:** Agents attend to the input sequence to gather distinct evidence.
|
|
133
|
+
- **Message Formulation:** Agents compress their beliefs and evidence into theses.
|
|
134
|
+
- **Cross-Attention Communication:** Agents listen to others, explicitly weighting disagreement.
|
|
135
|
+
- **Gated Update:** Agents update their internal states.
|
|
136
|
+
- **Role Prior Blending:** Enforces structural roles (Prosecutor, Defender, Skeptic, Integrator) on updated beliefs.
|
|
137
|
+
4. **Consensus Aggregation:** A final Weighted Mean consensus calculation.
|
|
138
|
+
|
|
139
|
+
---
|
|
140
|
+
|
|
141
|
+
## 🧪 Liquid Dynamics & TTT (v5 Additions)
|
|
142
|
+
* **Liquid Neural ODE**: Set `use_liquid_dynamics=True` to replace the discrete iteration loop with continuous-time deliberation solver ($ds/dt = -s/\tau(x) + F(s,r,e)$).
|
|
143
|
+
* **Test-Time Training (TTT)**: Set `use_ttt=True` to allow agents to adapt their weights to a specific input using self-supervised masked state reconstruction steps before answering.
|
|
144
|
+
|
|
145
|
+
---
|
|
146
|
+
|
|
147
|
+
## 🎮 Interactive Demo (Hugging Face Spaces)
|
|
148
|
+
|
|
149
|
+
See how 4 agents deliberate before answering — with agent vote charts and uncertainty gauges.
|
|
150
|
+
|
|
151
|
+
```bash
|
|
152
|
+
# Train demo checkpoints (~5 min on CPU)
|
|
153
|
+
python demo/train_demo.py
|
|
154
|
+
|
|
155
|
+
# Launch Gradio locally
|
|
156
|
+
python demo/app.py
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
Deploy to [Hugging Face Spaces](https://huggingface.co/spaces): create a Gradio Space pointing to the `demo/` folder (see `demo/README.md`).
|
|
160
|
+
|
|
161
|
+
---
|
|
162
|
+
|
|
163
|
+
## ⚡ Running Tests
|
|
164
|
+
To verify the installation and execution:
|
|
165
|
+
```bash
|
|
166
|
+
pytest tests/ -v
|
|
167
|
+
```
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
cida_plugin/__init__.py,sha256=2tKyADKoGtZ_mkm9pBBoJ-0XlNxeN9zLTRVpsGUmXww,1801
|
|
2
|
+
cida_plugin/agent.py,sha256=YVEFLIGBiM-w2hEMmz1oE7C_M8_hvXWxZAkvKQMLZMM,8537
|
|
3
|
+
cida_plugin/config.py,sha256=7_YALMO3vDu6E7Ah1ODgcmPdK2DO_VIHqMBUz3hPzQE,11156
|
|
4
|
+
cida_plugin/consensus.py,sha256=WDokEu5Ybr4j5l6cMOddM8Y7eTkdV_4v3iaeybl_-R8,8832
|
|
5
|
+
cida_plugin/core.py,sha256=9oNih5KXtuJcGhj7oBDF_0tJvx01cSfxoEJEHCfcQHQ,20143
|
|
6
|
+
cida_plugin/deliberation.py,sha256=ABedI-Le0BFjP4GGEpfzYMZ-4A469yx-vUwM-xsUAYY,18470
|
|
7
|
+
cida_plugin/diagnostics.py,sha256=PeMjHMnihO_Qu1nTG-PFTphAL8b8h44Sl-wjlbC8VnA,5242
|
|
8
|
+
cida_plugin/hf.py,sha256=hn5AOxVNFlOB73kQyZfpStZ3a7-nJe6AoqdCerIL3aw,7929
|
|
9
|
+
cida_plugin/liquid_dynamics.py,sha256=iLFV6bPllhDW4wimTtIXd7n-_S9FgCpen-ct56DExSw,10489
|
|
10
|
+
cida_plugin/losses.py,sha256=F6Cj-uQfKMJbP1O_pJ9C8H-XRXGBuG6gNU-_09s6Irs,8493
|
|
11
|
+
cida_plugin/translator.py,sha256=3ZHzOM1B7_xnqNSwUmf_dKgnETV62X0YWhZMlB5yO2I,3296
|
|
12
|
+
cida_plugin/ttt.py,sha256=ILE8NxzKo5bT1OgS-ro18RouOeopmcWTzbrLthPDRSk,15449
|
|
13
|
+
cida_plugin/vision_backbone.py,sha256=T4GdlZTFlh2xCBnvYJBgmZmIh6MoDQ8JUS30UGemlZo,1649
|
|
14
|
+
cida_plugin-1.0.0.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
15
|
+
cida_plugin-1.0.0.dist-info/METADATA,sha256=MB0rqdFaAZMyoFmW0I3CLIA95z6QGCBAYBJXuIr4b6Y,7543
|
|
16
|
+
cida_plugin-1.0.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
17
|
+
cida_plugin-1.0.0.dist-info/top_level.txt,sha256=IsOo3gdhzBYnw_n7e2zwHEUL0cQAq9GYrR4v9ubjqy4,12
|
|
18
|
+
cida_plugin-1.0.0.dist-info/RECORD,,
|