cida-plugin 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,376 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .agent import AgentState
5
+
6
+
7
+ class MessageFormulator(nn.Module):
8
+ """
9
+ Creates a structured message out of the explicit agent components.
10
+
11
+ v4: Anonymous mode (Choi et al. ACL 2025).
12
+ Убирает s_i из сообщения → агент слышит аргумент, не авторитет.
13
+ Доказано устраняет 100% identity bias в debate.
14
+
15
+ Было: m = φ([s; e; b; u]) — агент знает источник → sycophancy/self-bias
16
+ Стало: m = φ([e; b; u]) — role prior только в инициализации b₀
17
+ """
18
+ def __init__(self, d_hidden: int, num_classes: int, d_message: int, anonymous: bool = True):
19
+ super().__init__()
20
+ self.anonymous = anonymous
21
+ if anonymous:
22
+ in_dim = d_hidden + num_classes + 1 # [e; b; u] — без s
23
+ else:
24
+ in_dim = d_hidden + d_hidden + num_classes + 1 # [s; e; b; u]
25
+ self.proj = nn.Sequential(
26
+ nn.Linear(in_dim, d_message),
27
+ nn.LayerNorm(d_message),
28
+ nn.SiLU(),
29
+ nn.Linear(d_message, d_message)
30
+ )
31
+
32
+ def forward(self, s, e, b, u):
33
+ if self.anonymous:
34
+ concat = torch.cat([e, b, u], dim=-1)
35
+ else:
36
+ concat = torch.cat([s, e, b, u], dim=-1)
37
+ return self.proj(concat)
38
+
39
+
40
+ class AgentEvidenceExtractor(nn.Module):
41
+ """
42
+ Computes evidence pointers over tokens, giving each agent a distinct focus.
43
+ p_i^t = softmax((H W_p)(W_q s_i^t)^T / sqrt(d))
44
+ e_i^t = sum_k p_{i,k}^t H_k
45
+ """
46
+ def __init__(self, d_hidden: int):
47
+ super().__init__()
48
+ self.W_q = nn.Linear(d_hidden, d_hidden)
49
+ self.W_p = nn.Linear(d_hidden, d_hidden)
50
+ self.scale = d_hidden ** -0.5
51
+
52
+ def forward(self, s, H, mask=None):
53
+ K = self.W_p(H)
54
+ Q = self.W_q(s)
55
+ V = H
56
+
57
+ attn_mask = None
58
+ if mask is not None:
59
+ # mask: (B, L) -> (B, 1, L) для бродкаста на M агентов
60
+ attn_mask = mask.bool().unsqueeze(1)
61
+
62
+ # F.scaled_dot_product_attention автоматически использует FlashAttention на H100
63
+ # Масштабирование (scale) также применяется автоматически внутри.
64
+ e = F.scaled_dot_product_attention(
65
+ Q, K, V,
66
+ attn_mask=attn_mask,
67
+ dropout_p=0.0,
68
+ is_causal=False,
69
+ )
70
+
71
+ # Возвращаем пустой тензор для p, так как FlashAttention не материализует матрицу внимания.
72
+ # Это экономит O(B*M*L) памяти.
73
+ p_dummy = torch.empty((Q.size(0), Q.size(1), K.size(1)), device=Q.device)
74
+ return p_dummy, e
75
+
76
+
77
+ class CounterargumentCommunication(nn.Module):
78
+ """
79
+ Agents collect messages from others, weighted by an attention mechanism
80
+ that explicitly values disagreement while punishing simple similarity.
81
+
82
+ v4: Disagreement-Routed Sparse Communication (Li et al. EMNLP 2024).
83
+ Каждый агент слушает только top-K по несогласию — противоположные
84
+ убеждения дают сильную связь, похожие — разрыв.
85
+ Это защищает minority voice и сокращает вычисления на 75%.
86
+ """
87
+ def __init__(self, d_hidden: int, d_message: int, lambda_d: float = 1.0, lambda_s: float = 1.0):
88
+ super().__init__()
89
+ self.W_q = nn.Linear(d_hidden, d_hidden)
90
+ self.W_k = nn.Linear(d_message, d_hidden)
91
+ self.W_v = nn.Linear(d_message, d_hidden)
92
+ self.lambda_d = lambda_d
93
+ self.lambda_s = lambda_s
94
+ self.scale = d_hidden ** -0.5
95
+
96
+ def forward(self, s, m, b, e, sparse_k: int = 0):
97
+ """
98
+ s: (B, M, d), m: (B, M, d_msg), b: (B, M, K), e: (B, M, d)
99
+ sparse_k: если > 0 — включает sparse communication (top-K по disagreement)
100
+ """
101
+ Q = self.W_q(s)
102
+ K = self.W_k(m)
103
+ V = self.W_v(m)
104
+
105
+ attn = torch.bmm(Q, K.transpose(1, 2)) * self.scale
106
+
107
+ b_i = b.unsqueeze(2)
108
+ b_j = b.unsqueeze(1)
109
+ D = (b_i - b_j).abs().sum(dim=-1)
110
+
111
+ e_norm = F.normalize(e, p=2, dim=-1)
112
+ S = torch.bmm(e_norm, e_norm.transpose(1, 2))
113
+
114
+ scores = attn + self.lambda_d * D - self.lambda_s * S
115
+
116
+ M_agents = s.size(1)
117
+ # Само-маскировка (агент не слушает себя)
118
+ self_mask = torch.eye(M_agents, device=s.device, dtype=torch.bool).unsqueeze(0)
119
+ scores = scores.masked_fill(self_mask, float('-inf'))
120
+
121
+ # v4: Sparse communication — top-K по disagreement
122
+ if sparse_k > 0 and sparse_k < M_agents - 1:
123
+ # Находим top-K наиболее несогласных агентов для каждого
124
+ b_diff = D.clone() # (B, M, M)
125
+ b_diff = b_diff.masked_fill(self_mask, 0.0)
126
+ _, top_idx = b_diff.topk(min(sparse_k, M_agents - 1), dim=-1) # (B, M, k)
127
+ sparse_mask = torch.zeros(scores.shape, device=s.device, dtype=torch.bool)
128
+ sparse_mask.scatter_(-1, top_idx, True)
129
+ # Маскируем всех кроме top-K
130
+ scores = scores.masked_fill(~sparse_mask, float('-inf'))
131
+
132
+ a = F.softmax(scores, dim=-1)
133
+ r = torch.bmm(a, V)
134
+ return r
135
+
136
+
137
+ # ─── v2: Transformer Agent Updater ───────────────────────────────────────────
138
+
139
+ class AgentUpdater(nn.Module):
140
+ """
141
+ [v2] Обновляет состояние агентов через Transformer Decoder Step.
142
+
143
+ Замена GRUCell на cross-attention даёт агентам возможность
144
+ самостоятельно решать, НА ЧТО обращать внимание в сообщениях других,
145
+ а не слепо конкатенировать всё подряд.
146
+
147
+ Шаги:
148
+ 1. Cross-Attention: каждый агент (query) смотрит на сигналы
149
+ counterargument (key/value) от других агентов.
150
+ 2. Gated Evidence Fusion: ворота решают, сколько evidence добавить.
151
+ 3. FFN + Residual + LayerNorm.
152
+ 4. Dirichlet Belief Update: alpha = softplus(g(s)) + 1.
153
+ """
154
+
155
+ def __init__(self, d_hidden: int, num_classes: int, num_heads: int = 4, multi_label: bool = False):
156
+ super().__init__()
157
+ self.multi_label = multi_label
158
+ # ── Step 1: Cross-attention over counterarguments ─────────────────────
159
+ self.cross_attn = nn.MultiheadAttention(
160
+ embed_dim=d_hidden,
161
+ num_heads=num_heads,
162
+ batch_first=True,
163
+ dropout=0.1,
164
+ )
165
+ self.norm1 = nn.LayerNorm(d_hidden)
166
+
167
+ # ── Step 2: Gated evidence fusion ─────────────────────────────────────
168
+ self.evidence_gate = nn.Sequential(
169
+ nn.Linear(d_hidden * 2, d_hidden),
170
+ nn.Sigmoid(),
171
+ )
172
+ self.evidence_proj = nn.Linear(d_hidden, d_hidden)
173
+ self.norm2 = nn.LayerNorm(d_hidden)
174
+
175
+ # ── Step 3: FFN ───────────────────────────────────────────────────────
176
+ self.ffn = nn.Sequential(
177
+ nn.Linear(d_hidden, d_hidden * 2),
178
+ nn.SiLU(),
179
+ nn.Linear(d_hidden * 2, d_hidden),
180
+ )
181
+ self.ffn_norm = nn.LayerNorm(d_hidden) # Replaced norm3 for clarity
182
+
183
+ # ── Step 4: Belief generator ──────────────────────────────────────────
184
+ self.g = nn.Sequential(
185
+ nn.Linear(d_hidden, d_hidden),
186
+ nn.SiLU(),
187
+ nn.Linear(d_hidden, num_classes),
188
+ )
189
+ self.K = float(num_classes)
190
+
191
+ def forward(self, state: AgentState, r: torch.Tensor, e: torch.Tensor) -> AgentState:
192
+ """
193
+ state : AgentState — текущее состояние агентов
194
+ r : (B, M, d_hidden) — aggregated counterargument signals
195
+ e : (B, M, d_hidden) — evidence vectors из AgentEvidenceExtractor
196
+ """
197
+ s = state.s # (B, M, d_hidden)
198
+
199
+ # ── Step 1: Cross-attention ───────────────────────────────────────────
200
+ # Query: текущие состояния агентов
201
+ # Key/Value: counterargument signals — что говорят другие агенты
202
+ # NOTE: если r все нули (начальный шаг), cross-attention = identity
203
+ attn_out, _ = self.cross_attn(query=s, key=r, value=r)
204
+ s = self.norm1(s + attn_out)
205
+
206
+ # ── Step 2: Gated evidence fusion ─────────────────────────────────────
207
+ # Ворота: насколько новый evidence должен изменить состояние?
208
+ gate = self.evidence_gate(torch.cat([s, e], dim=-1)) # (B, M, d)
209
+ e_proj = self.evidence_proj(e) # (B, M, d)
210
+ s = self.norm2(s + gate * e_proj)
211
+
212
+ # ── Step 3: FFN + residual ────────────────────────────────────────────
213
+ s = self.ffn_norm(s + self.ffn(s))
214
+
215
+ # ── Step 4: Belief update ─────────────────────────────────────────────
216
+ g_out = self.g(s)
217
+
218
+ if self.multi_label:
219
+ # Multi-label: Independent sigmoids
220
+ b_next = torch.sigmoid(g_out) # (B, M, K)
221
+ # v4: Разделение неопределённости
222
+ u_alea = (4.0 * b_next * (1.0 - b_next)).mean(dim=-1, keepdim=True) # шум данных
223
+ u_epi = u_alea * 0.5 # приближение для multi-label
224
+ u_next = u_alea + u_epi
225
+ alpha = b_next # Dummy for AgentState
226
+ else:
227
+ # Single-label: Dirichlet
228
+ alpha = F.softplus(g_out) + 1.0 # (B, M, K), всегда > 1
229
+ alpha_sum = alpha.sum(dim=-1, keepdim=True)
230
+ b_next = alpha / alpha_sum
231
+ # v4: Разделение неопределённости (Josang 2002 + Meinert 2024)
232
+ u_epi = self.K ** 2 / alpha_sum ** 2 # Vacuity: неполнота знаний
233
+ u_alea = 1.0 - b_next.max(dim=-1, keepdim=True).values # Шум данных
234
+ u_next = u_epi + u_alea # Совместимость с u (total)
235
+
236
+ return AgentState(
237
+ s=s,
238
+ b=b_next,
239
+ u=u_next,
240
+ alpha=alpha,
241
+ p=state.p,
242
+ e=e,
243
+ u_epi=u_epi,
244
+ u_alea=u_alea,
245
+ )
246
+
247
+
248
+ # ─── v3: Structural Specialization ───────────────────────────────────────────
249
+
250
+ class PerspectiveProjector(nn.Module):
251
+ """
252
+ [v3+] Каждый агент имеет свой собственный "взгляд" на входные данные.
253
+ Вместо клонирования h_cls, каждый агент i проецирует pooled_output
254
+ через свою собственную матрицу.
255
+
256
+ [v4] Добавлен orthogonality_loss(): штраф за коллинеарность матриц
257
+ проекций. Заставляет агентов смотреть на РАЗНЫЕ аспекты входа.
258
+ Без этого штрафа все M проекций коллапсируют к одной — агенты
259
+ становятся копиями друг друга и дебаты теряют смысл.
260
+ """
261
+
262
+ def __init__(self, num_agents: int, d_input: int, d_hidden: int):
263
+ super().__init__()
264
+ self.num_agents = num_agents
265
+ # Каждый агент имеет свою проекцию
266
+ self.agent_projections = nn.ModuleList([
267
+ nn.Linear(d_input, d_hidden) for _ in range(num_agents)
268
+ ])
269
+ # Опционально: позиционное смещение для севенциальных данных
270
+ # Предполагаем макс. длину 512 для инициализации
271
+ self.position_bias = nn.Parameter(
272
+ torch.randn(num_agents, 512) * 0.01
273
+ )
274
+
275
+ def forward(self, pooled: torch.Tensor) -> torch.Tensor:
276
+ """
277
+ pooled: (B, d_input)
278
+ Returns: (B, M, d_hidden)
279
+ """
280
+ out = [proj(pooled) for proj in self.agent_projections]
281
+ return torch.stack(out, dim=1)
282
+
283
+ def orthogonality_loss(self) -> torch.Tensor:
284
+ """
285
+ [v4] Штраф за коллинеарность проекций агентов.
286
+
287
+ Gram(W) = W_norm @ W_norm^T, где W_norm — нормализованные строки
288
+ матриц весов каждого агента. Идеал: Gram = I (единичная матрица).
289
+ Loss = mean((Gram - I)^2) → 0 при полной ортогональности.
290
+
291
+ Математическое обоснование:
292
+ cos(θ_ij) = <w_i, w_j> / (||w_i|| ||w_j||)
293
+ При L_orth → 0: cos(θ_ij) → 0 для i≠j ⟹ θ_ij → 90°
294
+ Агенты гарантированно смотрят на линейно независимые подпространства.
295
+ """
296
+ # Собираем матрицы весов: (M, d_hidden, d_input)
297
+ W = torch.stack([proj.weight for proj in self.agent_projections])
298
+ M = W.size(0)
299
+
300
+ # Flatten: (M, d_hidden * d_input)
301
+ W_flat = W.view(M, -1)
302
+
303
+ # L2-нормализация строк
304
+ W_norm = F.normalize(W_flat, p=2, dim=-1)
305
+
306
+ # Gram matrix: (M, M) — косинусные сходства между агентами
307
+ gram = W_norm @ W_norm.T
308
+
309
+ # Цель: единичная матрица (каждый агент ортогонален остальным)
310
+ eye = torch.eye(M, device=gram.device)
311
+
312
+ return ((gram - eye) ** 2).mean()
313
+
314
+
315
+ class AdaptiveDSchedule(nn.Module):
316
+ """
317
+ [v3+] Адаптивное расписание разногласий с PI-регулятором.
318
+
319
+ Исправление: current_disagreement теперь ИСПОЛЬЗУЕТСЯ для обратной связи.
320
+ Прежняя версия принимала параметр, но полностью его игнорировала.
321
+
322
+ Математическое обоснование:
323
+ d_target = base(h_cls) + Kp * e_t + Ki * Σe_k
324
+ где e_t = base - actual_disagreement — ошибка регулирования.
325
+
326
+ PI-регулятор гарантирует экспоненциальное стремление фактического
327
+ разногласия к целевому (доказано через анализ устойчивости замкнутой системы):
328
+ |λ| = |1 - η(1 - Kp)| < 1 при 0 < η < 2, 0 < Kp < 1.
329
+
330
+ Возвращает (d_target, error) — ошибку нужно накапливать в integral_acc
331
+ в вызывающем коде (core.py), а не внутри буфера модуля,
332
+ чтобы интегральный член был локальным для каждого forward-прохода.
333
+ """
334
+
335
+ def __init__(self, d_hidden: int, Kp: float = 0.5, Ki: float = 0.1):
336
+ super().__init__()
337
+ self.complexity_estimator = nn.Sequential(
338
+ nn.Linear(d_hidden, 64),
339
+ nn.SiLU(),
340
+ nn.Linear(64, 1),
341
+ nn.Sigmoid(),
342
+ )
343
+ self.Kp = Kp # Пропорциональный коэффициент
344
+ self.Ki = Ki # Интегральный коэффициент
345
+
346
+ def forward(
347
+ self,
348
+ h_cls: torch.Tensor,
349
+ current_disagreement: torch.Tensor,
350
+ integral_acc: torch.Tensor = None,
351
+ ):
352
+ """
353
+ h_cls : (B, d_hidden) — начальное представление примера
354
+ current_disagreement: (B, 1, 1) или (B, 1) — фактическое разногласие
355
+ integral_acc : (B, 1) — накопленная ошибка из предыдущих раундов
356
+
357
+ Возвращает:
358
+ d_target : (B, 1) — скорректированная цель разногласия ∈ [0, 1.5]
359
+ error : (B, 1) — ошибка e_t для обновления integral_acc снаружи
360
+ """
361
+ complexity = self.complexity_estimator(h_cls) # (B, 1)
362
+ base_target = complexity * 0.8 + (1 - complexity) * 0.1 # ∈ [0.1, 0.9]
363
+
364
+ # Приводим фактическое разногласие к форме (B, 1)
365
+ disag = current_disagreement.view(h_cls.size(0), 1)
366
+
367
+ # Пропорциональная ошибка (stop_grad — чистый управляющий сигнал)
368
+ error = (base_target - disag).detach()
369
+
370
+ # Интегральный член (накоплен вызывающим кодом за предыдущие раунды)
371
+ integral = integral_acc if integral_acc is not None else torch.zeros_like(error)
372
+
373
+ # PI-цель: base + proportional + integral
374
+ d_target = base_target + self.Kp * error + self.Ki * integral
375
+
376
+ return torch.clamp(d_target, 0.0, 1.5), error
@@ -0,0 +1,124 @@
1
+ import torch
2
+
3
+ class DebateDiagnostics:
4
+ """
5
+ Validation mechanics to ensure explicit multi-agent deliberation avoids simple
6
+ over-confidence and actually provides semantic divergence over standard training.
7
+ """
8
+
9
+ @staticmethod
10
+ def expected_calibration_error(p_final: torch.Tensor, y: torch.Tensor, n_bins: int = 10):
11
+ """Computes Expected Calibration Error (ECE) for final probability vector."""
12
+ if y.dim() == 2: # Multi-label
13
+ # If p_final contains logits (values > 1 or < 0 possible), apply sigmoid
14
+ if p_final.max() > 1.0 or p_final.min() < 0.0:
15
+ probs = torch.sigmoid(p_final)
16
+ else:
17
+ probs = p_final
18
+
19
+ K = probs.size(1)
20
+ total_ece = 0.0
21
+ for k in range(K):
22
+ total_ece += DebateDiagnostics._binary_ece(probs[:, k], y[:, k], n_bins)
23
+ return total_ece / K
24
+
25
+ # Single-label (original logic)
26
+ confidences, predictions = torch.max(p_final, dim=1)
27
+ accuracies = (predictions == y)
28
+ return DebateDiagnostics._calculate_ece_from_bins(confidences, accuracies, n_bins, p_final.device)
29
+
30
+ @staticmethod
31
+ def _binary_ece(probs: torch.Tensor, targets: torch.Tensor, n_bins: int):
32
+ """Binary ECE for a single class."""
33
+ accuracies = (probs > 0.5) == targets
34
+ return DebateDiagnostics._calculate_ece_from_bins(probs, accuracies, n_bins, probs.device)
35
+
36
+ @staticmethod
37
+ def _calculate_ece_from_bins(confidences: torch.Tensor, accuracies: torch.Tensor, n_bins: int, device: torch.device):
38
+ ece = torch.zeros(1, device=device)
39
+ bin_boundaries = torch.linspace(0, 1, n_bins + 1, device=device)
40
+
41
+ for bin_idx in range(n_bins):
42
+ in_bin = (confidences > bin_boundaries[bin_idx]) & (confidences <= bin_boundaries[bin_idx + 1])
43
+ prop_in_bin = in_bin.float().mean()
44
+
45
+ if prop_in_bin > 0:
46
+ accuracy_in_bin = accuracies[in_bin].float().mean()
47
+ avg_confidence_in_bin = confidences[in_bin].mean()
48
+ ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
49
+
50
+ return ece.item()
51
+
52
+ @staticmethod
53
+ def persuasion_matrix(b_all: list):
54
+ """
55
+ Calculates how much agents changed their beliefs over time.
56
+ Allows us to track if persuasion is actually occurring.
57
+ Returns tensor of shape (Rounds, Agents) tracking average shifts.
58
+ """
59
+ T = len(b_all)
60
+ if T < 2:
61
+ return None
62
+
63
+ shifts = []
64
+ for t in range(1, T):
65
+ shift = torch.norm(b_all[t] - b_all[t-1], p=1, dim=-1).mean(dim=0) # (M,)
66
+ shifts.append(shift)
67
+
68
+ return torch.stack(shifts)
69
+
70
+ @staticmethod
71
+ def belief_diversity_curve(b_all: list):
72
+ """
73
+ Measures disagreement level over time. Useful to plot against d_t schedule.
74
+ """
75
+ divergence = []
76
+ for b_t in b_all:
77
+ b_i = b_t.unsqueeze(2)
78
+ b_j = b_t.unsqueeze(1)
79
+ dist = torch.norm(b_i - b_j, p=1, dim=-1).mean()
80
+ divergence.append(dist.item())
81
+ return divergence
82
+
83
+ # ─── [v5] TTT Diagnostics ────────────────────────────────────────────────
84
+
85
+ @staticmethod
86
+ def ttt_adaptation_magnitude(ttt_info: dict) -> float:
87
+ """
88
+ Насколько сильно TTT изменил веса агентов (L2 norm of param shift).
89
+
90
+ Полезно для мониторинга: слишком маленький сдвиг = TTT бесполезен,
91
+ слишком большой = риск catastrophic forgetting.
92
+ Хороший диапазон: 0.01 — 1.0.
93
+ """
94
+ if ttt_info is None:
95
+ return 0.0
96
+ return ttt_info.get("param_shift", 0.0)
97
+
98
+ @staticmethod
99
+ def ttt_loss_curve(ttt_info: dict) -> list:
100
+ """
101
+ Кривая self-supervised loss по шагам TTT.
102
+
103
+ Должна убывать: loss[0] > loss[1] > ... > loss[K-1].
104
+ Если не убывает — lr слишком большой или K слишком маленький.
105
+ """
106
+ if ttt_info is None:
107
+ return []
108
+ return ttt_info.get("losses", [])
109
+
110
+ @staticmethod
111
+ def ttt_convergence_ratio(ttt_info: dict) -> float:
112
+ """
113
+ Отношение финального loss к начальному: loss[-1] / loss[0].
114
+
115
+ < 0.5 = хорошая адаптация (loss уменьшился вдвое+).
116
+ ≈ 1.0 = TTT не помогает, стоит увеличить K или lr.
117
+ > 1.0 = расходимость, уменьшить lr.
118
+ """
119
+ if ttt_info is None:
120
+ return 1.0
121
+ losses = ttt_info.get("losses", [])
122
+ if len(losses) < 2:
123
+ return 1.0
124
+ return losses[-1] / max(losses[0], 1e-9)