cida-plugin 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cida_plugin/hf.py ADDED
@@ -0,0 +1,194 @@
1
+ """
2
+ hf.py — Интеграция CIDA-Plugin с Hugging Face Models.
3
+
4
+ Позволяет обернуть любую модель классификации последовательностей (Sequence Classification)
5
+ в CIDA-Plugin в одну строчку кода, заменяя стандартную линейную голову.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from typing import Optional, Union, Dict, Any
11
+
12
+ try:
13
+ from transformers.modeling_outputs import SequenceClassifierOutput
14
+ _HAS_TRANSFORMERS = True
15
+ except ImportError:
16
+ # Заглушка на случай отсутствия transformers (например, при тестировании голого PyTorch)
17
+ class SequenceClassifierOutput:
18
+ def __init__(self, loss=None, logits=None, hidden_states=None, attentions=None):
19
+ self.loss = loss
20
+ self.logits = logits
21
+ self.hidden_states = hidden_states
22
+ self.attentions = attentions
23
+ _HAS_TRANSFORMERS = False
24
+
25
+ from .core import CIDAPlugin
26
+ from .config import CIDAPluginConfig
27
+ from .losses import CIDALoss
28
+
29
+
30
+ class HFModelWithCIDA(nn.Module):
31
+ """
32
+ Обертка над предобученной Hugging Face моделью классификации текстов.
33
+ Перехватывает скрытые состояния базового энкодера и пропускает их через CIDA-Plugin.
34
+ """
35
+ def __init__(
36
+ self,
37
+ hf_model: nn.Module,
38
+ cida_config: Optional[CIDAPluginConfig] = None,
39
+ **cida_kwargs: Any
40
+ ):
41
+ super().__init__()
42
+
43
+ # Сохраняем исходную конфигурацию
44
+ self.config = getattr(hf_model, "config", None)
45
+
46
+ # 1. Извлекаем базовый энкодер (transformer backbone)
47
+ if hasattr(hf_model, "base_model"):
48
+ self.encoder = hf_model.base_model
49
+ else:
50
+ # Fallback для кастомных моделей
51
+ self.encoder = hf_model
52
+
53
+ # 2. Определяем размерности
54
+ if self.config is not None:
55
+ d_model = getattr(self.config, "hidden_size", None) or getattr(self.config, "d_model", None)
56
+ num_labels = getattr(self.config, "num_labels", 2)
57
+ else:
58
+ d_model = cida_kwargs.get("d_input")
59
+ num_labels = cida_kwargs.get("num_classes", 2)
60
+
61
+ if d_model is None:
62
+ raise ValueError(
63
+ "Не удалось автоматически определить hidden_size из конфигурации модели HF. "
64
+ "Пожалуйста, передайте параметр d_input явно в cida_kwargs."
65
+ )
66
+
67
+ # 3. Инициализируем CIDA-Plugin
68
+ if cida_config is None:
69
+ cida_config = CIDAPluginConfig(
70
+ d_input=d_model,
71
+ num_classes=num_labels,
72
+ **cida_kwargs
73
+ )
74
+ else:
75
+ # Обновляем поля, если они переданы пустые
76
+ if cida_config.d_input is None:
77
+ cida_config.d_input = d_model
78
+ if cida_config.num_classes is None:
79
+ cida_config.num_classes = num_labels
80
+
81
+ self.plugin = CIDAPlugin(cida_config)
82
+ self.num_labels = num_labels
83
+
84
+ def forward(
85
+ self,
86
+ input_ids: Optional[torch.Tensor] = None,
87
+ attention_mask: Optional[torch.Tensor] = None,
88
+ token_type_ids: Optional[torch.Tensor] = None,
89
+ position_ids: Optional[torch.Tensor] = None,
90
+ head_mask: Optional[torch.Tensor] = None,
91
+ inputs_embeds: Optional[torch.Tensor] = None,
92
+ labels: Optional[torch.Tensor] = None,
93
+ output_attentions: Optional[bool] = None,
94
+ output_hidden_states: Optional[bool] = None,
95
+ return_dict: Optional[bool] = None,
96
+ **kwargs: Any
97
+ ) -> Union[SequenceClassifierOutput, tuple]:
98
+
99
+ use_return_dict = return_dict if return_dict is not None else getattr(self.config, "use_return_dict", True)
100
+
101
+ # 1. Прогоняем через базовый энкодер
102
+ encoder_outputs = self.encoder(
103
+ input_ids=input_ids,
104
+ attention_mask=attention_mask,
105
+ token_type_ids=token_type_ids,
106
+ position_ids=position_ids,
107
+ head_mask=head_mask,
108
+ inputs_embeds=inputs_embeds,
109
+ output_attentions=output_attentions,
110
+ output_hidden_states=output_hidden_states,
111
+ return_dict=use_return_dict,
112
+ **kwargs
113
+ )
114
+
115
+ # 2. Извлекаем sequence outputs
116
+ if isinstance(encoder_outputs, tuple):
117
+ last_hidden_state = encoder_outputs[0]
118
+ else:
119
+ last_hidden_state = encoder_outputs.last_hidden_state
120
+
121
+ # Извлекаем CLS токен (первый токен последовательности) как pooled representation
122
+ pooled_output = last_hidden_state[:, 0, :]
123
+
124
+ # 3. Вызываем CIDA-Plugin
125
+ plugin_outputs = self.plugin(
126
+ pooled=pooled_output,
127
+ seq_output=last_hidden_state,
128
+ mask=attention_mask
129
+ )
130
+
131
+ logits = plugin_outputs["p_final"]
132
+
133
+ # 4. Вычисляем лосс при наличии меток
134
+ loss = None
135
+ if labels is not None:
136
+ loss_fn = CIDALoss(
137
+ multi_label=self.plugin.config.multi_label,
138
+ lambda_cal=self.plugin.config.lambda_cal,
139
+ lambda_ac=self.plugin.config.lambda_ac,
140
+ min_disagreement=self.plugin.config.min_disagreement
141
+ ).to(logits.device)
142
+
143
+ orth_loss = plugin_outputs.get("L_orth", None)
144
+
145
+ loss, _ = loss_fn(
146
+ p_final=logits,
147
+ y=labels,
148
+ b_all=plugin_outputs["b_all"],
149
+ orth_loss=orth_loss
150
+ )
151
+
152
+ # 5. Возвращаем результат
153
+ if not use_return_dict:
154
+ output = (logits,)
155
+ if isinstance(encoder_outputs, tuple):
156
+ output = output + encoder_outputs[1:]
157
+ return ((loss,) + output) if loss is not None else output
158
+
159
+ # Собираем SequenceClassifierOutput для HF-совместимости
160
+ res = SequenceClassifierOutput(
161
+ loss=loss,
162
+ logits=logits,
163
+ hidden_states=getattr(encoder_outputs, "hidden_states", None),
164
+ attentions=getattr(encoder_outputs, "attentions", None)
165
+ )
166
+
167
+ # Добавляем CIDA метаданные как атрибуты для диагностики
168
+ res.b_all = plugin_outputs["b_all"]
169
+ res.uncertainty = plugin_outputs["uncertainty"]
170
+ res.disagreement = plugin_outputs["disagreement"]
171
+ res.rounds_used = plugin_outputs["rounds_used"]
172
+ res.plugin_outputs = plugin_outputs
173
+
174
+ return res
175
+
176
+
177
+ def wrap_hf_model(
178
+ hf_model: nn.Module,
179
+ cida_config: Optional[CIDAPluginConfig] = None,
180
+ **kwargs: Any
181
+ ) -> HFModelWithCIDA:
182
+ """
183
+ Быстрая обертка для Hugging Face sequence classification моделей.
184
+
185
+ Пример использования:
186
+ from transformers import AutoModelForSequenceClassification
187
+ from cida_plugin import wrap_hf_model
188
+
189
+ model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
190
+ model = wrap_hf_model(model, max_rounds=3)
191
+
192
+ # Теперь модель готова к обучению с CIDA-Plugin в Hugging Face Trainer!
193
+ """
194
+ return HFModelWithCIDA(hf_model, cida_config=cida_config, **kwargs)
@@ -0,0 +1,291 @@
1
+ """
2
+ Liquid Neural ODE Dynamics для CIDA-Plugin.
3
+
4
+ Заменяет дискретный цикл deliberation loop на непрерывное
5
+ дифференциальное уравнение:
6
+ ds/dt = -s/τ(x) + F(s, r, e)
7
+
8
+ Где:
9
+ s - скрытое состояние агентов
10
+ r - контр-аргументы от других агентов
11
+ e - извлечённые доказательства
12
+ τ(x) - зависящая от входа константа времени (Liquid Time Constant)
13
+ F - функция перехода (основанная на AgentUpdater)
14
+
15
+ Использование `torchdiffeq` позволяет использовать адаптивные решатели
16
+ (например, dopri5), которые автоматически регулируют количество
17
+ вычислений (NFE) в зависимости от сложности входа.
18
+ """
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ try:
25
+ from torchdiffeq import odeint_adjoint as odeint_adj
26
+ from torchdiffeq import odeint
27
+ _HAS_ADJOINT = True
28
+ except ImportError:
29
+ from torchdiffeq import odeint
30
+ odeint_adj = odeint
31
+ _HAS_ADJOINT = False
32
+
33
+ from .agent import AgentState
34
+ from .deliberation import (
35
+ AgentEvidenceExtractor,
36
+ CounterargumentCommunication,
37
+ AgentUpdater,
38
+ )
39
+
40
+
41
+ class LiquidTimeConstant(nn.Module):
42
+ """
43
+ Вычисляет τ(x) — константу времени, зависящую от начального состояния.
44
+ """
45
+
46
+ def __init__(self, d_hidden: int):
47
+ super().__init__()
48
+ self.net = nn.Sequential(
49
+ nn.Linear(d_hidden, d_hidden // 2),
50
+ nn.SiLU(),
51
+ nn.Linear(d_hidden // 2, 1),
52
+ nn.Sigmoid(),
53
+ )
54
+
55
+ def forward(self, s_init: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ s_init: (B, M, d)
58
+ Returns: (B, M, 1) — значения τ ∈ (0.1, 1.1)
59
+ """
60
+ return self.net(s_init) + 0.1
61
+
62
+
63
+ class LiquidDeliberationDynamics(nn.Module):
64
+ """
65
+ Обёртка над компонентами deliberation (extractor, communication, updater),
66
+ которая вычисляет производную ds/dt для ODE решателя.
67
+
68
+ Ключевое отличие от предыдущей версии: beliefs (b) и uncertainty (u)
69
+ вычисляются ДИНАМИЧЕСКИ из текущего скрытого состояния s через
70
+ updater.g(s), а не подставляются как нулевые заглушки.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ extractor: AgentEvidenceExtractor,
76
+ message_formulator,
77
+ communication: CounterargumentCommunication,
78
+ updater: AgentUpdater,
79
+ d_hidden: int,
80
+ trajectory_save_every: int = 10,
81
+ T: float = 4.0,
82
+ ):
83
+ super().__init__()
84
+ self.extractor = extractor
85
+ self.message_formulator = message_formulator
86
+ self.communication = communication
87
+ self.updater = updater
88
+ self.tau_net = LiquidTimeConstant(d_hidden)
89
+ self.trajectory_save_every = trajectory_save_every
90
+ self.T = T
91
+ # Временные точки для сохранения (корректно для adaptive solvers)
92
+ self._save_timepoints = {0.25, 0.50, 0.75, 1.0}
93
+
94
+ # Кэшируемый контекст (чтобы не прокидывать через ODE)
95
+ self.memory = None
96
+ self.mask = None
97
+ self.tau = None
98
+ self.nfe = 0 # Number of Function Evaluations
99
+
100
+ # Для сохранения промежуточных состояний (trajectory)
101
+ self.b_trajectory = []
102
+ self.u_trajectory = []
103
+ self.p_ptr_trajectory = []
104
+ self.m_trajectory = []
105
+
106
+ def set_context(
107
+ self,
108
+ memory: torch.Tensor,
109
+ mask: torch.Tensor,
110
+ s_init: torch.Tensor,
111
+ ):
112
+ """
113
+ Устанавливает контекст перед запуском ODE.
114
+ memory: (B, L, d) — память документов
115
+ mask: (B, L) — маска внимания
116
+ s_init: (B, M, d) — начальное состояние агентов
117
+ """
118
+ self.memory = memory
119
+ self.mask = mask
120
+ self.nfe = 0
121
+ # FIX v4: del [:] вместо clear() для явного освобождения памяти
122
+ del self.b_trajectory[:]
123
+ del self.u_trajectory[:]
124
+ del self.p_ptr_trajectory[:]
125
+ del self.m_trajectory[:]
126
+
127
+ def clear_context(self):
128
+ self.memory = None
129
+ self.mask = None
130
+
131
+ def _compute_beliefs(self, s: torch.Tensor):
132
+ """
133
+ Динамически вычисляет текущие beliefs и uncertainty из состояния s.
134
+
135
+ Вместо передачи dummy_b=0, dummy_u=1 (что лишало коммуникацию
136
+ информации о текущих мнениях агентов), мы проецируем s через
137
+ ту же голову, что использует AgentUpdater.
138
+
139
+ Returns: (current_b, current_u)
140
+ """
141
+ g_out = self.updater.g(s)
142
+
143
+ if self.updater.multi_label:
144
+ current_b = torch.sigmoid(g_out)
145
+ current_u = (4.0 * current_b * (1.0 - current_b)).mean(
146
+ dim=-1, keepdim=True
147
+ )
148
+ else:
149
+ alpha = F.softplus(g_out) + 1.0
150
+ alpha_sum = alpha.sum(dim=-1, keepdim=True)
151
+ current_b = alpha / alpha_sum
152
+ current_u = float(g_out.size(-1)) / alpha_sum
153
+
154
+ return current_b, current_u
155
+
156
+ def forward(self, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
157
+ """
158
+ Вычисляет ds/dt.
159
+ t: (1,) — текущее время ODE
160
+ s: (B, M, d) — текущее состояние агентов
161
+ """
162
+ self.nfe += 1
163
+
164
+ B, M = s.shape[0], s.shape[1]
165
+ num_classes = self.updater.g[-1].out_features
166
+
167
+ # Динамическое вычисление текущих beliefs и uncertainty
168
+ current_b, current_u = self._compute_beliefs(s)
169
+
170
+ # Evidence extraction
171
+ p_ptr_t, e_t = self.extractor(s, self.memory, self.mask)
172
+
173
+ # Communication с актуальными beliefs (не с нулями!)
174
+ m_t = self.message_formulator(s, e_t, current_b, current_u)
175
+ r_t = self.communication(s, m_t, current_b, e_t)
176
+
177
+ # AgentUpdater step
178
+ dummy_state = AgentState(
179
+ s=s,
180
+ b=current_b,
181
+ u=current_u,
182
+ p=torch.zeros(B, M, 1, device=s.device),
183
+ e=e_t,
184
+ alpha=torch.ones(B, M, num_classes, device=s.device),
185
+ )
186
+ next_state = self.updater(dummy_state, r_t, e_t)
187
+
188
+ # Сохраняем промежуточные beliefs для лоссов и консенсуса
189
+ # FIX v4: 1) detach() — без него ODE граф держится в памяти (сотни МБ)
190
+ # 2) temporal checkpoints вместо NFE-based (корректно для adaptive solvers)
191
+ if torch.is_grad_enabled() or not self.training:
192
+ t_normalized = t.item() / self.T if self.T > 0 else 0
193
+ should_save = any(abs(t_normalized - sp) < 0.02 for sp in self._save_timepoints)
194
+ if should_save or self.nfe % self.trajectory_save_every == 0:
195
+ self.b_trajectory.append(next_state.b.detach())
196
+ self.u_trajectory.append(next_state.u.detach())
197
+ self.p_ptr_trajectory.append(p_ptr_t.detach())
198
+ self.m_trajectory.append(s.detach())
199
+
200
+ # TRUE Liquid ODE v4: τ(s,t) вычисляется на КАЖДОМ шаге из текущего s
201
+ # (раньше τ фиксировался из s_init в set_context — это не настоящий Liquid NN)
202
+ tau = self.tau_net(s) # (B, M, 1) — state-dependent time constant
203
+ f_s = next_state.s - s
204
+ ds_dt = -s / tau + f_s
205
+
206
+ return ds_dt
207
+
208
+
209
+ class LiquidDeliberationSolver(nn.Module):
210
+ """
211
+ Запускает ODE решатель для deliberation.
212
+
213
+ При use_adjoint=True использует adjoint method для экономии VRAM:
214
+ градиенты считаются без хранения промежуточных состояний.
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ extractor: AgentEvidenceExtractor,
220
+ message_formulator,
221
+ communication: CounterargumentCommunication,
222
+ updater: AgentUpdater,
223
+ d_hidden: int,
224
+ solver: str = "euler",
225
+ atol: float = 1e-3,
226
+ rtol: float = 1e-3,
227
+ T: float = 4.0,
228
+ trajectory_save_every: int = 10,
229
+ use_adjoint: bool = False,
230
+ ):
231
+ super().__init__()
232
+ self.dynamics = LiquidDeliberationDynamics(
233
+ extractor=extractor,
234
+ message_formulator=message_formulator,
235
+ communication=communication,
236
+ updater=updater,
237
+ d_hidden=d_hidden,
238
+ trajectory_save_every=trajectory_save_every,
239
+ T=T,
240
+ )
241
+ self.solver = solver
242
+ self.atol = atol
243
+ self.rtol = rtol
244
+ self.T = T
245
+ self.use_adjoint = use_adjoint and _HAS_ADJOINT
246
+ self._integrate = odeint_adj if self.use_adjoint else odeint
247
+
248
+ def forward(
249
+ self,
250
+ s_init: torch.Tensor,
251
+ memory: torch.Tensor,
252
+ mask: torch.Tensor,
253
+ ) -> tuple:
254
+ """
255
+ s_init: (B, M, d)
256
+ memory: (B, L, d)
257
+ mask: (B, L)
258
+
259
+ Returns:
260
+ s_final (B, M, d)
261
+ trajectory (dict)
262
+ """
263
+ device = s_init.device
264
+ self.dynamics.set_context(memory, mask, s_init)
265
+
266
+ t = torch.tensor([0.0, self.T], device=device)
267
+
268
+ states = self._integrate(
269
+ self.dynamics,
270
+ s_init,
271
+ t,
272
+ method=self.solver,
273
+ atol=self.atol,
274
+ rtol=self.rtol,
275
+ )
276
+
277
+ s_final = states[-1]
278
+ nfe = self.dynamics.nfe
279
+
280
+ trajectory = {
281
+ "b_all": self.dynamics.b_trajectory,
282
+ "u_all": self.dynamics.u_trajectory,
283
+ "p_ptr_all": self.dynamics.p_ptr_trajectory,
284
+ "m_all": self.dynamics.m_trajectory,
285
+ "nfe": nfe,
286
+ }
287
+
288
+ self.dynamics.clear_context()
289
+
290
+ return s_final, trajectory
291
+
cida_plugin/losses.py ADDED
@@ -0,0 +1,174 @@
1
+ """
2
+ losses.py — Упрощённая система потерь CIDA (3 компонента вместо 11)
3
+
4
+ Было (OmegaLossSystem):
5
+ task + calibration + debate + progress + budget + dominance
6
+ + role + role_spec + orthogonality + ponder_cost + ttt_reg
7
+ = 11 компонентов, 11 гиперпараметров
8
+
9
+ Стало (CIDALoss):
10
+ task + calibration + anti_collapse
11
+ = 3 компонента, 2 гиперпараметра
12
+
13
+ Почему можно удалить остальные 8:
14
+
15
+ debate_loss → удалён: несогласие теперь структурное (role priors)
16
+ progress_loss → удалён: избыточен с task_loss при правильных priors
17
+ budget_loss → удалён: ACT не нужен если число раундов мало (≤3)
18
+ dominance_loss → удалён: симптом проблемы с PoE, которую мы исправили
19
+ role_loss → удалён: нет RoleEmbeddings → нет и orth-лосса для них
20
+ role_spec_loss → удалён: роли теперь structural, не statistical
21
+ orth_loss → заменён одной строкой в compute_role_orthogonality_loss()
22
+ ttt_reg → удалён: TTT убран из базовой архитектуры
23
+
24
+ anti_collapse_loss:
25
+ Единственный "новый" компонент, но он ПРОЩЕ и ПРАВИЛЬНЕЕ debate_loss.
26
+ Вместо "попади в целевой уровень несогласия на каждом раунде" →
27
+ "просто не позволяй агентам стать идентичными".
28
+ Это МИНИМАЛЬНОЕ требование, а не жёсткое целевое расписание.
29
+ """
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+
36
+ class CIDALoss(nn.Module):
37
+ """
38
+ Упрощённая система потерь для CIDA-Plugin.
39
+
40
+ Параметры
41
+ ---------
42
+ lambda_cal : float
43
+ Вес calibration loss (Brier score). Рекомендуется 0.3–0.5.
44
+ lambda_ac : float
45
+ Вес anti-collapse loss. Рекомендуется 0.1–0.3.
46
+ Меньше → агенты могут сближаться (риск коллапса).
47
+ Больше → агенты разведены сильнее (риск потери точности).
48
+ min_disagreement : float
49
+ Минимальный порог несогласия. Рекомендуется 0.05–0.15.
50
+ При role priors реальное несогласие ≥ 0.21, поэтому loss
51
+ почти всегда = 0 (действует как safety net).
52
+ multi_label : bool
53
+ Если True — используется BCE вместо CE.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ lambda_cal: float = 0.4,
59
+ lambda_ac: float = 0.2,
60
+ min_disagreement: float = 0.08,
61
+ multi_label: bool = False,
62
+ ):
63
+ super().__init__()
64
+ self.lambda_cal = lambda_cal
65
+ self.lambda_ac = lambda_ac
66
+ self.min_disagreement = min_disagreement
67
+ self.multi_label = multi_label
68
+
69
+ # ── Компонент 1: Task Loss ────────────────────────────────────────────────
70
+
71
+ def task_loss(self, p_final: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
72
+ """
73
+ Основная задача: классификация.
74
+ """
75
+ eps = 1e-7
76
+ if self.multi_label:
77
+ return F.binary_cross_entropy(p_final.clamp(eps, 1.0 - eps), y)
78
+ # FIX v4: log-softmax стабильнее чем clamp().log() — согласовано с consensus.py
79
+ log_probs = F.log_softmax(p_final.log().clamp(-10, 10), dim=-1)
80
+ return F.nll_loss(log_probs, y)
81
+
82
+ # ── Компонент 2: Calibration Loss ─────────────────────────────────────────
83
+
84
+ def calibration_loss(self, p_final: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
85
+ """
86
+ Brier score: MSE между вероятностями и метками.
87
+ """
88
+ if self.multi_label:
89
+ probs = p_final
90
+ return ((probs - y) ** 2).mean()
91
+
92
+ # Single-label: one-hot targets
93
+ K = p_final.size(-1)
94
+ y_oh = F.one_hot(y, num_classes=K).float()
95
+ return ((p_final - y_oh) ** 2).mean()
96
+
97
+ # ── Компонент 3: Anti-Collapse Loss ──────────────────────────────────────
98
+
99
+ def anti_collapse_loss(self, b_final: torch.Tensor) -> torch.Tensor:
100
+ """
101
+ Предотвращает коллапс агентов к идентичным убеждениям.
102
+
103
+ НЕ является аналогом debate_loss. Ключевые отличия:
104
+
105
+ debate_loss:
106
+ - Целевое расписание d_t на каждый раунд
107
+ - Штрафует если несогласие != цели
108
+ - Создаёт противоречие с consensus
109
+
110
+ anti_collapse_loss:
111
+ - Только минимальная граница min_disagreement
112
+ - Штрафует только при несогласии НИЖЕ порога
113
+ - Нулевой градиент при нормальном обучении (safety net)
114
+
115
+ b_final: (B, M, K) — финальные убеждения агентов
116
+ Returns: scalar (обычно ≈ 0 если role priors работают)
117
+
118
+ Математика:
119
+ d_ij = ||b_i - b_j||_1 — несогласие между агентами i и j
120
+ L_ac = mean(relu(δ - d_ij)) где δ = min_disagreement
121
+ Градиент ненулевой только когда d_ij < δ
122
+ """
123
+ b_i = b_final.unsqueeze(2) # (B, M, 1, K)
124
+ b_j = b_final.unsqueeze(1) # (B, 1, M, K)
125
+ pairwise_dist = (b_i - b_j).abs().sum(dim=-1) # (B, M, M)
126
+
127
+ # Только нижняя треугольная часть (без диагонали)
128
+ M = b_final.size(1)
129
+ mask = torch.tril(torch.ones(M, M, device=b_final.device), diagonal=-1).bool()
130
+ off_diag = pairwise_dist[:, mask] # (B, n_pairs)
131
+
132
+ return F.relu(self.min_disagreement - off_diag).mean()
133
+
134
+ # ── Основной forward ──────────────────────────────────────────────────────
135
+
136
+ def forward(
137
+ self,
138
+ p_final: torch.Tensor,
139
+ y: torch.Tensor,
140
+ b_all: list,
141
+ orth_loss: torch.Tensor = None,
142
+ ) -> tuple[torch.Tensor, dict]:
143
+ """
144
+ p_final : (B, K) — финальное предсказание
145
+ y : (B,) или (B,K) — метки
146
+ b_all : list[(B,M,K)] — история убеждений по раундам
147
+ orth_loss: scalar или None — orth loss из representations (опционально)
148
+
149
+ Returns: (total_loss, components_dict)
150
+ """
151
+ l_t = self.task_loss(p_final, y)
152
+ l_c = self.calibration_loss(p_final, y)
153
+
154
+ # Anti-collapse на финальных убеждениях
155
+ b_final = b_all[-1] if b_all else None
156
+ l_ac = self.anti_collapse_loss(b_final) if b_final is not None else torch.tensor(0.0)
157
+
158
+ total = l_t + self.lambda_cal * l_c + self.lambda_ac * l_ac
159
+
160
+ # Orth loss опционально (из representations, не из весов)
161
+ l_orth = torch.tensor(0.0)
162
+ if orth_loss is not None:
163
+ l_orth = orth_loss
164
+ total = total + 0.05 * l_orth # малый вес, это regularizer
165
+
166
+ components = {
167
+ "task": l_t.item(),
168
+ "calibration": l_c.item(),
169
+ "anti_collapse": l_ac.item(),
170
+ "orth": l_orth.item() if isinstance(l_orth, torch.Tensor) else l_orth,
171
+ "total": total.item(),
172
+ }
173
+
174
+ return total, components