cida-plugin 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,62 @@
1
+ """
2
+ CIDA-Plugin: Universal Evidence-Grounded Multi-Agent Deliberation Layer (v3).
3
+
4
+ Использование (минимальное):
5
+ from cida_plugin import CIDAPlugin, CIDAPluginConfig
6
+
7
+ cfg = CIDAPluginConfig(d_input=768, num_classes=2)
8
+ plugin = CIDAPlugin(cfg)
9
+ out = plugin(pooled_output) # pooled_output: (B, 768)
10
+ logits = out["p_final"] # (B, 2)
11
+
12
+ Использование (с seq_output для evidence pointers):
13
+ out = plugin(pooled_output, seq_output=hidden_states, mask=attention_mask)
14
+ """
15
+
16
+ from .config import CIDAPluginConfig
17
+ from .core import CIDAPlugin
18
+ from .agent import (
19
+ AgentState,
20
+ apply_role_prior,
21
+ compute_role_orthogonality_loss,
22
+ RoleNames,
23
+ )
24
+ from .deliberation import (
25
+ AgentEvidenceExtractor,
26
+ MessageFormulator,
27
+ CounterargumentCommunication,
28
+ AgentUpdater,
29
+ )
30
+ from .consensus import ConsensusAggregator, HaltingPredictor
31
+ from .losses import CIDALoss
32
+ from .diagnostics import DebateDiagnostics
33
+ from .ttt import TestTimeTrainer
34
+ from .liquid_dynamics import LiquidDeliberationSolver
35
+ from .hf import wrap_hf_model, HFModelWithCIDA
36
+
37
+ __all__ = [
38
+ # Главный интерфейс
39
+ "CIDAPlugin",
40
+ "CIDAPluginConfig",
41
+ "wrap_hf_model",
42
+ "HFModelWithCIDA",
43
+ # Потери и диагностика
44
+ "CIDALoss",
45
+ "DebateDiagnostics",
46
+ # [v5] Test-Time Training
47
+ "TestTimeTrainer",
48
+ # [v5] Liquid Neural ODE Dynamics
49
+ "LiquidDeliberationSolver",
50
+ # Компоненты (для кастомных архитектур)
51
+ "AgentState",
52
+ "apply_role_prior",
53
+ "compute_role_orthogonality_loss",
54
+ "RoleNames",
55
+ "AgentEvidenceExtractor",
56
+ "MessageFormulator",
57
+ "CounterargumentCommunication",
58
+ "AgentUpdater",
59
+ "ConsensusAggregator",
60
+ "HaltingPredictor",
61
+ ]
62
+
cida_plugin/agent.py ADDED
@@ -0,0 +1,198 @@
1
+ """
2
+ agent.py — CIDA-Plugin Agent (упрощённая версия)
3
+
4
+ Ключевое изменение:
5
+ БЫЛО: RoleEmbeddings (аддитивный bias) + debate_loss + role_spec_loss + orth_loss
6
+ СТАЛО: ROLE_PRIORS — структурные априорные убеждения агентов
7
+
8
+ Почему это лучше:
9
+ RoleEmbeddings создаёт СТАТИСТИЧЕСКОЕ различие (агенты начинают чуть
10
+ по-разному и потом схожи). Role priors создают СТРУКТУРНОЕ различие:
11
+
12
+ Прокурор → P(y=1) = 0.85 до просмотра данных
13
+ Защитник → P(y=1) = 0.15 до просмотра данных
14
+ Скептик → P(y=1) = 0.50 (максимальная неопределённость)
15
+ Интегратор → P(y=1) = данные (без prior, объективен)
16
+
17
+ Это байесовски обоснованная альтернатива: агенты представляют разные
18
+ prior beliefs, а консенсус — это posterior после объединения.
19
+
20
+ Математически гарантировано:
21
+ |b_0 - b_1| ≥ BLEND * |P_0 - P_1| = 0.3 * 0.7 = 0.21 > 0
22
+ То есть несогласие является СВОЙСТВОМ архитектуры, а не результатом loss.
23
+ """
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from dataclasses import dataclass
29
+ from typing import Optional
30
+
31
+
32
+ # ── Роли агентов: структурные prior beliefs ───────────────────────────────────
33
+ #
34
+ # Для binary/multi-label: prior = вероятность positive класса
35
+ # Для multi-class: prior = вектор вероятностей классов
36
+ #
37
+ # BLEND = насколько сильно prior смешивается с данными
38
+ # 0.0 → prior игнорируется (все агенты одинаковы)
39
+ # 0.3 → 30% prior, 70% данные (рекомендуется)
40
+ # 1.0 → агент игнорирует данные (только prior)
41
+
42
+ ROLE_PRIOR_POSITIVE = [
43
+ 0.85, # Прокурор: склонен считать что патология есть
44
+ 0.15, # Защитник: склонен считать что патологии нет
45
+ 0.50, # Скептик: максимальная неопределённость
46
+ None, # Интегратор: без prior, смотрит объективно
47
+ ]
48
+
49
+ ROLE_PRIOR_BLEND = 0.30 # Сила prior относительно данных
50
+
51
+
52
+ @dataclass
53
+ class AgentState:
54
+ """Явное представление состояния агента."""
55
+ s: torch.Tensor # (B, d) - Hidden state
56
+ b: torch.Tensor # (B, K) - Belief (posterior после prior blending)
57
+ u: torch.Tensor # (B, 1) - Total uncertainty (backward compat)
58
+ p: torch.Tensor # (B, n) - Evidence pointer
59
+ e: torch.Tensor # (B, d) - Evidence vector
60
+ alpha: torch.Tensor # (B, K) - Dirichlet parameters
61
+ # v4: Разделение неопределённости (Josang 2002 + Meinert 2024)
62
+ u_epi: torch.Tensor = None # (B, 1) - Epistemic: K²/alpha_sum² (vacuity)
63
+ u_alea: torch.Tensor = None # (B, 1) - Aleatoric: 1 - max(b) (data noise)
64
+
65
+
66
+ def apply_role_prior(
67
+ b_data: torch.Tensor,
68
+ role_idx: int,
69
+ multi_label: bool = False,
70
+ positive_class_idx: int = 1,
71
+ ) -> torch.Tensor:
72
+ """
73
+ Смешивает data-driven belief с role-specific prior (обобщённая версия для любого K).
74
+ """
75
+ role = role_idx % 4
76
+ if role == 3: # Интегратор: без prior
77
+ return b_data
78
+
79
+ K = b_data.size(-1)
80
+ device = b_data.device
81
+ dtype = b_data.dtype
82
+
83
+ if multi_label or K == 1:
84
+ prior_val = ROLE_PRIOR_POSITIVE[role]
85
+ prior = torch.full_like(b_data, prior_val)
86
+ elif K == 2:
87
+ prior_val = ROLE_PRIOR_POSITIVE[role]
88
+ prior = torch.tensor(
89
+ [1.0 - prior_val, prior_val],
90
+ device=device,
91
+ dtype=dtype,
92
+ ).expand_as(b_data)
93
+ else:
94
+ # Multi-class
95
+ if role == 0: # Prosecutor: biased to positive class
96
+ prior_val = 0.85
97
+ uniform_val = (1.0 - prior_val) / max(K - 1, 1)
98
+ prior = torch.full_like(b_data, uniform_val)
99
+ prior = prior.clone()
100
+ prior[:, positive_class_idx] = prior_val
101
+ elif role == 1: # Defender: biased against positive class (pro-alternative classes)
102
+ prior_val = 0.15
103
+ uniform_val = (1.0 - prior_val) / max(K - 1, 1)
104
+ prior = torch.full_like(b_data, uniform_val)
105
+ prior = prior.clone()
106
+ prior[:, positive_class_idx] = prior_val
107
+ elif role == 2: # Skeptic: neutral / high uncertainty (uniform)
108
+ prior = torch.full_like(b_data, 1.0 / K)
109
+
110
+ prior = prior / prior.sum(dim=-1, keepdim=True)
111
+
112
+ return (1.0 - ROLE_PRIOR_BLEND) * b_data + ROLE_PRIOR_BLEND * prior
113
+
114
+
115
+ def apply_role_priors_batched(
116
+ b_data: torch.Tensor,
117
+ M: int,
118
+ multi_label: bool = False,
119
+ positive_class_idx: int = 1,
120
+ ) -> torch.Tensor:
121
+ """
122
+ Векторизованная версия apply_role_prior для батча агентов сразу (B, M, K).
123
+ """
124
+ B, _, K = b_data.shape
125
+ device = b_data.device
126
+ dtype = b_data.dtype
127
+
128
+ priors = torch.zeros((M, K), device=device, dtype=dtype)
129
+ blend_mask = torch.full((M, 1), ROLE_PRIOR_BLEND, device=device, dtype=dtype)
130
+
131
+ for i in range(M):
132
+ role = i % 4
133
+ if role == 3: # Интегратор: без prior
134
+ blend_mask[i, 0] = 0.0
135
+ continue
136
+
137
+ if multi_label or K == 1:
138
+ priors[i, :] = ROLE_PRIOR_POSITIVE[role]
139
+ elif K == 2:
140
+ priors[i, 0] = 1.0 - ROLE_PRIOR_POSITIVE[role]
141
+ priors[i, 1] = ROLE_PRIOR_POSITIVE[role]
142
+ else:
143
+ # Multi-class
144
+ if role == 0: # Prosecutor: biased to positive class
145
+ prior_val = 0.85
146
+ uniform_val = (1.0 - prior_val) / max(K - 1, 1)
147
+ priors[i, :] = uniform_val
148
+ priors[i, positive_class_idx] = prior_val
149
+ elif role == 1: # Defender: biased against positive class
150
+ prior_val = 0.15
151
+ uniform_val = (1.0 - prior_val) / max(K - 1, 1)
152
+ priors[i, :] = uniform_val
153
+ priors[i, positive_class_idx] = prior_val
154
+ elif role == 2: # Skeptic: neutral / uniform
155
+ priors[i, :] = 1.0 / K
156
+
157
+ priors[i] = priors[i] / priors[i].sum()
158
+
159
+ priors = priors.unsqueeze(0) # (1, M, K)
160
+ blend_mask = blend_mask.unsqueeze(0) # (1, M, 1)
161
+
162
+ return (1.0 - blend_mask) * b_data + blend_mask * priors
163
+
164
+
165
+
166
+ def compute_role_orthogonality_loss(agent_states: torch.Tensor) -> torch.Tensor:
167
+ """
168
+ Мягкая ортогональность через ПРЕДСТАВЛЕНИЯ, а не через веса.
169
+
170
+ БЫЛО: orth_loss через Gram(weights) — веса могут быть ортогональны,
171
+ но representations при этом коллапсируют (если вход низкоранговый).
172
+
173
+ СТАЛО: orth_loss через Gram(representations) — напрямую измеряем
174
+ то, что нас реально интересует.
175
+
176
+ agent_states: (B, M, d) — hidden states агентов
177
+ Returns: scalar — потери (0 при полной ортогональности)
178
+
179
+ Математически:
180
+ G_ij = <s_i, s_j> / (||s_i|| · ||s_j||) — косинусное сходство
181
+ Loss = mean((G - I)²) → 0 при G = I
182
+ """
183
+ B, M, d = agent_states.shape
184
+ # Усредняем по батчу, нормализуем
185
+ s_mean = agent_states.mean(0) # (M, d)
186
+ s_norm = F.normalize(s_mean, p=2, dim=-1) # (M, d)
187
+ gram = s_norm @ s_norm.T # (M, M)
188
+ eye = torch.eye(M, device=gram.device)
189
+ return ((gram - eye) ** 2).mean()
190
+
191
+
192
+ class RoleNames:
193
+ """Имена агентов для диагностики."""
194
+ NAMES = ["Prosecutor", "Defender", "Skeptic", "Integrator"]
195
+
196
+ @classmethod
197
+ def get(cls, idx: int) -> str:
198
+ return cls.NAMES[idx % 4]
cida_plugin/config.py ADDED
@@ -0,0 +1,228 @@
1
+ """
2
+ CIDAPluginConfig — единая точка конфигурации для CIDA-Plugin v3.
3
+ """
4
+ from dataclasses import dataclass, asdict
5
+ import json
6
+ import os
7
+
8
+
9
+ @dataclass
10
+ class CIDAPluginConfig:
11
+ """
12
+ Конфигурация универсального CIDA-Plugin слоя (v3 — упрощённая архитектура).
13
+
14
+ Параметры ядра
15
+ --------------
16
+ d_input : int
17
+ Размерность pooled_output от внешнего энкодера.
18
+ Примеры: 128 (bert-tiny), 768 (bert-base / distilbert), 1024 (DenseNet121).
19
+ d_hidden : int
20
+ Внутренняя размерность агентов плагина.
21
+ d_message : int
22
+ Размерность сообщений при коммуникации агентов.
23
+ num_agents : int
24
+ Количество агентов (рекомендуется 4: Prosecutor, Defender, Skeptic, Integrator).
25
+ num_classes : int
26
+ Число целевых классов.
27
+ max_rounds : int
28
+ Максимальное число раундов deliberation.
29
+ multi_label : bool
30
+ True для multi-label задач (BCE), False для single-label (CE).
31
+
32
+ Архитектурные параметры
33
+ -----------------------
34
+ num_attn_heads : int
35
+ Число голов в TransformerAgentUpdater (cross-attention).
36
+ Должен делить d_hidden.
37
+
38
+ Role Priors
39
+ -----------
40
+ role_prior_blend : float
41
+ Сила смешивания роли с данными (0.0 = без priors, 1.0 = только priors).
42
+ Рекомендуется 0.30 (30% prior, 70% данные).
43
+ positive_class_idx : int
44
+ Индекс "положительного" класса для multi-class задач (K>2).
45
+ По умолчанию 1, но может быть 0, 2 и т.д. в зависимости от датасета.
46
+
47
+ Loss гиперпараметры
48
+ -------------------
49
+ lambda_cal : float
50
+ Вес calibration loss (Brier score). Рекомендуется 0.3–0.5.
51
+ lambda_ac : float
52
+ Вес anti-collapse loss. Рекомендуется 0.1–0.3.
53
+ min_disagreement : float
54
+ Минимальный порог несогласия для anti-collapse.
55
+
56
+ Регуляризация
57
+ -------------
58
+ comm_dropout : float
59
+ Dropout на канале коммуникации между агентами.
60
+ early_stop_threshold : float or None
61
+ Порог уверенности для досрочной остановки deliberation (inference only).
62
+ Если None — всегда идёт max_rounds раундов.
63
+ freeze_input_proj : bool
64
+ Если True — входной проекционный слой не обучается.
65
+
66
+ Ablation Flags
67
+ --------------
68
+ abl_no_pointers : bool
69
+ Отключить evidence pointers.
70
+ abl_no_messages : bool
71
+ Отключить формулировку сообщений.
72
+ abl_no_communication : bool
73
+ Полностью отключить коммуникацию.
74
+
75
+ Perspective Projector
76
+ ---------------------
77
+ use_perspective_projector : bool
78
+ Если True — каждый агент имеет свою проекцию входа (per-agent view).
79
+ Если False — все агенты начинают с одного представления (expand).
80
+ По умолчанию False для упрощённой архитектуры v3.
81
+
82
+ Test-Time Training (TTT)
83
+ -------------------------
84
+ use_ttt : bool
85
+ Агенты адаптируют свои веса к каждому входу перед deliberation.
86
+ ttt_steps : int
87
+ Число шагов внутренней оптимизации (K).
88
+ ttt_lr : float
89
+ Learning rate для inner Adam.
90
+ ttt_mask_ratio : float
91
+ Доля маскируемых компонент скрытого состояния.
92
+
93
+ Liquid Neural ODE Dynamics
94
+ --------------------------
95
+ use_liquid_dynamics : bool
96
+ Заменяет дискретные раунды на непрерывное ODE: ds/dt = -s/τ + F(s,r,e).
97
+ liquid_solver : str
98
+ Метод интегрирования: 'euler', 'dopri5' и др.
99
+ liquid_atol : float
100
+ Абсолютная толерантность для адаптивных решателей.
101
+ liquid_rtol : float
102
+ Относительная толерантность для адаптивных решателей.
103
+ trajectory_save_every : int
104
+ Сохранять каждое N-е состояние в trajectory для ODE решателя.
105
+ Для адаптивных решателей (dopri5) рекомендуется 10-50 для экономии памяти.
106
+ Для euler можно использовать 1.
107
+ """
108
+ # ─── Размерности ────────────────────────────────────────────────────────────
109
+ d_input: int = 128
110
+ d_hidden: int = 128
111
+ d_message: int = 128
112
+
113
+ # ─── Агенты и deliberation ───────────────────────────────────────────────────
114
+ num_agents: int = 4
115
+ num_classes: int = 2
116
+ max_rounds: int = 3
117
+ multi_label: bool = False
118
+
119
+ # ─── Архитектурные параметры ─────────────────────────────────────────────────
120
+ num_attn_heads: int = 4
121
+
122
+ # ─── Role Priors ─────────────────────────────────────────────────────────────
123
+ role_prior_blend: float = 0.30
124
+ positive_class_idx: int = 1
125
+
126
+ # ─── Loss гиперпараметры ─────────────────────────────────────────────────────
127
+ lambda_cal: float = 0.4
128
+ lambda_ac: float = 0.2
129
+ min_disagreement: float = 0.08
130
+
131
+ # ─── Регуляризация ──────────────────────────────────────────────────────────
132
+ comm_dropout: float = 0.2
133
+ early_stop_threshold: float = 0.90
134
+ freeze_input_proj: bool = False
135
+
136
+ # ─── Ablation flags ─────────────────────────────────────────────────────────
137
+ abl_no_pointers: bool = False
138
+ abl_no_messages: bool = False
139
+ abl_no_communication: bool = False
140
+
141
+ # ─── Perspective Projector ─────────────────────────────────────────────────
142
+ use_perspective_projector: bool = True
143
+
144
+ # ─── Test-Time Training ──────────────────────────────────────────────────────
145
+ use_ttt: bool = False
146
+ ttt_steps: int = 3
147
+ ttt_lr: float = 1e-3
148
+ ttt_mask_ratio: float = 0.15
149
+
150
+ # ─── Liquid Neural ODE Dynamics ──────────────────────────────────────────────
151
+ use_liquid_dynamics: bool = False
152
+ liquid_solver: str = "euler"
153
+ liquid_atol: float = 1e-3
154
+ liquid_rtol: float = 1e-3
155
+ trajectory_save_every: int = 10
156
+
157
+ # ─── CIDA v4: Architectural Improvements ─────────────────────────────────────
158
+ # Disagreement-Routed Sparse Communication (Li et al. EMNLP 2024)
159
+ # Каждый агент слушает только top-K по disagreement, не всех.
160
+ sparse_comm_k: int = 2
161
+
162
+ # Anonymous Message Passing (Choi et al. ACL 2025)
163
+ # Убирает s (identity) из сообщений → агент слышит аргумент, а не авторитет.
164
+ anonymous_messages: bool = True
165
+
166
+ # Dynamic Trust Weights (CortexDebate MDM, ACL 2025)
167
+ # EMA trust вместо статических ROLE_WEIGHTS.
168
+ use_dynamic_trust: bool = False
169
+ trust_ema_gamma: float = 0.9
170
+
171
+ # ─── Kairos Dynamic Deliberation (KDD) ───────────────────────────────────────
172
+ # Time-Travel (Rollback) mechanics
173
+ enable_rollback: bool = False
174
+ rollback_noise_std: float = 0.05
175
+
176
+ # Latent Superposition
177
+ # If final disagreement is > threshold, output Superposition State
178
+ superposition_threshold: float = 0.7
179
+
180
+ def __post_init__(self):
181
+ assert self.d_input > 0, "d_input must be positive"
182
+ assert self.num_agents >= 2, "Need at least 2 agents"
183
+ assert self.num_classes >= 2, "Need at least 2 classes"
184
+ assert self.max_rounds >= 1, "Need at least 1 deliberation round"
185
+ assert self.d_hidden % self.num_attn_heads == 0, (
186
+ f"d_hidden ({self.d_hidden}) must be divisible by "
187
+ f"num_attn_heads ({self.num_attn_heads})"
188
+ )
189
+ assert 0.0 <= self.role_prior_blend <= 1.0, (
190
+ "role_prior_blend must be in [0.0, 1.0]"
191
+ )
192
+ if self.early_stop_threshold is not None:
193
+ assert 0.5 < self.early_stop_threshold <= 1.0, (
194
+ "early_stop_threshold must be in (0.5, 1.0]"
195
+ )
196
+
197
+ def to_dict(self):
198
+ return asdict(self)
199
+
200
+ @classmethod
201
+ def from_dict(cls, config_dict):
202
+ # Filter out unknown keys for forward compatibility
203
+ valid_keys = {f.name for f in cls.__dataclass_fields__.values()}
204
+ filtered = {k: v for k, v in config_dict.items() if k in valid_keys}
205
+ return cls(**filtered)
206
+
207
+ def save_pretrained(self, save_directory: str):
208
+ os.makedirs(save_directory, exist_ok=True)
209
+ config_file = os.path.join(save_directory, "config.json")
210
+ with open(config_file, "w", encoding="utf-8") as f:
211
+ json.dump(self.to_dict(), f, indent=2, sort_keys=True)
212
+
213
+ @classmethod
214
+ def from_pretrained(cls, pretrained_model_name_or_path: str):
215
+ if os.path.isdir(pretrained_model_name_or_path):
216
+ config_file = os.path.join(
217
+ pretrained_model_name_or_path, "config.json"
218
+ )
219
+ else:
220
+ from huggingface_hub import hf_hub_download
221
+ config_file = hf_hub_download(
222
+ repo_id=pretrained_model_name_or_path,
223
+ filename="config.json",
224
+ )
225
+
226
+ with open(config_file, "r", encoding="utf-8") as f:
227
+ config_dict = json.load(f)
228
+ return cls.from_dict(config_dict)
@@ -0,0 +1,178 @@
1
+ """
2
+ consensus.py — Упрощённый ConsensusAggregator
3
+
4
+ Ключевые исправления:
5
+
6
+ 1. УБРАН ReliabilityTracker
7
+ Причина: круговая зависимость rho ← accuracy ← predictions ← weights(rho).
8
+ На малых датасетах вырождается. Без теоретической гарантии сходимости.
9
+ Замена: статические веса ролей (обоснованные a priori).
10
+
11
+ 2. ИСПРАВЛЕН ConsensusAggregator
12
+ БЫЛО: PoE = log p ∝ Σ w_i · log(b_i)
13
+ ПРОБЛЕМА: PoE предполагает НЕЗАВИСИМЫХ экспертов.
14
+ Агенты CIDA КОРРЕЛИРОВАНЫ — общий энкодер + communication между агентами.
15
+ PoE на коррелированных агентах → усиление shared bias → overconfidence.
16
+
17
+ СТАЛО: Weighted Mean + Disagreement-as-Uncertainty
18
+ p_final = Σ w_i · b_i (взвешенное среднее)
19
+ uncertainty = std(b) по агентам
20
+
21
+ Математическое обоснование:
22
+ Для коррелированных экспертов с ковариационной матрицей Σ:
23
+ Var(Σ w_i · b_i) = w^T · Σ · w (правильно учитывает корреляции)
24
+ Var(PoE(b_i)) ≠ w^T · Σ · w (PoE неверно предполагает Σ = σ²I)
25
+ При положительной корреляции: PoE под-оценивает variance → overconfidence.
26
+
27
+ 3. УПРОЩЁН интерфейс
28
+ Убраны: epoch_fraction, temperature scheduling, normalize_components флаг.
29
+ Всё это компенсировало нестабильность PoE на коррелированных агентах.
30
+ При правильной агрегации это не нужно.
31
+ """
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+
37
+
38
+ # Статические веса ролей (a priori обоснованы ролями)
39
+ # Интегратор = наибольший вес (его роль — синтезировать)
40
+ # Скептик = наименьший вес (его роль — поднимать вопросы, не решать)
41
+ ROLE_WEIGHTS = torch.tensor([0.28, 0.28, 0.12, 0.32]) # P, D, S, I (сумма = 1.0)
42
+
43
+
44
+ class ConsensusAggregator(nn.Module):
45
+ """
46
+ Weighted Mean консенсус для КОРРЕЛИРОВАННЫХ агентов.
47
+
48
+ v4: Dynamic Trust Weights (CortexDebate MDM, ACL 2025).
49
+ Статические веса заменяются на EMA trust — вес агента растёт
50
+ когда его предсказания совпадают с финальным консенсусом.
51
+
52
+ trust_i += γ·(1[b_i_correct] - trust_i) per batch, detached from main graph
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ num_agents: int = 4,
58
+ multi_label: bool = False,
59
+ use_dynamic_trust: bool = False,
60
+ trust_ema_gamma: float = 0.9,
61
+ ):
62
+ super().__init__()
63
+ self.multi_label = multi_label
64
+ self.use_dynamic_trust = use_dynamic_trust
65
+ self.trust_ema_gamma = trust_ema_gamma
66
+
67
+ # Статические веса ролей — обучаемые для адаптации к задаче
68
+ init_weights = ROLE_WEIGHTS[:num_agents].clone()
69
+ init_weights = init_weights / init_weights.sum()
70
+ self.role_weights = nn.Parameter(init_weights)
71
+
72
+ # v4: Dynamic trust (EMA buffer, не участвует в градиентах)
73
+ if use_dynamic_trust:
74
+ self.register_buffer(
75
+ 'trust_scores',
76
+ torch.ones(num_agents) / num_agents
77
+ )
78
+
79
+ @property
80
+ def weights(self) -> torch.Tensor:
81
+ """Нормализованные веса ролей (в simplex)."""
82
+ if self.use_dynamic_trust and hasattr(self, 'trust_scores'):
83
+ return self.trust_scores
84
+ return F.softmax(self.role_weights, dim=0) # (M,)
85
+
86
+ def update_trust(self, b: torch.Tensor, p_final: torch.Tensor):
87
+ """
88
+ Обновляет trust scores на основе согласованности агента с консенсусом.
89
+ Вызывается ПОСЛЕ forward(), detached от основного графа.
90
+
91
+ b: (B, M, K) — убеждения агентов
92
+ p_final: (B, K) — финальный консенсус
93
+ """
94
+ if not self.use_dynamic_trust:
95
+ return
96
+ with torch.no_grad():
97
+ # Per-agent correctness = 1 - L1(b_i, p_final)
98
+ correctness = 1.0 - (b - p_final.unsqueeze(1)).abs().mean(dim=(0, 2)) # (M,)
99
+ gamma = self.trust_ema_gamma
100
+ self.trust_scores = gamma * self.trust_scores + (1 - gamma) * correctness
101
+ self.trust_scores = self.trust_scores / self.trust_scores.sum()
102
+
103
+ def forward(
104
+ self,
105
+ b: torch.Tensor, # (B, M, K)
106
+ u: torch.Tensor = None, # (B, M, 1) — опционально
107
+ u_epi: torch.Tensor = None, # (B, M, 1) — v4: epistemic
108
+ u_alea: torch.Tensor = None, # (B, M, 1) — v4: aleatoric
109
+ ) -> tuple:
110
+ """
111
+ Returns:
112
+ p_final: (B, K) — консенсус
113
+ uncertainty: (B,) — мера неопределённости
114
+ disagreement:(B,) — мера несогласия агентов
115
+ """
116
+ B, M, K = b.shape
117
+ w = self.weights # (M,) — нормализованные role weights
118
+
119
+ # ── Взвешенное среднее (корректно для коррелированных агентов) ────────
120
+ p_final = (w.view(1, M, 1) * b).sum(dim=1) # (B, K)
121
+
122
+ if not self.multi_label:
123
+ # FIX v4: Нормализация через log-softmax — численно стабильна
124
+ p_final = F.softmax(p_final.log().clamp(-10, 10), dim=-1)
125
+
126
+ # ── Несогласие = std убеждений по агентам (со стабильным градиентом) ──
127
+ var = b.var(dim=1, unbiased=True)
128
+ disagreement = torch.sqrt(var + 1e-8).mean(dim=-1) # (B,)
129
+
130
+ # ── Неопределённость = несогласие + собственная uncertainty агентов ──
131
+ if u is not None:
132
+ agent_u = (w.view(1, M, 1) * u).sum(dim=1).squeeze(-1) # (B,)
133
+ uncertainty = 0.5 * disagreement + 0.5 * agent_u
134
+ else:
135
+ uncertainty = disagreement
136
+
137
+ # v4: Обновляем dynamic trust (если включён)
138
+ if self.training and self.use_dynamic_trust:
139
+ self.update_trust(b, p_final)
140
+
141
+ return p_final, uncertainty, disagreement
142
+
143
+
144
+ class HaltingPredictor(nn.Module):
145
+ """
146
+ Предсказатель остановки (упрощённый).
147
+ Используется для ACT: когда агенты достаточно согласны — останавливаемся.
148
+
149
+ В отличие от оригинала, решение основано на НЕСОГЛАСИИ (наблюдаемо),
150
+ а не на выученном MLP поверх всех состояний (трудно интерпретировать).
151
+ """
152
+
153
+ def __init__(self, agreement_threshold: float = 0.15):
154
+ """
155
+ agreement_threshold: при несогласии ниже этого порога — останавливаемся.
156
+ """
157
+ super().__init__()
158
+ self.threshold = agreement_threshold
159
+ # Небольшой выученный bias для подстройки порога под задачу
160
+ self.bias = nn.Parameter(torch.tensor(0.0))
161
+
162
+ def should_halt(self, disagreement: torch.Tensor) -> torch.Tensor:
163
+ """
164
+ disagreement: (B,)
165
+ Returns: (B,) bool — True если стоит остановиться
166
+ """
167
+ effective_threshold = self.threshold + self.bias.sigmoid() * 0.1
168
+ return disagreement < effective_threshold
169
+
170
+ def halting_probability(self, disagreement: torch.Tensor) -> torch.Tensor:
171
+ """
172
+ Мягкая версия: вероятность остановки ∈ [0, 1].
173
+ disagreement: (B,)
174
+ Returns: (B,)
175
+ """
176
+ effective_threshold = self.threshold + self.bias.sigmoid() * 0.1
177
+ # Sigmoid(-scale * (d - threshold)): высокая вероятность при d < threshold
178
+ return torch.sigmoid(-10.0 * (disagreement - effective_threshold))