cida-plugin 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cida_plugin/__init__.py +62 -0
- cida_plugin/agent.py +198 -0
- cida_plugin/config.py +228 -0
- cida_plugin/consensus.py +178 -0
- cida_plugin/core.py +451 -0
- cida_plugin/deliberation.py +376 -0
- cida_plugin/diagnostics.py +124 -0
- cida_plugin/hf.py +194 -0
- cida_plugin/liquid_dynamics.py +291 -0
- cida_plugin/losses.py +174 -0
- cida_plugin/translator.py +74 -0
- cida_plugin/ttt.py +348 -0
- cida_plugin/vision_backbone.py +43 -0
- cida_plugin-1.0.0.dist-info/METADATA +167 -0
- cida_plugin-1.0.0.dist-info/RECORD +18 -0
- cida_plugin-1.0.0.dist-info/WHEEL +5 -0
- cida_plugin-1.0.0.dist-info/licenses/LICENSE +201 -0
- cida_plugin-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from .agent import AgentState
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MessageFormulator(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
Creates a structured message out of the explicit agent components.
|
|
10
|
+
|
|
11
|
+
v4: Anonymous mode (Choi et al. ACL 2025).
|
|
12
|
+
Убирает s_i из сообщения → агент слышит аргумент, не авторитет.
|
|
13
|
+
Доказано устраняет 100% identity bias в debate.
|
|
14
|
+
|
|
15
|
+
Было: m = φ([s; e; b; u]) — агент знает источник → sycophancy/self-bias
|
|
16
|
+
Стало: m = φ([e; b; u]) — role prior только в инициализации b₀
|
|
17
|
+
"""
|
|
18
|
+
def __init__(self, d_hidden: int, num_classes: int, d_message: int, anonymous: bool = True):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.anonymous = anonymous
|
|
21
|
+
if anonymous:
|
|
22
|
+
in_dim = d_hidden + num_classes + 1 # [e; b; u] — без s
|
|
23
|
+
else:
|
|
24
|
+
in_dim = d_hidden + d_hidden + num_classes + 1 # [s; e; b; u]
|
|
25
|
+
self.proj = nn.Sequential(
|
|
26
|
+
nn.Linear(in_dim, d_message),
|
|
27
|
+
nn.LayerNorm(d_message),
|
|
28
|
+
nn.SiLU(),
|
|
29
|
+
nn.Linear(d_message, d_message)
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def forward(self, s, e, b, u):
|
|
33
|
+
if self.anonymous:
|
|
34
|
+
concat = torch.cat([e, b, u], dim=-1)
|
|
35
|
+
else:
|
|
36
|
+
concat = torch.cat([s, e, b, u], dim=-1)
|
|
37
|
+
return self.proj(concat)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AgentEvidenceExtractor(nn.Module):
|
|
41
|
+
"""
|
|
42
|
+
Computes evidence pointers over tokens, giving each agent a distinct focus.
|
|
43
|
+
p_i^t = softmax((H W_p)(W_q s_i^t)^T / sqrt(d))
|
|
44
|
+
e_i^t = sum_k p_{i,k}^t H_k
|
|
45
|
+
"""
|
|
46
|
+
def __init__(self, d_hidden: int):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.W_q = nn.Linear(d_hidden, d_hidden)
|
|
49
|
+
self.W_p = nn.Linear(d_hidden, d_hidden)
|
|
50
|
+
self.scale = d_hidden ** -0.5
|
|
51
|
+
|
|
52
|
+
def forward(self, s, H, mask=None):
|
|
53
|
+
K = self.W_p(H)
|
|
54
|
+
Q = self.W_q(s)
|
|
55
|
+
V = H
|
|
56
|
+
|
|
57
|
+
attn_mask = None
|
|
58
|
+
if mask is not None:
|
|
59
|
+
# mask: (B, L) -> (B, 1, L) для бродкаста на M агентов
|
|
60
|
+
attn_mask = mask.bool().unsqueeze(1)
|
|
61
|
+
|
|
62
|
+
# F.scaled_dot_product_attention автоматически использует FlashAttention на H100
|
|
63
|
+
# Масштабирование (scale) также применяется автоматически внутри.
|
|
64
|
+
e = F.scaled_dot_product_attention(
|
|
65
|
+
Q, K, V,
|
|
66
|
+
attn_mask=attn_mask,
|
|
67
|
+
dropout_p=0.0,
|
|
68
|
+
is_causal=False,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Возвращаем пустой тензор для p, так как FlashAttention не материализует матрицу внимания.
|
|
72
|
+
# Это экономит O(B*M*L) памяти.
|
|
73
|
+
p_dummy = torch.empty((Q.size(0), Q.size(1), K.size(1)), device=Q.device)
|
|
74
|
+
return p_dummy, e
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class CounterargumentCommunication(nn.Module):
|
|
78
|
+
"""
|
|
79
|
+
Agents collect messages from others, weighted by an attention mechanism
|
|
80
|
+
that explicitly values disagreement while punishing simple similarity.
|
|
81
|
+
|
|
82
|
+
v4: Disagreement-Routed Sparse Communication (Li et al. EMNLP 2024).
|
|
83
|
+
Каждый агент слушает только top-K по несогласию — противоположные
|
|
84
|
+
убеждения дают сильную связь, похожие — разрыв.
|
|
85
|
+
Это защищает minority voice и сокращает вычисления на 75%.
|
|
86
|
+
"""
|
|
87
|
+
def __init__(self, d_hidden: int, d_message: int, lambda_d: float = 1.0, lambda_s: float = 1.0):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.W_q = nn.Linear(d_hidden, d_hidden)
|
|
90
|
+
self.W_k = nn.Linear(d_message, d_hidden)
|
|
91
|
+
self.W_v = nn.Linear(d_message, d_hidden)
|
|
92
|
+
self.lambda_d = lambda_d
|
|
93
|
+
self.lambda_s = lambda_s
|
|
94
|
+
self.scale = d_hidden ** -0.5
|
|
95
|
+
|
|
96
|
+
def forward(self, s, m, b, e, sparse_k: int = 0):
|
|
97
|
+
"""
|
|
98
|
+
s: (B, M, d), m: (B, M, d_msg), b: (B, M, K), e: (B, M, d)
|
|
99
|
+
sparse_k: если > 0 — включает sparse communication (top-K по disagreement)
|
|
100
|
+
"""
|
|
101
|
+
Q = self.W_q(s)
|
|
102
|
+
K = self.W_k(m)
|
|
103
|
+
V = self.W_v(m)
|
|
104
|
+
|
|
105
|
+
attn = torch.bmm(Q, K.transpose(1, 2)) * self.scale
|
|
106
|
+
|
|
107
|
+
b_i = b.unsqueeze(2)
|
|
108
|
+
b_j = b.unsqueeze(1)
|
|
109
|
+
D = (b_i - b_j).abs().sum(dim=-1)
|
|
110
|
+
|
|
111
|
+
e_norm = F.normalize(e, p=2, dim=-1)
|
|
112
|
+
S = torch.bmm(e_norm, e_norm.transpose(1, 2))
|
|
113
|
+
|
|
114
|
+
scores = attn + self.lambda_d * D - self.lambda_s * S
|
|
115
|
+
|
|
116
|
+
M_agents = s.size(1)
|
|
117
|
+
# Само-маскировка (агент не слушает себя)
|
|
118
|
+
self_mask = torch.eye(M_agents, device=s.device, dtype=torch.bool).unsqueeze(0)
|
|
119
|
+
scores = scores.masked_fill(self_mask, float('-inf'))
|
|
120
|
+
|
|
121
|
+
# v4: Sparse communication — top-K по disagreement
|
|
122
|
+
if sparse_k > 0 and sparse_k < M_agents - 1:
|
|
123
|
+
# Находим top-K наиболее несогласных агентов для каждого
|
|
124
|
+
b_diff = D.clone() # (B, M, M)
|
|
125
|
+
b_diff = b_diff.masked_fill(self_mask, 0.0)
|
|
126
|
+
_, top_idx = b_diff.topk(min(sparse_k, M_agents - 1), dim=-1) # (B, M, k)
|
|
127
|
+
sparse_mask = torch.zeros(scores.shape, device=s.device, dtype=torch.bool)
|
|
128
|
+
sparse_mask.scatter_(-1, top_idx, True)
|
|
129
|
+
# Маскируем всех кроме top-K
|
|
130
|
+
scores = scores.masked_fill(~sparse_mask, float('-inf'))
|
|
131
|
+
|
|
132
|
+
a = F.softmax(scores, dim=-1)
|
|
133
|
+
r = torch.bmm(a, V)
|
|
134
|
+
return r
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# ─── v2: Transformer Agent Updater ───────────────────────────────────────────
|
|
138
|
+
|
|
139
|
+
class AgentUpdater(nn.Module):
|
|
140
|
+
"""
|
|
141
|
+
[v2] Обновляет состояние агентов через Transformer Decoder Step.
|
|
142
|
+
|
|
143
|
+
Замена GRUCell на cross-attention даёт агентам возможность
|
|
144
|
+
самостоятельно решать, НА ЧТО обращать внимание в сообщениях других,
|
|
145
|
+
а не слепо конкатенировать всё подряд.
|
|
146
|
+
|
|
147
|
+
Шаги:
|
|
148
|
+
1. Cross-Attention: каждый агент (query) смотрит на сигналы
|
|
149
|
+
counterargument (key/value) от других агентов.
|
|
150
|
+
2. Gated Evidence Fusion: ворота решают, сколько evidence добавить.
|
|
151
|
+
3. FFN + Residual + LayerNorm.
|
|
152
|
+
4. Dirichlet Belief Update: alpha = softplus(g(s)) + 1.
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
def __init__(self, d_hidden: int, num_classes: int, num_heads: int = 4, multi_label: bool = False):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self.multi_label = multi_label
|
|
158
|
+
# ── Step 1: Cross-attention over counterarguments ─────────────────────
|
|
159
|
+
self.cross_attn = nn.MultiheadAttention(
|
|
160
|
+
embed_dim=d_hidden,
|
|
161
|
+
num_heads=num_heads,
|
|
162
|
+
batch_first=True,
|
|
163
|
+
dropout=0.1,
|
|
164
|
+
)
|
|
165
|
+
self.norm1 = nn.LayerNorm(d_hidden)
|
|
166
|
+
|
|
167
|
+
# ── Step 2: Gated evidence fusion ─────────────────────────────────────
|
|
168
|
+
self.evidence_gate = nn.Sequential(
|
|
169
|
+
nn.Linear(d_hidden * 2, d_hidden),
|
|
170
|
+
nn.Sigmoid(),
|
|
171
|
+
)
|
|
172
|
+
self.evidence_proj = nn.Linear(d_hidden, d_hidden)
|
|
173
|
+
self.norm2 = nn.LayerNorm(d_hidden)
|
|
174
|
+
|
|
175
|
+
# ── Step 3: FFN ───────────────────────────────────────────────────────
|
|
176
|
+
self.ffn = nn.Sequential(
|
|
177
|
+
nn.Linear(d_hidden, d_hidden * 2),
|
|
178
|
+
nn.SiLU(),
|
|
179
|
+
nn.Linear(d_hidden * 2, d_hidden),
|
|
180
|
+
)
|
|
181
|
+
self.ffn_norm = nn.LayerNorm(d_hidden) # Replaced norm3 for clarity
|
|
182
|
+
|
|
183
|
+
# ── Step 4: Belief generator ──────────────────────────────────────────
|
|
184
|
+
self.g = nn.Sequential(
|
|
185
|
+
nn.Linear(d_hidden, d_hidden),
|
|
186
|
+
nn.SiLU(),
|
|
187
|
+
nn.Linear(d_hidden, num_classes),
|
|
188
|
+
)
|
|
189
|
+
self.K = float(num_classes)
|
|
190
|
+
|
|
191
|
+
def forward(self, state: AgentState, r: torch.Tensor, e: torch.Tensor) -> AgentState:
|
|
192
|
+
"""
|
|
193
|
+
state : AgentState — текущее состояние агентов
|
|
194
|
+
r : (B, M, d_hidden) — aggregated counterargument signals
|
|
195
|
+
e : (B, M, d_hidden) — evidence vectors из AgentEvidenceExtractor
|
|
196
|
+
"""
|
|
197
|
+
s = state.s # (B, M, d_hidden)
|
|
198
|
+
|
|
199
|
+
# ── Step 1: Cross-attention ───────────────────────────────────────────
|
|
200
|
+
# Query: текущие состояния агентов
|
|
201
|
+
# Key/Value: counterargument signals — что говорят другие агенты
|
|
202
|
+
# NOTE: если r все нули (начальный шаг), cross-attention = identity
|
|
203
|
+
attn_out, _ = self.cross_attn(query=s, key=r, value=r)
|
|
204
|
+
s = self.norm1(s + attn_out)
|
|
205
|
+
|
|
206
|
+
# ── Step 2: Gated evidence fusion ─────────────────────────────────────
|
|
207
|
+
# Ворота: насколько новый evidence должен изменить состояние?
|
|
208
|
+
gate = self.evidence_gate(torch.cat([s, e], dim=-1)) # (B, M, d)
|
|
209
|
+
e_proj = self.evidence_proj(e) # (B, M, d)
|
|
210
|
+
s = self.norm2(s + gate * e_proj)
|
|
211
|
+
|
|
212
|
+
# ── Step 3: FFN + residual ────────────────────────────────────────────
|
|
213
|
+
s = self.ffn_norm(s + self.ffn(s))
|
|
214
|
+
|
|
215
|
+
# ── Step 4: Belief update ─────────────────────────────────────────────
|
|
216
|
+
g_out = self.g(s)
|
|
217
|
+
|
|
218
|
+
if self.multi_label:
|
|
219
|
+
# Multi-label: Independent sigmoids
|
|
220
|
+
b_next = torch.sigmoid(g_out) # (B, M, K)
|
|
221
|
+
# v4: Разделение неопределённости
|
|
222
|
+
u_alea = (4.0 * b_next * (1.0 - b_next)).mean(dim=-1, keepdim=True) # шум данных
|
|
223
|
+
u_epi = u_alea * 0.5 # приближение для multi-label
|
|
224
|
+
u_next = u_alea + u_epi
|
|
225
|
+
alpha = b_next # Dummy for AgentState
|
|
226
|
+
else:
|
|
227
|
+
# Single-label: Dirichlet
|
|
228
|
+
alpha = F.softplus(g_out) + 1.0 # (B, M, K), всегда > 1
|
|
229
|
+
alpha_sum = alpha.sum(dim=-1, keepdim=True)
|
|
230
|
+
b_next = alpha / alpha_sum
|
|
231
|
+
# v4: Разделение неопределённости (Josang 2002 + Meinert 2024)
|
|
232
|
+
u_epi = self.K ** 2 / alpha_sum ** 2 # Vacuity: неполнота знаний
|
|
233
|
+
u_alea = 1.0 - b_next.max(dim=-1, keepdim=True).values # Шум данных
|
|
234
|
+
u_next = u_epi + u_alea # Совместимость с u (total)
|
|
235
|
+
|
|
236
|
+
return AgentState(
|
|
237
|
+
s=s,
|
|
238
|
+
b=b_next,
|
|
239
|
+
u=u_next,
|
|
240
|
+
alpha=alpha,
|
|
241
|
+
p=state.p,
|
|
242
|
+
e=e,
|
|
243
|
+
u_epi=u_epi,
|
|
244
|
+
u_alea=u_alea,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# ─── v3: Structural Specialization ───────────────────────────────────────────
|
|
249
|
+
|
|
250
|
+
class PerspectiveProjector(nn.Module):
|
|
251
|
+
"""
|
|
252
|
+
[v3+] Каждый агент имеет свой собственный "взгляд" на входные данные.
|
|
253
|
+
Вместо клонирования h_cls, каждый агент i проецирует pooled_output
|
|
254
|
+
через свою собственную матрицу.
|
|
255
|
+
|
|
256
|
+
[v4] Добавлен orthogonality_loss(): штраф за коллинеарность матриц
|
|
257
|
+
проекций. Заставляет агентов смотреть на РАЗНЫЕ аспекты входа.
|
|
258
|
+
Без этого штрафа все M проекций коллапсируют к одной — агенты
|
|
259
|
+
становятся копиями друг друга и дебаты теряют смысл.
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
def __init__(self, num_agents: int, d_input: int, d_hidden: int):
|
|
263
|
+
super().__init__()
|
|
264
|
+
self.num_agents = num_agents
|
|
265
|
+
# Каждый агент имеет свою проекцию
|
|
266
|
+
self.agent_projections = nn.ModuleList([
|
|
267
|
+
nn.Linear(d_input, d_hidden) for _ in range(num_agents)
|
|
268
|
+
])
|
|
269
|
+
# Опционально: позиционное смещение для севенциальных данных
|
|
270
|
+
# Предполагаем макс. длину 512 для инициализации
|
|
271
|
+
self.position_bias = nn.Parameter(
|
|
272
|
+
torch.randn(num_agents, 512) * 0.01
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
def forward(self, pooled: torch.Tensor) -> torch.Tensor:
|
|
276
|
+
"""
|
|
277
|
+
pooled: (B, d_input)
|
|
278
|
+
Returns: (B, M, d_hidden)
|
|
279
|
+
"""
|
|
280
|
+
out = [proj(pooled) for proj in self.agent_projections]
|
|
281
|
+
return torch.stack(out, dim=1)
|
|
282
|
+
|
|
283
|
+
def orthogonality_loss(self) -> torch.Tensor:
|
|
284
|
+
"""
|
|
285
|
+
[v4] Штраф за коллинеарность проекций агентов.
|
|
286
|
+
|
|
287
|
+
Gram(W) = W_norm @ W_norm^T, где W_norm — нормализованные строки
|
|
288
|
+
матриц весов каждого агента. Идеал: Gram = I (единичная матрица).
|
|
289
|
+
Loss = mean((Gram - I)^2) → 0 при полной ортогональности.
|
|
290
|
+
|
|
291
|
+
Математическое обоснование:
|
|
292
|
+
cos(θ_ij) = <w_i, w_j> / (||w_i|| ||w_j||)
|
|
293
|
+
При L_orth → 0: cos(θ_ij) → 0 для i≠j ⟹ θ_ij → 90°
|
|
294
|
+
Агенты гарантированно смотрят на линейно независимые подпространства.
|
|
295
|
+
"""
|
|
296
|
+
# Собираем матрицы весов: (M, d_hidden, d_input)
|
|
297
|
+
W = torch.stack([proj.weight for proj in self.agent_projections])
|
|
298
|
+
M = W.size(0)
|
|
299
|
+
|
|
300
|
+
# Flatten: (M, d_hidden * d_input)
|
|
301
|
+
W_flat = W.view(M, -1)
|
|
302
|
+
|
|
303
|
+
# L2-нормализация строк
|
|
304
|
+
W_norm = F.normalize(W_flat, p=2, dim=-1)
|
|
305
|
+
|
|
306
|
+
# Gram matrix: (M, M) — косинусные сходства между агентами
|
|
307
|
+
gram = W_norm @ W_norm.T
|
|
308
|
+
|
|
309
|
+
# Цель: единичная матрица (каждый агент ортогонален остальным)
|
|
310
|
+
eye = torch.eye(M, device=gram.device)
|
|
311
|
+
|
|
312
|
+
return ((gram - eye) ** 2).mean()
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class AdaptiveDSchedule(nn.Module):
|
|
316
|
+
"""
|
|
317
|
+
[v3+] Адаптивное расписание разногласий с PI-регулятором.
|
|
318
|
+
|
|
319
|
+
Исправление: current_disagreement теперь ИСПОЛЬЗУЕТСЯ для обратной связи.
|
|
320
|
+
Прежняя версия принимала параметр, но полностью его игнорировала.
|
|
321
|
+
|
|
322
|
+
Математическое обоснование:
|
|
323
|
+
d_target = base(h_cls) + Kp * e_t + Ki * Σe_k
|
|
324
|
+
где e_t = base - actual_disagreement — ошибка регулирования.
|
|
325
|
+
|
|
326
|
+
PI-регулятор гарантирует экспоненциальное стремление фактического
|
|
327
|
+
разногласия к целевому (доказано через анализ устойчивости замкнутой системы):
|
|
328
|
+
|λ| = |1 - η(1 - Kp)| < 1 при 0 < η < 2, 0 < Kp < 1.
|
|
329
|
+
|
|
330
|
+
Возвращает (d_target, error) — ошибку нужно накапливать в integral_acc
|
|
331
|
+
в вызывающем коде (core.py), а не внутри буфера модуля,
|
|
332
|
+
чтобы интегральный член был локальным для каждого forward-прохода.
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
def __init__(self, d_hidden: int, Kp: float = 0.5, Ki: float = 0.1):
|
|
336
|
+
super().__init__()
|
|
337
|
+
self.complexity_estimator = nn.Sequential(
|
|
338
|
+
nn.Linear(d_hidden, 64),
|
|
339
|
+
nn.SiLU(),
|
|
340
|
+
nn.Linear(64, 1),
|
|
341
|
+
nn.Sigmoid(),
|
|
342
|
+
)
|
|
343
|
+
self.Kp = Kp # Пропорциональный коэффициент
|
|
344
|
+
self.Ki = Ki # Интегральный коэффициент
|
|
345
|
+
|
|
346
|
+
def forward(
|
|
347
|
+
self,
|
|
348
|
+
h_cls: torch.Tensor,
|
|
349
|
+
current_disagreement: torch.Tensor,
|
|
350
|
+
integral_acc: torch.Tensor = None,
|
|
351
|
+
):
|
|
352
|
+
"""
|
|
353
|
+
h_cls : (B, d_hidden) — начальное представление примера
|
|
354
|
+
current_disagreement: (B, 1, 1) или (B, 1) — фактическое разногласие
|
|
355
|
+
integral_acc : (B, 1) — накопленная ошибка из предыдущих раундов
|
|
356
|
+
|
|
357
|
+
Возвращает:
|
|
358
|
+
d_target : (B, 1) — скорректированная цель разногласия ∈ [0, 1.5]
|
|
359
|
+
error : (B, 1) — ошибка e_t для обновления integral_acc снаружи
|
|
360
|
+
"""
|
|
361
|
+
complexity = self.complexity_estimator(h_cls) # (B, 1)
|
|
362
|
+
base_target = complexity * 0.8 + (1 - complexity) * 0.1 # ∈ [0.1, 0.9]
|
|
363
|
+
|
|
364
|
+
# Приводим фактическое разногласие к форме (B, 1)
|
|
365
|
+
disag = current_disagreement.view(h_cls.size(0), 1)
|
|
366
|
+
|
|
367
|
+
# Пропорциональная ошибка (stop_grad — чистый управляющий сигнал)
|
|
368
|
+
error = (base_target - disag).detach()
|
|
369
|
+
|
|
370
|
+
# Интегральный член (накоплен вызывающим кодом за предыдущие раунды)
|
|
371
|
+
integral = integral_acc if integral_acc is not None else torch.zeros_like(error)
|
|
372
|
+
|
|
373
|
+
# PI-цель: base + proportional + integral
|
|
374
|
+
d_target = base_target + self.Kp * error + self.Ki * integral
|
|
375
|
+
|
|
376
|
+
return torch.clamp(d_target, 0.0, 1.5), error
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
class DebateDiagnostics:
|
|
4
|
+
"""
|
|
5
|
+
Validation mechanics to ensure explicit multi-agent deliberation avoids simple
|
|
6
|
+
over-confidence and actually provides semantic divergence over standard training.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
@staticmethod
|
|
10
|
+
def expected_calibration_error(p_final: torch.Tensor, y: torch.Tensor, n_bins: int = 10):
|
|
11
|
+
"""Computes Expected Calibration Error (ECE) for final probability vector."""
|
|
12
|
+
if y.dim() == 2: # Multi-label
|
|
13
|
+
# If p_final contains logits (values > 1 or < 0 possible), apply sigmoid
|
|
14
|
+
if p_final.max() > 1.0 or p_final.min() < 0.0:
|
|
15
|
+
probs = torch.sigmoid(p_final)
|
|
16
|
+
else:
|
|
17
|
+
probs = p_final
|
|
18
|
+
|
|
19
|
+
K = probs.size(1)
|
|
20
|
+
total_ece = 0.0
|
|
21
|
+
for k in range(K):
|
|
22
|
+
total_ece += DebateDiagnostics._binary_ece(probs[:, k], y[:, k], n_bins)
|
|
23
|
+
return total_ece / K
|
|
24
|
+
|
|
25
|
+
# Single-label (original logic)
|
|
26
|
+
confidences, predictions = torch.max(p_final, dim=1)
|
|
27
|
+
accuracies = (predictions == y)
|
|
28
|
+
return DebateDiagnostics._calculate_ece_from_bins(confidences, accuracies, n_bins, p_final.device)
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def _binary_ece(probs: torch.Tensor, targets: torch.Tensor, n_bins: int):
|
|
32
|
+
"""Binary ECE for a single class."""
|
|
33
|
+
accuracies = (probs > 0.5) == targets
|
|
34
|
+
return DebateDiagnostics._calculate_ece_from_bins(probs, accuracies, n_bins, probs.device)
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def _calculate_ece_from_bins(confidences: torch.Tensor, accuracies: torch.Tensor, n_bins: int, device: torch.device):
|
|
38
|
+
ece = torch.zeros(1, device=device)
|
|
39
|
+
bin_boundaries = torch.linspace(0, 1, n_bins + 1, device=device)
|
|
40
|
+
|
|
41
|
+
for bin_idx in range(n_bins):
|
|
42
|
+
in_bin = (confidences > bin_boundaries[bin_idx]) & (confidences <= bin_boundaries[bin_idx + 1])
|
|
43
|
+
prop_in_bin = in_bin.float().mean()
|
|
44
|
+
|
|
45
|
+
if prop_in_bin > 0:
|
|
46
|
+
accuracy_in_bin = accuracies[in_bin].float().mean()
|
|
47
|
+
avg_confidence_in_bin = confidences[in_bin].mean()
|
|
48
|
+
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
|
49
|
+
|
|
50
|
+
return ece.item()
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def persuasion_matrix(b_all: list):
|
|
54
|
+
"""
|
|
55
|
+
Calculates how much agents changed their beliefs over time.
|
|
56
|
+
Allows us to track if persuasion is actually occurring.
|
|
57
|
+
Returns tensor of shape (Rounds, Agents) tracking average shifts.
|
|
58
|
+
"""
|
|
59
|
+
T = len(b_all)
|
|
60
|
+
if T < 2:
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
shifts = []
|
|
64
|
+
for t in range(1, T):
|
|
65
|
+
shift = torch.norm(b_all[t] - b_all[t-1], p=1, dim=-1).mean(dim=0) # (M,)
|
|
66
|
+
shifts.append(shift)
|
|
67
|
+
|
|
68
|
+
return torch.stack(shifts)
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def belief_diversity_curve(b_all: list):
|
|
72
|
+
"""
|
|
73
|
+
Measures disagreement level over time. Useful to plot against d_t schedule.
|
|
74
|
+
"""
|
|
75
|
+
divergence = []
|
|
76
|
+
for b_t in b_all:
|
|
77
|
+
b_i = b_t.unsqueeze(2)
|
|
78
|
+
b_j = b_t.unsqueeze(1)
|
|
79
|
+
dist = torch.norm(b_i - b_j, p=1, dim=-1).mean()
|
|
80
|
+
divergence.append(dist.item())
|
|
81
|
+
return divergence
|
|
82
|
+
|
|
83
|
+
# ─── [v5] TTT Diagnostics ────────────────────────────────────────────────
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def ttt_adaptation_magnitude(ttt_info: dict) -> float:
|
|
87
|
+
"""
|
|
88
|
+
Насколько сильно TTT изменил веса агентов (L2 norm of param shift).
|
|
89
|
+
|
|
90
|
+
Полезно для мониторинга: слишком маленький сдвиг = TTT бесполезен,
|
|
91
|
+
слишком большой = риск catastrophic forgetting.
|
|
92
|
+
Хороший диапазон: 0.01 — 1.0.
|
|
93
|
+
"""
|
|
94
|
+
if ttt_info is None:
|
|
95
|
+
return 0.0
|
|
96
|
+
return ttt_info.get("param_shift", 0.0)
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def ttt_loss_curve(ttt_info: dict) -> list:
|
|
100
|
+
"""
|
|
101
|
+
Кривая self-supervised loss по шагам TTT.
|
|
102
|
+
|
|
103
|
+
Должна убывать: loss[0] > loss[1] > ... > loss[K-1].
|
|
104
|
+
Если не убывает — lr слишком большой или K слишком маленький.
|
|
105
|
+
"""
|
|
106
|
+
if ttt_info is None:
|
|
107
|
+
return []
|
|
108
|
+
return ttt_info.get("losses", [])
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def ttt_convergence_ratio(ttt_info: dict) -> float:
|
|
112
|
+
"""
|
|
113
|
+
Отношение финального loss к начальному: loss[-1] / loss[0].
|
|
114
|
+
|
|
115
|
+
< 0.5 = хорошая адаптация (loss уменьшился вдвое+).
|
|
116
|
+
≈ 1.0 = TTT не помогает, стоит увеличить K или lr.
|
|
117
|
+
> 1.0 = расходимость, уменьшить lr.
|
|
118
|
+
"""
|
|
119
|
+
if ttt_info is None:
|
|
120
|
+
return 1.0
|
|
121
|
+
losses = ttt_info.get("losses", [])
|
|
122
|
+
if len(losses) < 2:
|
|
123
|
+
return 1.0
|
|
124
|
+
return losses[-1] / max(losses[0], 1e-9)
|