cida-plugin 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cida_plugin/__init__.py +62 -0
- cida_plugin/agent.py +198 -0
- cida_plugin/config.py +228 -0
- cida_plugin/consensus.py +178 -0
- cida_plugin/core.py +451 -0
- cida_plugin/deliberation.py +376 -0
- cida_plugin/diagnostics.py +124 -0
- cida_plugin/hf.py +194 -0
- cida_plugin/liquid_dynamics.py +291 -0
- cida_plugin/losses.py +174 -0
- cida_plugin/translator.py +74 -0
- cida_plugin/ttt.py +348 -0
- cida_plugin/vision_backbone.py +43 -0
- cida_plugin-1.0.0.dist-info/METADATA +167 -0
- cida_plugin-1.0.0.dist-info/RECORD +18 -0
- cida_plugin-1.0.0.dist-info/WHEEL +5 -0
- cida_plugin-1.0.0.dist-info/licenses/LICENSE +201 -0
- cida_plugin-1.0.0.dist-info/top_level.txt +1 -0
cida_plugin/__init__.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CIDA-Plugin: Universal Evidence-Grounded Multi-Agent Deliberation Layer (v3).
|
|
3
|
+
|
|
4
|
+
Использование (минимальное):
|
|
5
|
+
from cida_plugin import CIDAPlugin, CIDAPluginConfig
|
|
6
|
+
|
|
7
|
+
cfg = CIDAPluginConfig(d_input=768, num_classes=2)
|
|
8
|
+
plugin = CIDAPlugin(cfg)
|
|
9
|
+
out = plugin(pooled_output) # pooled_output: (B, 768)
|
|
10
|
+
logits = out["p_final"] # (B, 2)
|
|
11
|
+
|
|
12
|
+
Использование (с seq_output для evidence pointers):
|
|
13
|
+
out = plugin(pooled_output, seq_output=hidden_states, mask=attention_mask)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from .config import CIDAPluginConfig
|
|
17
|
+
from .core import CIDAPlugin
|
|
18
|
+
from .agent import (
|
|
19
|
+
AgentState,
|
|
20
|
+
apply_role_prior,
|
|
21
|
+
compute_role_orthogonality_loss,
|
|
22
|
+
RoleNames,
|
|
23
|
+
)
|
|
24
|
+
from .deliberation import (
|
|
25
|
+
AgentEvidenceExtractor,
|
|
26
|
+
MessageFormulator,
|
|
27
|
+
CounterargumentCommunication,
|
|
28
|
+
AgentUpdater,
|
|
29
|
+
)
|
|
30
|
+
from .consensus import ConsensusAggregator, HaltingPredictor
|
|
31
|
+
from .losses import CIDALoss
|
|
32
|
+
from .diagnostics import DebateDiagnostics
|
|
33
|
+
from .ttt import TestTimeTrainer
|
|
34
|
+
from .liquid_dynamics import LiquidDeliberationSolver
|
|
35
|
+
from .hf import wrap_hf_model, HFModelWithCIDA
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
# Главный интерфейс
|
|
39
|
+
"CIDAPlugin",
|
|
40
|
+
"CIDAPluginConfig",
|
|
41
|
+
"wrap_hf_model",
|
|
42
|
+
"HFModelWithCIDA",
|
|
43
|
+
# Потери и диагностика
|
|
44
|
+
"CIDALoss",
|
|
45
|
+
"DebateDiagnostics",
|
|
46
|
+
# [v5] Test-Time Training
|
|
47
|
+
"TestTimeTrainer",
|
|
48
|
+
# [v5] Liquid Neural ODE Dynamics
|
|
49
|
+
"LiquidDeliberationSolver",
|
|
50
|
+
# Компоненты (для кастомных архитектур)
|
|
51
|
+
"AgentState",
|
|
52
|
+
"apply_role_prior",
|
|
53
|
+
"compute_role_orthogonality_loss",
|
|
54
|
+
"RoleNames",
|
|
55
|
+
"AgentEvidenceExtractor",
|
|
56
|
+
"MessageFormulator",
|
|
57
|
+
"CounterargumentCommunication",
|
|
58
|
+
"AgentUpdater",
|
|
59
|
+
"ConsensusAggregator",
|
|
60
|
+
"HaltingPredictor",
|
|
61
|
+
]
|
|
62
|
+
|
cida_plugin/agent.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""
|
|
2
|
+
agent.py — CIDA-Plugin Agent (упрощённая версия)
|
|
3
|
+
|
|
4
|
+
Ключевое изменение:
|
|
5
|
+
БЫЛО: RoleEmbeddings (аддитивный bias) + debate_loss + role_spec_loss + orth_loss
|
|
6
|
+
СТАЛО: ROLE_PRIORS — структурные априорные убеждения агентов
|
|
7
|
+
|
|
8
|
+
Почему это лучше:
|
|
9
|
+
RoleEmbeddings создаёт СТАТИСТИЧЕСКОЕ различие (агенты начинают чуть
|
|
10
|
+
по-разному и потом схожи). Role priors создают СТРУКТУРНОЕ различие:
|
|
11
|
+
|
|
12
|
+
Прокурор → P(y=1) = 0.85 до просмотра данных
|
|
13
|
+
Защитник → P(y=1) = 0.15 до просмотра данных
|
|
14
|
+
Скептик → P(y=1) = 0.50 (максимальная неопределённость)
|
|
15
|
+
Интегратор → P(y=1) = данные (без prior, объективен)
|
|
16
|
+
|
|
17
|
+
Это байесовски обоснованная альтернатива: агенты представляют разные
|
|
18
|
+
prior beliefs, а консенсус — это posterior после объединения.
|
|
19
|
+
|
|
20
|
+
Математически гарантировано:
|
|
21
|
+
|b_0 - b_1| ≥ BLEND * |P_0 - P_1| = 0.3 * 0.7 = 0.21 > 0
|
|
22
|
+
То есть несогласие является СВОЙСТВОМ архитектуры, а не результатом loss.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
import torch.nn.functional as F
|
|
28
|
+
from dataclasses import dataclass
|
|
29
|
+
from typing import Optional
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# ── Роли агентов: структурные prior beliefs ───────────────────────────────────
|
|
33
|
+
#
|
|
34
|
+
# Для binary/multi-label: prior = вероятность positive класса
|
|
35
|
+
# Для multi-class: prior = вектор вероятностей классов
|
|
36
|
+
#
|
|
37
|
+
# BLEND = насколько сильно prior смешивается с данными
|
|
38
|
+
# 0.0 → prior игнорируется (все агенты одинаковы)
|
|
39
|
+
# 0.3 → 30% prior, 70% данные (рекомендуется)
|
|
40
|
+
# 1.0 → агент игнорирует данные (только prior)
|
|
41
|
+
|
|
42
|
+
ROLE_PRIOR_POSITIVE = [
|
|
43
|
+
0.85, # Прокурор: склонен считать что патология есть
|
|
44
|
+
0.15, # Защитник: склонен считать что патологии нет
|
|
45
|
+
0.50, # Скептик: максимальная неопределённость
|
|
46
|
+
None, # Интегратор: без prior, смотрит объективно
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
ROLE_PRIOR_BLEND = 0.30 # Сила prior относительно данных
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class AgentState:
|
|
54
|
+
"""Явное представление состояния агента."""
|
|
55
|
+
s: torch.Tensor # (B, d) - Hidden state
|
|
56
|
+
b: torch.Tensor # (B, K) - Belief (posterior после prior blending)
|
|
57
|
+
u: torch.Tensor # (B, 1) - Total uncertainty (backward compat)
|
|
58
|
+
p: torch.Tensor # (B, n) - Evidence pointer
|
|
59
|
+
e: torch.Tensor # (B, d) - Evidence vector
|
|
60
|
+
alpha: torch.Tensor # (B, K) - Dirichlet parameters
|
|
61
|
+
# v4: Разделение неопределённости (Josang 2002 + Meinert 2024)
|
|
62
|
+
u_epi: torch.Tensor = None # (B, 1) - Epistemic: K²/alpha_sum² (vacuity)
|
|
63
|
+
u_alea: torch.Tensor = None # (B, 1) - Aleatoric: 1 - max(b) (data noise)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def apply_role_prior(
|
|
67
|
+
b_data: torch.Tensor,
|
|
68
|
+
role_idx: int,
|
|
69
|
+
multi_label: bool = False,
|
|
70
|
+
positive_class_idx: int = 1,
|
|
71
|
+
) -> torch.Tensor:
|
|
72
|
+
"""
|
|
73
|
+
Смешивает data-driven belief с role-specific prior (обобщённая версия для любого K).
|
|
74
|
+
"""
|
|
75
|
+
role = role_idx % 4
|
|
76
|
+
if role == 3: # Интегратор: без prior
|
|
77
|
+
return b_data
|
|
78
|
+
|
|
79
|
+
K = b_data.size(-1)
|
|
80
|
+
device = b_data.device
|
|
81
|
+
dtype = b_data.dtype
|
|
82
|
+
|
|
83
|
+
if multi_label or K == 1:
|
|
84
|
+
prior_val = ROLE_PRIOR_POSITIVE[role]
|
|
85
|
+
prior = torch.full_like(b_data, prior_val)
|
|
86
|
+
elif K == 2:
|
|
87
|
+
prior_val = ROLE_PRIOR_POSITIVE[role]
|
|
88
|
+
prior = torch.tensor(
|
|
89
|
+
[1.0 - prior_val, prior_val],
|
|
90
|
+
device=device,
|
|
91
|
+
dtype=dtype,
|
|
92
|
+
).expand_as(b_data)
|
|
93
|
+
else:
|
|
94
|
+
# Multi-class
|
|
95
|
+
if role == 0: # Prosecutor: biased to positive class
|
|
96
|
+
prior_val = 0.85
|
|
97
|
+
uniform_val = (1.0 - prior_val) / max(K - 1, 1)
|
|
98
|
+
prior = torch.full_like(b_data, uniform_val)
|
|
99
|
+
prior = prior.clone()
|
|
100
|
+
prior[:, positive_class_idx] = prior_val
|
|
101
|
+
elif role == 1: # Defender: biased against positive class (pro-alternative classes)
|
|
102
|
+
prior_val = 0.15
|
|
103
|
+
uniform_val = (1.0 - prior_val) / max(K - 1, 1)
|
|
104
|
+
prior = torch.full_like(b_data, uniform_val)
|
|
105
|
+
prior = prior.clone()
|
|
106
|
+
prior[:, positive_class_idx] = prior_val
|
|
107
|
+
elif role == 2: # Skeptic: neutral / high uncertainty (uniform)
|
|
108
|
+
prior = torch.full_like(b_data, 1.0 / K)
|
|
109
|
+
|
|
110
|
+
prior = prior / prior.sum(dim=-1, keepdim=True)
|
|
111
|
+
|
|
112
|
+
return (1.0 - ROLE_PRIOR_BLEND) * b_data + ROLE_PRIOR_BLEND * prior
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def apply_role_priors_batched(
|
|
116
|
+
b_data: torch.Tensor,
|
|
117
|
+
M: int,
|
|
118
|
+
multi_label: bool = False,
|
|
119
|
+
positive_class_idx: int = 1,
|
|
120
|
+
) -> torch.Tensor:
|
|
121
|
+
"""
|
|
122
|
+
Векторизованная версия apply_role_prior для батча агентов сразу (B, M, K).
|
|
123
|
+
"""
|
|
124
|
+
B, _, K = b_data.shape
|
|
125
|
+
device = b_data.device
|
|
126
|
+
dtype = b_data.dtype
|
|
127
|
+
|
|
128
|
+
priors = torch.zeros((M, K), device=device, dtype=dtype)
|
|
129
|
+
blend_mask = torch.full((M, 1), ROLE_PRIOR_BLEND, device=device, dtype=dtype)
|
|
130
|
+
|
|
131
|
+
for i in range(M):
|
|
132
|
+
role = i % 4
|
|
133
|
+
if role == 3: # Интегратор: без prior
|
|
134
|
+
blend_mask[i, 0] = 0.0
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
if multi_label or K == 1:
|
|
138
|
+
priors[i, :] = ROLE_PRIOR_POSITIVE[role]
|
|
139
|
+
elif K == 2:
|
|
140
|
+
priors[i, 0] = 1.0 - ROLE_PRIOR_POSITIVE[role]
|
|
141
|
+
priors[i, 1] = ROLE_PRIOR_POSITIVE[role]
|
|
142
|
+
else:
|
|
143
|
+
# Multi-class
|
|
144
|
+
if role == 0: # Prosecutor: biased to positive class
|
|
145
|
+
prior_val = 0.85
|
|
146
|
+
uniform_val = (1.0 - prior_val) / max(K - 1, 1)
|
|
147
|
+
priors[i, :] = uniform_val
|
|
148
|
+
priors[i, positive_class_idx] = prior_val
|
|
149
|
+
elif role == 1: # Defender: biased against positive class
|
|
150
|
+
prior_val = 0.15
|
|
151
|
+
uniform_val = (1.0 - prior_val) / max(K - 1, 1)
|
|
152
|
+
priors[i, :] = uniform_val
|
|
153
|
+
priors[i, positive_class_idx] = prior_val
|
|
154
|
+
elif role == 2: # Skeptic: neutral / uniform
|
|
155
|
+
priors[i, :] = 1.0 / K
|
|
156
|
+
|
|
157
|
+
priors[i] = priors[i] / priors[i].sum()
|
|
158
|
+
|
|
159
|
+
priors = priors.unsqueeze(0) # (1, M, K)
|
|
160
|
+
blend_mask = blend_mask.unsqueeze(0) # (1, M, 1)
|
|
161
|
+
|
|
162
|
+
return (1.0 - blend_mask) * b_data + blend_mask * priors
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def compute_role_orthogonality_loss(agent_states: torch.Tensor) -> torch.Tensor:
|
|
167
|
+
"""
|
|
168
|
+
Мягкая ортогональность через ПРЕДСТАВЛЕНИЯ, а не через веса.
|
|
169
|
+
|
|
170
|
+
БЫЛО: orth_loss через Gram(weights) — веса могут быть ортогональны,
|
|
171
|
+
но representations при этом коллапсируют (если вход низкоранговый).
|
|
172
|
+
|
|
173
|
+
СТАЛО: orth_loss через Gram(representations) — напрямую измеряем
|
|
174
|
+
то, что нас реально интересует.
|
|
175
|
+
|
|
176
|
+
agent_states: (B, M, d) — hidden states агентов
|
|
177
|
+
Returns: scalar — потери (0 при полной ортогональности)
|
|
178
|
+
|
|
179
|
+
Математически:
|
|
180
|
+
G_ij = <s_i, s_j> / (||s_i|| · ||s_j||) — косинусное сходство
|
|
181
|
+
Loss = mean((G - I)²) → 0 при G = I
|
|
182
|
+
"""
|
|
183
|
+
B, M, d = agent_states.shape
|
|
184
|
+
# Усредняем по батчу, нормализуем
|
|
185
|
+
s_mean = agent_states.mean(0) # (M, d)
|
|
186
|
+
s_norm = F.normalize(s_mean, p=2, dim=-1) # (M, d)
|
|
187
|
+
gram = s_norm @ s_norm.T # (M, M)
|
|
188
|
+
eye = torch.eye(M, device=gram.device)
|
|
189
|
+
return ((gram - eye) ** 2).mean()
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class RoleNames:
|
|
193
|
+
"""Имена агентов для диагностики."""
|
|
194
|
+
NAMES = ["Prosecutor", "Defender", "Skeptic", "Integrator"]
|
|
195
|
+
|
|
196
|
+
@classmethod
|
|
197
|
+
def get(cls, idx: int) -> str:
|
|
198
|
+
return cls.NAMES[idx % 4]
|
cida_plugin/config.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CIDAPluginConfig — единая точка конфигурации для CIDA-Plugin v3.
|
|
3
|
+
"""
|
|
4
|
+
from dataclasses import dataclass, asdict
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class CIDAPluginConfig:
|
|
11
|
+
"""
|
|
12
|
+
Конфигурация универсального CIDA-Plugin слоя (v3 — упрощённая архитектура).
|
|
13
|
+
|
|
14
|
+
Параметры ядра
|
|
15
|
+
--------------
|
|
16
|
+
d_input : int
|
|
17
|
+
Размерность pooled_output от внешнего энкодера.
|
|
18
|
+
Примеры: 128 (bert-tiny), 768 (bert-base / distilbert), 1024 (DenseNet121).
|
|
19
|
+
d_hidden : int
|
|
20
|
+
Внутренняя размерность агентов плагина.
|
|
21
|
+
d_message : int
|
|
22
|
+
Размерность сообщений при коммуникации агентов.
|
|
23
|
+
num_agents : int
|
|
24
|
+
Количество агентов (рекомендуется 4: Prosecutor, Defender, Skeptic, Integrator).
|
|
25
|
+
num_classes : int
|
|
26
|
+
Число целевых классов.
|
|
27
|
+
max_rounds : int
|
|
28
|
+
Максимальное число раундов deliberation.
|
|
29
|
+
multi_label : bool
|
|
30
|
+
True для multi-label задач (BCE), False для single-label (CE).
|
|
31
|
+
|
|
32
|
+
Архитектурные параметры
|
|
33
|
+
-----------------------
|
|
34
|
+
num_attn_heads : int
|
|
35
|
+
Число голов в TransformerAgentUpdater (cross-attention).
|
|
36
|
+
Должен делить d_hidden.
|
|
37
|
+
|
|
38
|
+
Role Priors
|
|
39
|
+
-----------
|
|
40
|
+
role_prior_blend : float
|
|
41
|
+
Сила смешивания роли с данными (0.0 = без priors, 1.0 = только priors).
|
|
42
|
+
Рекомендуется 0.30 (30% prior, 70% данные).
|
|
43
|
+
positive_class_idx : int
|
|
44
|
+
Индекс "положительного" класса для multi-class задач (K>2).
|
|
45
|
+
По умолчанию 1, но может быть 0, 2 и т.д. в зависимости от датасета.
|
|
46
|
+
|
|
47
|
+
Loss гиперпараметры
|
|
48
|
+
-------------------
|
|
49
|
+
lambda_cal : float
|
|
50
|
+
Вес calibration loss (Brier score). Рекомендуется 0.3–0.5.
|
|
51
|
+
lambda_ac : float
|
|
52
|
+
Вес anti-collapse loss. Рекомендуется 0.1–0.3.
|
|
53
|
+
min_disagreement : float
|
|
54
|
+
Минимальный порог несогласия для anti-collapse.
|
|
55
|
+
|
|
56
|
+
Регуляризация
|
|
57
|
+
-------------
|
|
58
|
+
comm_dropout : float
|
|
59
|
+
Dropout на канале коммуникации между агентами.
|
|
60
|
+
early_stop_threshold : float or None
|
|
61
|
+
Порог уверенности для досрочной остановки deliberation (inference only).
|
|
62
|
+
Если None — всегда идёт max_rounds раундов.
|
|
63
|
+
freeze_input_proj : bool
|
|
64
|
+
Если True — входной проекционный слой не обучается.
|
|
65
|
+
|
|
66
|
+
Ablation Flags
|
|
67
|
+
--------------
|
|
68
|
+
abl_no_pointers : bool
|
|
69
|
+
Отключить evidence pointers.
|
|
70
|
+
abl_no_messages : bool
|
|
71
|
+
Отключить формулировку сообщений.
|
|
72
|
+
abl_no_communication : bool
|
|
73
|
+
Полностью отключить коммуникацию.
|
|
74
|
+
|
|
75
|
+
Perspective Projector
|
|
76
|
+
---------------------
|
|
77
|
+
use_perspective_projector : bool
|
|
78
|
+
Если True — каждый агент имеет свою проекцию входа (per-agent view).
|
|
79
|
+
Если False — все агенты начинают с одного представления (expand).
|
|
80
|
+
По умолчанию False для упрощённой архитектуры v3.
|
|
81
|
+
|
|
82
|
+
Test-Time Training (TTT)
|
|
83
|
+
-------------------------
|
|
84
|
+
use_ttt : bool
|
|
85
|
+
Агенты адаптируют свои веса к каждому входу перед deliberation.
|
|
86
|
+
ttt_steps : int
|
|
87
|
+
Число шагов внутренней оптимизации (K).
|
|
88
|
+
ttt_lr : float
|
|
89
|
+
Learning rate для inner Adam.
|
|
90
|
+
ttt_mask_ratio : float
|
|
91
|
+
Доля маскируемых компонент скрытого состояния.
|
|
92
|
+
|
|
93
|
+
Liquid Neural ODE Dynamics
|
|
94
|
+
--------------------------
|
|
95
|
+
use_liquid_dynamics : bool
|
|
96
|
+
Заменяет дискретные раунды на непрерывное ODE: ds/dt = -s/τ + F(s,r,e).
|
|
97
|
+
liquid_solver : str
|
|
98
|
+
Метод интегрирования: 'euler', 'dopri5' и др.
|
|
99
|
+
liquid_atol : float
|
|
100
|
+
Абсолютная толерантность для адаптивных решателей.
|
|
101
|
+
liquid_rtol : float
|
|
102
|
+
Относительная толерантность для адаптивных решателей.
|
|
103
|
+
trajectory_save_every : int
|
|
104
|
+
Сохранять каждое N-е состояние в trajectory для ODE решателя.
|
|
105
|
+
Для адаптивных решателей (dopri5) рекомендуется 10-50 для экономии памяти.
|
|
106
|
+
Для euler можно использовать 1.
|
|
107
|
+
"""
|
|
108
|
+
# ─── Размерности ────────────────────────────────────────────────────────────
|
|
109
|
+
d_input: int = 128
|
|
110
|
+
d_hidden: int = 128
|
|
111
|
+
d_message: int = 128
|
|
112
|
+
|
|
113
|
+
# ─── Агенты и deliberation ───────────────────────────────────────────────────
|
|
114
|
+
num_agents: int = 4
|
|
115
|
+
num_classes: int = 2
|
|
116
|
+
max_rounds: int = 3
|
|
117
|
+
multi_label: bool = False
|
|
118
|
+
|
|
119
|
+
# ─── Архитектурные параметры ─────────────────────────────────────────────────
|
|
120
|
+
num_attn_heads: int = 4
|
|
121
|
+
|
|
122
|
+
# ─── Role Priors ─────────────────────────────────────────────────────────────
|
|
123
|
+
role_prior_blend: float = 0.30
|
|
124
|
+
positive_class_idx: int = 1
|
|
125
|
+
|
|
126
|
+
# ─── Loss гиперпараметры ─────────────────────────────────────────────────────
|
|
127
|
+
lambda_cal: float = 0.4
|
|
128
|
+
lambda_ac: float = 0.2
|
|
129
|
+
min_disagreement: float = 0.08
|
|
130
|
+
|
|
131
|
+
# ─── Регуляризация ──────────────────────────────────────────────────────────
|
|
132
|
+
comm_dropout: float = 0.2
|
|
133
|
+
early_stop_threshold: float = 0.90
|
|
134
|
+
freeze_input_proj: bool = False
|
|
135
|
+
|
|
136
|
+
# ─── Ablation flags ─────────────────────────────────────────────────────────
|
|
137
|
+
abl_no_pointers: bool = False
|
|
138
|
+
abl_no_messages: bool = False
|
|
139
|
+
abl_no_communication: bool = False
|
|
140
|
+
|
|
141
|
+
# ─── Perspective Projector ─────────────────────────────────────────────────
|
|
142
|
+
use_perspective_projector: bool = True
|
|
143
|
+
|
|
144
|
+
# ─── Test-Time Training ──────────────────────────────────────────────────────
|
|
145
|
+
use_ttt: bool = False
|
|
146
|
+
ttt_steps: int = 3
|
|
147
|
+
ttt_lr: float = 1e-3
|
|
148
|
+
ttt_mask_ratio: float = 0.15
|
|
149
|
+
|
|
150
|
+
# ─── Liquid Neural ODE Dynamics ──────────────────────────────────────────────
|
|
151
|
+
use_liquid_dynamics: bool = False
|
|
152
|
+
liquid_solver: str = "euler"
|
|
153
|
+
liquid_atol: float = 1e-3
|
|
154
|
+
liquid_rtol: float = 1e-3
|
|
155
|
+
trajectory_save_every: int = 10
|
|
156
|
+
|
|
157
|
+
# ─── CIDA v4: Architectural Improvements ─────────────────────────────────────
|
|
158
|
+
# Disagreement-Routed Sparse Communication (Li et al. EMNLP 2024)
|
|
159
|
+
# Каждый агент слушает только top-K по disagreement, не всех.
|
|
160
|
+
sparse_comm_k: int = 2
|
|
161
|
+
|
|
162
|
+
# Anonymous Message Passing (Choi et al. ACL 2025)
|
|
163
|
+
# Убирает s (identity) из сообщений → агент слышит аргумент, а не авторитет.
|
|
164
|
+
anonymous_messages: bool = True
|
|
165
|
+
|
|
166
|
+
# Dynamic Trust Weights (CortexDebate MDM, ACL 2025)
|
|
167
|
+
# EMA trust вместо статических ROLE_WEIGHTS.
|
|
168
|
+
use_dynamic_trust: bool = False
|
|
169
|
+
trust_ema_gamma: float = 0.9
|
|
170
|
+
|
|
171
|
+
# ─── Kairos Dynamic Deliberation (KDD) ───────────────────────────────────────
|
|
172
|
+
# Time-Travel (Rollback) mechanics
|
|
173
|
+
enable_rollback: bool = False
|
|
174
|
+
rollback_noise_std: float = 0.05
|
|
175
|
+
|
|
176
|
+
# Latent Superposition
|
|
177
|
+
# If final disagreement is > threshold, output Superposition State
|
|
178
|
+
superposition_threshold: float = 0.7
|
|
179
|
+
|
|
180
|
+
def __post_init__(self):
|
|
181
|
+
assert self.d_input > 0, "d_input must be positive"
|
|
182
|
+
assert self.num_agents >= 2, "Need at least 2 agents"
|
|
183
|
+
assert self.num_classes >= 2, "Need at least 2 classes"
|
|
184
|
+
assert self.max_rounds >= 1, "Need at least 1 deliberation round"
|
|
185
|
+
assert self.d_hidden % self.num_attn_heads == 0, (
|
|
186
|
+
f"d_hidden ({self.d_hidden}) must be divisible by "
|
|
187
|
+
f"num_attn_heads ({self.num_attn_heads})"
|
|
188
|
+
)
|
|
189
|
+
assert 0.0 <= self.role_prior_blend <= 1.0, (
|
|
190
|
+
"role_prior_blend must be in [0.0, 1.0]"
|
|
191
|
+
)
|
|
192
|
+
if self.early_stop_threshold is not None:
|
|
193
|
+
assert 0.5 < self.early_stop_threshold <= 1.0, (
|
|
194
|
+
"early_stop_threshold must be in (0.5, 1.0]"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def to_dict(self):
|
|
198
|
+
return asdict(self)
|
|
199
|
+
|
|
200
|
+
@classmethod
|
|
201
|
+
def from_dict(cls, config_dict):
|
|
202
|
+
# Filter out unknown keys for forward compatibility
|
|
203
|
+
valid_keys = {f.name for f in cls.__dataclass_fields__.values()}
|
|
204
|
+
filtered = {k: v for k, v in config_dict.items() if k in valid_keys}
|
|
205
|
+
return cls(**filtered)
|
|
206
|
+
|
|
207
|
+
def save_pretrained(self, save_directory: str):
|
|
208
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
209
|
+
config_file = os.path.join(save_directory, "config.json")
|
|
210
|
+
with open(config_file, "w", encoding="utf-8") as f:
|
|
211
|
+
json.dump(self.to_dict(), f, indent=2, sort_keys=True)
|
|
212
|
+
|
|
213
|
+
@classmethod
|
|
214
|
+
def from_pretrained(cls, pretrained_model_name_or_path: str):
|
|
215
|
+
if os.path.isdir(pretrained_model_name_or_path):
|
|
216
|
+
config_file = os.path.join(
|
|
217
|
+
pretrained_model_name_or_path, "config.json"
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
from huggingface_hub import hf_hub_download
|
|
221
|
+
config_file = hf_hub_download(
|
|
222
|
+
repo_id=pretrained_model_name_or_path,
|
|
223
|
+
filename="config.json",
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
with open(config_file, "r", encoding="utf-8") as f:
|
|
227
|
+
config_dict = json.load(f)
|
|
228
|
+
return cls.from_dict(config_dict)
|
cida_plugin/consensus.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""
|
|
2
|
+
consensus.py — Упрощённый ConsensusAggregator
|
|
3
|
+
|
|
4
|
+
Ключевые исправления:
|
|
5
|
+
|
|
6
|
+
1. УБРАН ReliabilityTracker
|
|
7
|
+
Причина: круговая зависимость rho ← accuracy ← predictions ← weights(rho).
|
|
8
|
+
На малых датасетах вырождается. Без теоретической гарантии сходимости.
|
|
9
|
+
Замена: статические веса ролей (обоснованные a priori).
|
|
10
|
+
|
|
11
|
+
2. ИСПРАВЛЕН ConsensusAggregator
|
|
12
|
+
БЫЛО: PoE = log p ∝ Σ w_i · log(b_i)
|
|
13
|
+
ПРОБЛЕМА: PoE предполагает НЕЗАВИСИМЫХ экспертов.
|
|
14
|
+
Агенты CIDA КОРРЕЛИРОВАНЫ — общий энкодер + communication между агентами.
|
|
15
|
+
PoE на коррелированных агентах → усиление shared bias → overconfidence.
|
|
16
|
+
|
|
17
|
+
СТАЛО: Weighted Mean + Disagreement-as-Uncertainty
|
|
18
|
+
p_final = Σ w_i · b_i (взвешенное среднее)
|
|
19
|
+
uncertainty = std(b) по агентам
|
|
20
|
+
|
|
21
|
+
Математическое обоснование:
|
|
22
|
+
Для коррелированных экспертов с ковариационной матрицей Σ:
|
|
23
|
+
Var(Σ w_i · b_i) = w^T · Σ · w (правильно учитывает корреляции)
|
|
24
|
+
Var(PoE(b_i)) ≠ w^T · Σ · w (PoE неверно предполагает Σ = σ²I)
|
|
25
|
+
При положительной корреляции: PoE под-оценивает variance → overconfidence.
|
|
26
|
+
|
|
27
|
+
3. УПРОЩЁН интерфейс
|
|
28
|
+
Убраны: epoch_fraction, temperature scheduling, normalize_components флаг.
|
|
29
|
+
Всё это компенсировало нестабильность PoE на коррелированных агентах.
|
|
30
|
+
При правильной агрегации это не нужно.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import torch
|
|
34
|
+
import torch.nn as nn
|
|
35
|
+
import torch.nn.functional as F
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Статические веса ролей (a priori обоснованы ролями)
|
|
39
|
+
# Интегратор = наибольший вес (его роль — синтезировать)
|
|
40
|
+
# Скептик = наименьший вес (его роль — поднимать вопросы, не решать)
|
|
41
|
+
ROLE_WEIGHTS = torch.tensor([0.28, 0.28, 0.12, 0.32]) # P, D, S, I (сумма = 1.0)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ConsensusAggregator(nn.Module):
|
|
45
|
+
"""
|
|
46
|
+
Weighted Mean консенсус для КОРРЕЛИРОВАННЫХ агентов.
|
|
47
|
+
|
|
48
|
+
v4: Dynamic Trust Weights (CortexDebate MDM, ACL 2025).
|
|
49
|
+
Статические веса заменяются на EMA trust — вес агента растёт
|
|
50
|
+
когда его предсказания совпадают с финальным консенсусом.
|
|
51
|
+
|
|
52
|
+
trust_i += γ·(1[b_i_correct] - trust_i) per batch, detached from main graph
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
num_agents: int = 4,
|
|
58
|
+
multi_label: bool = False,
|
|
59
|
+
use_dynamic_trust: bool = False,
|
|
60
|
+
trust_ema_gamma: float = 0.9,
|
|
61
|
+
):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.multi_label = multi_label
|
|
64
|
+
self.use_dynamic_trust = use_dynamic_trust
|
|
65
|
+
self.trust_ema_gamma = trust_ema_gamma
|
|
66
|
+
|
|
67
|
+
# Статические веса ролей — обучаемые для адаптации к задаче
|
|
68
|
+
init_weights = ROLE_WEIGHTS[:num_agents].clone()
|
|
69
|
+
init_weights = init_weights / init_weights.sum()
|
|
70
|
+
self.role_weights = nn.Parameter(init_weights)
|
|
71
|
+
|
|
72
|
+
# v4: Dynamic trust (EMA buffer, не участвует в градиентах)
|
|
73
|
+
if use_dynamic_trust:
|
|
74
|
+
self.register_buffer(
|
|
75
|
+
'trust_scores',
|
|
76
|
+
torch.ones(num_agents) / num_agents
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def weights(self) -> torch.Tensor:
|
|
81
|
+
"""Нормализованные веса ролей (в simplex)."""
|
|
82
|
+
if self.use_dynamic_trust and hasattr(self, 'trust_scores'):
|
|
83
|
+
return self.trust_scores
|
|
84
|
+
return F.softmax(self.role_weights, dim=0) # (M,)
|
|
85
|
+
|
|
86
|
+
def update_trust(self, b: torch.Tensor, p_final: torch.Tensor):
|
|
87
|
+
"""
|
|
88
|
+
Обновляет trust scores на основе согласованности агента с консенсусом.
|
|
89
|
+
Вызывается ПОСЛЕ forward(), detached от основного графа.
|
|
90
|
+
|
|
91
|
+
b: (B, M, K) — убеждения агентов
|
|
92
|
+
p_final: (B, K) — финальный консенсус
|
|
93
|
+
"""
|
|
94
|
+
if not self.use_dynamic_trust:
|
|
95
|
+
return
|
|
96
|
+
with torch.no_grad():
|
|
97
|
+
# Per-agent correctness = 1 - L1(b_i, p_final)
|
|
98
|
+
correctness = 1.0 - (b - p_final.unsqueeze(1)).abs().mean(dim=(0, 2)) # (M,)
|
|
99
|
+
gamma = self.trust_ema_gamma
|
|
100
|
+
self.trust_scores = gamma * self.trust_scores + (1 - gamma) * correctness
|
|
101
|
+
self.trust_scores = self.trust_scores / self.trust_scores.sum()
|
|
102
|
+
|
|
103
|
+
def forward(
|
|
104
|
+
self,
|
|
105
|
+
b: torch.Tensor, # (B, M, K)
|
|
106
|
+
u: torch.Tensor = None, # (B, M, 1) — опционально
|
|
107
|
+
u_epi: torch.Tensor = None, # (B, M, 1) — v4: epistemic
|
|
108
|
+
u_alea: torch.Tensor = None, # (B, M, 1) — v4: aleatoric
|
|
109
|
+
) -> tuple:
|
|
110
|
+
"""
|
|
111
|
+
Returns:
|
|
112
|
+
p_final: (B, K) — консенсус
|
|
113
|
+
uncertainty: (B,) — мера неопределённости
|
|
114
|
+
disagreement:(B,) — мера несогласия агентов
|
|
115
|
+
"""
|
|
116
|
+
B, M, K = b.shape
|
|
117
|
+
w = self.weights # (M,) — нормализованные role weights
|
|
118
|
+
|
|
119
|
+
# ── Взвешенное среднее (корректно для коррелированных агентов) ────────
|
|
120
|
+
p_final = (w.view(1, M, 1) * b).sum(dim=1) # (B, K)
|
|
121
|
+
|
|
122
|
+
if not self.multi_label:
|
|
123
|
+
# FIX v4: Нормализация через log-softmax — численно стабильна
|
|
124
|
+
p_final = F.softmax(p_final.log().clamp(-10, 10), dim=-1)
|
|
125
|
+
|
|
126
|
+
# ── Несогласие = std убеждений по агентам (со стабильным градиентом) ──
|
|
127
|
+
var = b.var(dim=1, unbiased=True)
|
|
128
|
+
disagreement = torch.sqrt(var + 1e-8).mean(dim=-1) # (B,)
|
|
129
|
+
|
|
130
|
+
# ── Неопределённость = несогласие + собственная uncertainty агентов ──
|
|
131
|
+
if u is not None:
|
|
132
|
+
agent_u = (w.view(1, M, 1) * u).sum(dim=1).squeeze(-1) # (B,)
|
|
133
|
+
uncertainty = 0.5 * disagreement + 0.5 * agent_u
|
|
134
|
+
else:
|
|
135
|
+
uncertainty = disagreement
|
|
136
|
+
|
|
137
|
+
# v4: Обновляем dynamic trust (если включён)
|
|
138
|
+
if self.training and self.use_dynamic_trust:
|
|
139
|
+
self.update_trust(b, p_final)
|
|
140
|
+
|
|
141
|
+
return p_final, uncertainty, disagreement
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class HaltingPredictor(nn.Module):
|
|
145
|
+
"""
|
|
146
|
+
Предсказатель остановки (упрощённый).
|
|
147
|
+
Используется для ACT: когда агенты достаточно согласны — останавливаемся.
|
|
148
|
+
|
|
149
|
+
В отличие от оригинала, решение основано на НЕСОГЛАСИИ (наблюдаемо),
|
|
150
|
+
а не на выученном MLP поверх всех состояний (трудно интерпретировать).
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def __init__(self, agreement_threshold: float = 0.15):
|
|
154
|
+
"""
|
|
155
|
+
agreement_threshold: при несогласии ниже этого порога — останавливаемся.
|
|
156
|
+
"""
|
|
157
|
+
super().__init__()
|
|
158
|
+
self.threshold = agreement_threshold
|
|
159
|
+
# Небольшой выученный bias для подстройки порога под задачу
|
|
160
|
+
self.bias = nn.Parameter(torch.tensor(0.0))
|
|
161
|
+
|
|
162
|
+
def should_halt(self, disagreement: torch.Tensor) -> torch.Tensor:
|
|
163
|
+
"""
|
|
164
|
+
disagreement: (B,)
|
|
165
|
+
Returns: (B,) bool — True если стоит остановиться
|
|
166
|
+
"""
|
|
167
|
+
effective_threshold = self.threshold + self.bias.sigmoid() * 0.1
|
|
168
|
+
return disagreement < effective_threshold
|
|
169
|
+
|
|
170
|
+
def halting_probability(self, disagreement: torch.Tensor) -> torch.Tensor:
|
|
171
|
+
"""
|
|
172
|
+
Мягкая версия: вероятность остановки ∈ [0, 1].
|
|
173
|
+
disagreement: (B,)
|
|
174
|
+
Returns: (B,)
|
|
175
|
+
"""
|
|
176
|
+
effective_threshold = self.threshold + self.bias.sigmoid() * 0.1
|
|
177
|
+
# Sigmoid(-scale * (d - threshold)): высокая вероятность при d < threshold
|
|
178
|
+
return torch.sigmoid(-10.0 * (disagreement - effective_threshold))
|