cida-plugin 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cida_plugin/__init__.py +62 -0
- cida_plugin/agent.py +198 -0
- cida_plugin/config.py +228 -0
- cida_plugin/consensus.py +178 -0
- cida_plugin/core.py +451 -0
- cida_plugin/deliberation.py +376 -0
- cida_plugin/diagnostics.py +124 -0
- cida_plugin/hf.py +194 -0
- cida_plugin/liquid_dynamics.py +291 -0
- cida_plugin/losses.py +174 -0
- cida_plugin/translator.py +74 -0
- cida_plugin/ttt.py +348 -0
- cida_plugin/vision_backbone.py +43 -0
- cida_plugin-1.0.0.dist-info/METADATA +167 -0
- cida_plugin-1.0.0.dist-info/RECORD +18 -0
- cida_plugin-1.0.0.dist-info/WHEEL +5 -0
- cida_plugin-1.0.0.dist-info/licenses/LICENSE +201 -0
- cida_plugin-1.0.0.dist-info/top_level.txt +1 -0
cida_plugin/hf.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""
|
|
2
|
+
hf.py — Интеграция CIDA-Plugin с Hugging Face Models.
|
|
3
|
+
|
|
4
|
+
Позволяет обернуть любую модель классификации последовательностей (Sequence Classification)
|
|
5
|
+
в CIDA-Plugin в одну строчку кода, заменяя стандартную линейную голову.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from typing import Optional, Union, Dict, Any
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
|
14
|
+
_HAS_TRANSFORMERS = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
# Заглушка на случай отсутствия transformers (например, при тестировании голого PyTorch)
|
|
17
|
+
class SequenceClassifierOutput:
|
|
18
|
+
def __init__(self, loss=None, logits=None, hidden_states=None, attentions=None):
|
|
19
|
+
self.loss = loss
|
|
20
|
+
self.logits = logits
|
|
21
|
+
self.hidden_states = hidden_states
|
|
22
|
+
self.attentions = attentions
|
|
23
|
+
_HAS_TRANSFORMERS = False
|
|
24
|
+
|
|
25
|
+
from .core import CIDAPlugin
|
|
26
|
+
from .config import CIDAPluginConfig
|
|
27
|
+
from .losses import CIDALoss
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class HFModelWithCIDA(nn.Module):
|
|
31
|
+
"""
|
|
32
|
+
Обертка над предобученной Hugging Face моделью классификации текстов.
|
|
33
|
+
Перехватывает скрытые состояния базового энкодера и пропускает их через CIDA-Plugin.
|
|
34
|
+
"""
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
hf_model: nn.Module,
|
|
38
|
+
cida_config: Optional[CIDAPluginConfig] = None,
|
|
39
|
+
**cida_kwargs: Any
|
|
40
|
+
):
|
|
41
|
+
super().__init__()
|
|
42
|
+
|
|
43
|
+
# Сохраняем исходную конфигурацию
|
|
44
|
+
self.config = getattr(hf_model, "config", None)
|
|
45
|
+
|
|
46
|
+
# 1. Извлекаем базовый энкодер (transformer backbone)
|
|
47
|
+
if hasattr(hf_model, "base_model"):
|
|
48
|
+
self.encoder = hf_model.base_model
|
|
49
|
+
else:
|
|
50
|
+
# Fallback для кастомных моделей
|
|
51
|
+
self.encoder = hf_model
|
|
52
|
+
|
|
53
|
+
# 2. Определяем размерности
|
|
54
|
+
if self.config is not None:
|
|
55
|
+
d_model = getattr(self.config, "hidden_size", None) or getattr(self.config, "d_model", None)
|
|
56
|
+
num_labels = getattr(self.config, "num_labels", 2)
|
|
57
|
+
else:
|
|
58
|
+
d_model = cida_kwargs.get("d_input")
|
|
59
|
+
num_labels = cida_kwargs.get("num_classes", 2)
|
|
60
|
+
|
|
61
|
+
if d_model is None:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
"Не удалось автоматически определить hidden_size из конфигурации модели HF. "
|
|
64
|
+
"Пожалуйста, передайте параметр d_input явно в cida_kwargs."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# 3. Инициализируем CIDA-Plugin
|
|
68
|
+
if cida_config is None:
|
|
69
|
+
cida_config = CIDAPluginConfig(
|
|
70
|
+
d_input=d_model,
|
|
71
|
+
num_classes=num_labels,
|
|
72
|
+
**cida_kwargs
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
# Обновляем поля, если они переданы пустые
|
|
76
|
+
if cida_config.d_input is None:
|
|
77
|
+
cida_config.d_input = d_model
|
|
78
|
+
if cida_config.num_classes is None:
|
|
79
|
+
cida_config.num_classes = num_labels
|
|
80
|
+
|
|
81
|
+
self.plugin = CIDAPlugin(cida_config)
|
|
82
|
+
self.num_labels = num_labels
|
|
83
|
+
|
|
84
|
+
def forward(
|
|
85
|
+
self,
|
|
86
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
87
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
88
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
89
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
90
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
91
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
92
|
+
labels: Optional[torch.Tensor] = None,
|
|
93
|
+
output_attentions: Optional[bool] = None,
|
|
94
|
+
output_hidden_states: Optional[bool] = None,
|
|
95
|
+
return_dict: Optional[bool] = None,
|
|
96
|
+
**kwargs: Any
|
|
97
|
+
) -> Union[SequenceClassifierOutput, tuple]:
|
|
98
|
+
|
|
99
|
+
use_return_dict = return_dict if return_dict is not None else getattr(self.config, "use_return_dict", True)
|
|
100
|
+
|
|
101
|
+
# 1. Прогоняем через базовый энкодер
|
|
102
|
+
encoder_outputs = self.encoder(
|
|
103
|
+
input_ids=input_ids,
|
|
104
|
+
attention_mask=attention_mask,
|
|
105
|
+
token_type_ids=token_type_ids,
|
|
106
|
+
position_ids=position_ids,
|
|
107
|
+
head_mask=head_mask,
|
|
108
|
+
inputs_embeds=inputs_embeds,
|
|
109
|
+
output_attentions=output_attentions,
|
|
110
|
+
output_hidden_states=output_hidden_states,
|
|
111
|
+
return_dict=use_return_dict,
|
|
112
|
+
**kwargs
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# 2. Извлекаем sequence outputs
|
|
116
|
+
if isinstance(encoder_outputs, tuple):
|
|
117
|
+
last_hidden_state = encoder_outputs[0]
|
|
118
|
+
else:
|
|
119
|
+
last_hidden_state = encoder_outputs.last_hidden_state
|
|
120
|
+
|
|
121
|
+
# Извлекаем CLS токен (первый токен последовательности) как pooled representation
|
|
122
|
+
pooled_output = last_hidden_state[:, 0, :]
|
|
123
|
+
|
|
124
|
+
# 3. Вызываем CIDA-Plugin
|
|
125
|
+
plugin_outputs = self.plugin(
|
|
126
|
+
pooled=pooled_output,
|
|
127
|
+
seq_output=last_hidden_state,
|
|
128
|
+
mask=attention_mask
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
logits = plugin_outputs["p_final"]
|
|
132
|
+
|
|
133
|
+
# 4. Вычисляем лосс при наличии меток
|
|
134
|
+
loss = None
|
|
135
|
+
if labels is not None:
|
|
136
|
+
loss_fn = CIDALoss(
|
|
137
|
+
multi_label=self.plugin.config.multi_label,
|
|
138
|
+
lambda_cal=self.plugin.config.lambda_cal,
|
|
139
|
+
lambda_ac=self.plugin.config.lambda_ac,
|
|
140
|
+
min_disagreement=self.plugin.config.min_disagreement
|
|
141
|
+
).to(logits.device)
|
|
142
|
+
|
|
143
|
+
orth_loss = plugin_outputs.get("L_orth", None)
|
|
144
|
+
|
|
145
|
+
loss, _ = loss_fn(
|
|
146
|
+
p_final=logits,
|
|
147
|
+
y=labels,
|
|
148
|
+
b_all=plugin_outputs["b_all"],
|
|
149
|
+
orth_loss=orth_loss
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# 5. Возвращаем результат
|
|
153
|
+
if not use_return_dict:
|
|
154
|
+
output = (logits,)
|
|
155
|
+
if isinstance(encoder_outputs, tuple):
|
|
156
|
+
output = output + encoder_outputs[1:]
|
|
157
|
+
return ((loss,) + output) if loss is not None else output
|
|
158
|
+
|
|
159
|
+
# Собираем SequenceClassifierOutput для HF-совместимости
|
|
160
|
+
res = SequenceClassifierOutput(
|
|
161
|
+
loss=loss,
|
|
162
|
+
logits=logits,
|
|
163
|
+
hidden_states=getattr(encoder_outputs, "hidden_states", None),
|
|
164
|
+
attentions=getattr(encoder_outputs, "attentions", None)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Добавляем CIDA метаданные как атрибуты для диагностики
|
|
168
|
+
res.b_all = plugin_outputs["b_all"]
|
|
169
|
+
res.uncertainty = plugin_outputs["uncertainty"]
|
|
170
|
+
res.disagreement = plugin_outputs["disagreement"]
|
|
171
|
+
res.rounds_used = plugin_outputs["rounds_used"]
|
|
172
|
+
res.plugin_outputs = plugin_outputs
|
|
173
|
+
|
|
174
|
+
return res
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def wrap_hf_model(
|
|
178
|
+
hf_model: nn.Module,
|
|
179
|
+
cida_config: Optional[CIDAPluginConfig] = None,
|
|
180
|
+
**kwargs: Any
|
|
181
|
+
) -> HFModelWithCIDA:
|
|
182
|
+
"""
|
|
183
|
+
Быстрая обертка для Hugging Face sequence classification моделей.
|
|
184
|
+
|
|
185
|
+
Пример использования:
|
|
186
|
+
from transformers import AutoModelForSequenceClassification
|
|
187
|
+
from cida_plugin import wrap_hf_model
|
|
188
|
+
|
|
189
|
+
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
|
|
190
|
+
model = wrap_hf_model(model, max_rounds=3)
|
|
191
|
+
|
|
192
|
+
# Теперь модель готова к обучению с CIDA-Plugin в Hugging Face Trainer!
|
|
193
|
+
"""
|
|
194
|
+
return HFModelWithCIDA(hf_model, cida_config=cida_config, **kwargs)
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Liquid Neural ODE Dynamics для CIDA-Plugin.
|
|
3
|
+
|
|
4
|
+
Заменяет дискретный цикл deliberation loop на непрерывное
|
|
5
|
+
дифференциальное уравнение:
|
|
6
|
+
ds/dt = -s/τ(x) + F(s, r, e)
|
|
7
|
+
|
|
8
|
+
Где:
|
|
9
|
+
s - скрытое состояние агентов
|
|
10
|
+
r - контр-аргументы от других агентов
|
|
11
|
+
e - извлечённые доказательства
|
|
12
|
+
τ(x) - зависящая от входа константа времени (Liquid Time Constant)
|
|
13
|
+
F - функция перехода (основанная на AgentUpdater)
|
|
14
|
+
|
|
15
|
+
Использование `torchdiffeq` позволяет использовать адаптивные решатели
|
|
16
|
+
(например, dopri5), которые автоматически регулируют количество
|
|
17
|
+
вычислений (NFE) в зависимости от сложности входа.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
import torch.nn.functional as F
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from torchdiffeq import odeint_adjoint as odeint_adj
|
|
26
|
+
from torchdiffeq import odeint
|
|
27
|
+
_HAS_ADJOINT = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
from torchdiffeq import odeint
|
|
30
|
+
odeint_adj = odeint
|
|
31
|
+
_HAS_ADJOINT = False
|
|
32
|
+
|
|
33
|
+
from .agent import AgentState
|
|
34
|
+
from .deliberation import (
|
|
35
|
+
AgentEvidenceExtractor,
|
|
36
|
+
CounterargumentCommunication,
|
|
37
|
+
AgentUpdater,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LiquidTimeConstant(nn.Module):
|
|
42
|
+
"""
|
|
43
|
+
Вычисляет τ(x) — константу времени, зависящую от начального состояния.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, d_hidden: int):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.net = nn.Sequential(
|
|
49
|
+
nn.Linear(d_hidden, d_hidden // 2),
|
|
50
|
+
nn.SiLU(),
|
|
51
|
+
nn.Linear(d_hidden // 2, 1),
|
|
52
|
+
nn.Sigmoid(),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def forward(self, s_init: torch.Tensor) -> torch.Tensor:
|
|
56
|
+
"""
|
|
57
|
+
s_init: (B, M, d)
|
|
58
|
+
Returns: (B, M, 1) — значения τ ∈ (0.1, 1.1)
|
|
59
|
+
"""
|
|
60
|
+
return self.net(s_init) + 0.1
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class LiquidDeliberationDynamics(nn.Module):
|
|
64
|
+
"""
|
|
65
|
+
Обёртка над компонентами deliberation (extractor, communication, updater),
|
|
66
|
+
которая вычисляет производную ds/dt для ODE решателя.
|
|
67
|
+
|
|
68
|
+
Ключевое отличие от предыдущей версии: beliefs (b) и uncertainty (u)
|
|
69
|
+
вычисляются ДИНАМИЧЕСКИ из текущего скрытого состояния s через
|
|
70
|
+
updater.g(s), а не подставляются как нулевые заглушки.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
extractor: AgentEvidenceExtractor,
|
|
76
|
+
message_formulator,
|
|
77
|
+
communication: CounterargumentCommunication,
|
|
78
|
+
updater: AgentUpdater,
|
|
79
|
+
d_hidden: int,
|
|
80
|
+
trajectory_save_every: int = 10,
|
|
81
|
+
T: float = 4.0,
|
|
82
|
+
):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.extractor = extractor
|
|
85
|
+
self.message_formulator = message_formulator
|
|
86
|
+
self.communication = communication
|
|
87
|
+
self.updater = updater
|
|
88
|
+
self.tau_net = LiquidTimeConstant(d_hidden)
|
|
89
|
+
self.trajectory_save_every = trajectory_save_every
|
|
90
|
+
self.T = T
|
|
91
|
+
# Временные точки для сохранения (корректно для adaptive solvers)
|
|
92
|
+
self._save_timepoints = {0.25, 0.50, 0.75, 1.0}
|
|
93
|
+
|
|
94
|
+
# Кэшируемый контекст (чтобы не прокидывать через ODE)
|
|
95
|
+
self.memory = None
|
|
96
|
+
self.mask = None
|
|
97
|
+
self.tau = None
|
|
98
|
+
self.nfe = 0 # Number of Function Evaluations
|
|
99
|
+
|
|
100
|
+
# Для сохранения промежуточных состояний (trajectory)
|
|
101
|
+
self.b_trajectory = []
|
|
102
|
+
self.u_trajectory = []
|
|
103
|
+
self.p_ptr_trajectory = []
|
|
104
|
+
self.m_trajectory = []
|
|
105
|
+
|
|
106
|
+
def set_context(
|
|
107
|
+
self,
|
|
108
|
+
memory: torch.Tensor,
|
|
109
|
+
mask: torch.Tensor,
|
|
110
|
+
s_init: torch.Tensor,
|
|
111
|
+
):
|
|
112
|
+
"""
|
|
113
|
+
Устанавливает контекст перед запуском ODE.
|
|
114
|
+
memory: (B, L, d) — память документов
|
|
115
|
+
mask: (B, L) — маска внимания
|
|
116
|
+
s_init: (B, M, d) — начальное состояние агентов
|
|
117
|
+
"""
|
|
118
|
+
self.memory = memory
|
|
119
|
+
self.mask = mask
|
|
120
|
+
self.nfe = 0
|
|
121
|
+
# FIX v4: del [:] вместо clear() для явного освобождения памяти
|
|
122
|
+
del self.b_trajectory[:]
|
|
123
|
+
del self.u_trajectory[:]
|
|
124
|
+
del self.p_ptr_trajectory[:]
|
|
125
|
+
del self.m_trajectory[:]
|
|
126
|
+
|
|
127
|
+
def clear_context(self):
|
|
128
|
+
self.memory = None
|
|
129
|
+
self.mask = None
|
|
130
|
+
|
|
131
|
+
def _compute_beliefs(self, s: torch.Tensor):
|
|
132
|
+
"""
|
|
133
|
+
Динамически вычисляет текущие beliefs и uncertainty из состояния s.
|
|
134
|
+
|
|
135
|
+
Вместо передачи dummy_b=0, dummy_u=1 (что лишало коммуникацию
|
|
136
|
+
информации о текущих мнениях агентов), мы проецируем s через
|
|
137
|
+
ту же голову, что использует AgentUpdater.
|
|
138
|
+
|
|
139
|
+
Returns: (current_b, current_u)
|
|
140
|
+
"""
|
|
141
|
+
g_out = self.updater.g(s)
|
|
142
|
+
|
|
143
|
+
if self.updater.multi_label:
|
|
144
|
+
current_b = torch.sigmoid(g_out)
|
|
145
|
+
current_u = (4.0 * current_b * (1.0 - current_b)).mean(
|
|
146
|
+
dim=-1, keepdim=True
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
alpha = F.softplus(g_out) + 1.0
|
|
150
|
+
alpha_sum = alpha.sum(dim=-1, keepdim=True)
|
|
151
|
+
current_b = alpha / alpha_sum
|
|
152
|
+
current_u = float(g_out.size(-1)) / alpha_sum
|
|
153
|
+
|
|
154
|
+
return current_b, current_u
|
|
155
|
+
|
|
156
|
+
def forward(self, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
|
|
157
|
+
"""
|
|
158
|
+
Вычисляет ds/dt.
|
|
159
|
+
t: (1,) — текущее время ODE
|
|
160
|
+
s: (B, M, d) — текущее состояние агентов
|
|
161
|
+
"""
|
|
162
|
+
self.nfe += 1
|
|
163
|
+
|
|
164
|
+
B, M = s.shape[0], s.shape[1]
|
|
165
|
+
num_classes = self.updater.g[-1].out_features
|
|
166
|
+
|
|
167
|
+
# Динамическое вычисление текущих beliefs и uncertainty
|
|
168
|
+
current_b, current_u = self._compute_beliefs(s)
|
|
169
|
+
|
|
170
|
+
# Evidence extraction
|
|
171
|
+
p_ptr_t, e_t = self.extractor(s, self.memory, self.mask)
|
|
172
|
+
|
|
173
|
+
# Communication с актуальными beliefs (не с нулями!)
|
|
174
|
+
m_t = self.message_formulator(s, e_t, current_b, current_u)
|
|
175
|
+
r_t = self.communication(s, m_t, current_b, e_t)
|
|
176
|
+
|
|
177
|
+
# AgentUpdater step
|
|
178
|
+
dummy_state = AgentState(
|
|
179
|
+
s=s,
|
|
180
|
+
b=current_b,
|
|
181
|
+
u=current_u,
|
|
182
|
+
p=torch.zeros(B, M, 1, device=s.device),
|
|
183
|
+
e=e_t,
|
|
184
|
+
alpha=torch.ones(B, M, num_classes, device=s.device),
|
|
185
|
+
)
|
|
186
|
+
next_state = self.updater(dummy_state, r_t, e_t)
|
|
187
|
+
|
|
188
|
+
# Сохраняем промежуточные beliefs для лоссов и консенсуса
|
|
189
|
+
# FIX v4: 1) detach() — без него ODE граф держится в памяти (сотни МБ)
|
|
190
|
+
# 2) temporal checkpoints вместо NFE-based (корректно для adaptive solvers)
|
|
191
|
+
if torch.is_grad_enabled() or not self.training:
|
|
192
|
+
t_normalized = t.item() / self.T if self.T > 0 else 0
|
|
193
|
+
should_save = any(abs(t_normalized - sp) < 0.02 for sp in self._save_timepoints)
|
|
194
|
+
if should_save or self.nfe % self.trajectory_save_every == 0:
|
|
195
|
+
self.b_trajectory.append(next_state.b.detach())
|
|
196
|
+
self.u_trajectory.append(next_state.u.detach())
|
|
197
|
+
self.p_ptr_trajectory.append(p_ptr_t.detach())
|
|
198
|
+
self.m_trajectory.append(s.detach())
|
|
199
|
+
|
|
200
|
+
# TRUE Liquid ODE v4: τ(s,t) вычисляется на КАЖДОМ шаге из текущего s
|
|
201
|
+
# (раньше τ фиксировался из s_init в set_context — это не настоящий Liquid NN)
|
|
202
|
+
tau = self.tau_net(s) # (B, M, 1) — state-dependent time constant
|
|
203
|
+
f_s = next_state.s - s
|
|
204
|
+
ds_dt = -s / tau + f_s
|
|
205
|
+
|
|
206
|
+
return ds_dt
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class LiquidDeliberationSolver(nn.Module):
|
|
210
|
+
"""
|
|
211
|
+
Запускает ODE решатель для deliberation.
|
|
212
|
+
|
|
213
|
+
При use_adjoint=True использует adjoint method для экономии VRAM:
|
|
214
|
+
градиенты считаются без хранения промежуточных состояний.
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
extractor: AgentEvidenceExtractor,
|
|
220
|
+
message_formulator,
|
|
221
|
+
communication: CounterargumentCommunication,
|
|
222
|
+
updater: AgentUpdater,
|
|
223
|
+
d_hidden: int,
|
|
224
|
+
solver: str = "euler",
|
|
225
|
+
atol: float = 1e-3,
|
|
226
|
+
rtol: float = 1e-3,
|
|
227
|
+
T: float = 4.0,
|
|
228
|
+
trajectory_save_every: int = 10,
|
|
229
|
+
use_adjoint: bool = False,
|
|
230
|
+
):
|
|
231
|
+
super().__init__()
|
|
232
|
+
self.dynamics = LiquidDeliberationDynamics(
|
|
233
|
+
extractor=extractor,
|
|
234
|
+
message_formulator=message_formulator,
|
|
235
|
+
communication=communication,
|
|
236
|
+
updater=updater,
|
|
237
|
+
d_hidden=d_hidden,
|
|
238
|
+
trajectory_save_every=trajectory_save_every,
|
|
239
|
+
T=T,
|
|
240
|
+
)
|
|
241
|
+
self.solver = solver
|
|
242
|
+
self.atol = atol
|
|
243
|
+
self.rtol = rtol
|
|
244
|
+
self.T = T
|
|
245
|
+
self.use_adjoint = use_adjoint and _HAS_ADJOINT
|
|
246
|
+
self._integrate = odeint_adj if self.use_adjoint else odeint
|
|
247
|
+
|
|
248
|
+
def forward(
|
|
249
|
+
self,
|
|
250
|
+
s_init: torch.Tensor,
|
|
251
|
+
memory: torch.Tensor,
|
|
252
|
+
mask: torch.Tensor,
|
|
253
|
+
) -> tuple:
|
|
254
|
+
"""
|
|
255
|
+
s_init: (B, M, d)
|
|
256
|
+
memory: (B, L, d)
|
|
257
|
+
mask: (B, L)
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
s_final (B, M, d)
|
|
261
|
+
trajectory (dict)
|
|
262
|
+
"""
|
|
263
|
+
device = s_init.device
|
|
264
|
+
self.dynamics.set_context(memory, mask, s_init)
|
|
265
|
+
|
|
266
|
+
t = torch.tensor([0.0, self.T], device=device)
|
|
267
|
+
|
|
268
|
+
states = self._integrate(
|
|
269
|
+
self.dynamics,
|
|
270
|
+
s_init,
|
|
271
|
+
t,
|
|
272
|
+
method=self.solver,
|
|
273
|
+
atol=self.atol,
|
|
274
|
+
rtol=self.rtol,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
s_final = states[-1]
|
|
278
|
+
nfe = self.dynamics.nfe
|
|
279
|
+
|
|
280
|
+
trajectory = {
|
|
281
|
+
"b_all": self.dynamics.b_trajectory,
|
|
282
|
+
"u_all": self.dynamics.u_trajectory,
|
|
283
|
+
"p_ptr_all": self.dynamics.p_ptr_trajectory,
|
|
284
|
+
"m_all": self.dynamics.m_trajectory,
|
|
285
|
+
"nfe": nfe,
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
self.dynamics.clear_context()
|
|
289
|
+
|
|
290
|
+
return s_final, trajectory
|
|
291
|
+
|
cida_plugin/losses.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""
|
|
2
|
+
losses.py — Упрощённая система потерь CIDA (3 компонента вместо 11)
|
|
3
|
+
|
|
4
|
+
Было (OmegaLossSystem):
|
|
5
|
+
task + calibration + debate + progress + budget + dominance
|
|
6
|
+
+ role + role_spec + orthogonality + ponder_cost + ttt_reg
|
|
7
|
+
= 11 компонентов, 11 гиперпараметров
|
|
8
|
+
|
|
9
|
+
Стало (CIDALoss):
|
|
10
|
+
task + calibration + anti_collapse
|
|
11
|
+
= 3 компонента, 2 гиперпараметра
|
|
12
|
+
|
|
13
|
+
Почему можно удалить остальные 8:
|
|
14
|
+
|
|
15
|
+
debate_loss → удалён: несогласие теперь структурное (role priors)
|
|
16
|
+
progress_loss → удалён: избыточен с task_loss при правильных priors
|
|
17
|
+
budget_loss → удалён: ACT не нужен если число раундов мало (≤3)
|
|
18
|
+
dominance_loss → удалён: симптом проблемы с PoE, которую мы исправили
|
|
19
|
+
role_loss → удалён: нет RoleEmbeddings → нет и orth-лосса для них
|
|
20
|
+
role_spec_loss → удалён: роли теперь structural, не statistical
|
|
21
|
+
orth_loss → заменён одной строкой в compute_role_orthogonality_loss()
|
|
22
|
+
ttt_reg → удалён: TTT убран из базовой архитектуры
|
|
23
|
+
|
|
24
|
+
anti_collapse_loss:
|
|
25
|
+
Единственный "новый" компонент, но он ПРОЩЕ и ПРАВИЛЬНЕЕ debate_loss.
|
|
26
|
+
Вместо "попади в целевой уровень несогласия на каждом раунде" →
|
|
27
|
+
"просто не позволяй агентам стать идентичными".
|
|
28
|
+
Это МИНИМАЛЬНОЕ требование, а не жёсткое целевое расписание.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
import torch
|
|
32
|
+
import torch.nn as nn
|
|
33
|
+
import torch.nn.functional as F
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class CIDALoss(nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
Упрощённая система потерь для CIDA-Plugin.
|
|
39
|
+
|
|
40
|
+
Параметры
|
|
41
|
+
---------
|
|
42
|
+
lambda_cal : float
|
|
43
|
+
Вес calibration loss (Brier score). Рекомендуется 0.3–0.5.
|
|
44
|
+
lambda_ac : float
|
|
45
|
+
Вес anti-collapse loss. Рекомендуется 0.1–0.3.
|
|
46
|
+
Меньше → агенты могут сближаться (риск коллапса).
|
|
47
|
+
Больше → агенты разведены сильнее (риск потери точности).
|
|
48
|
+
min_disagreement : float
|
|
49
|
+
Минимальный порог несогласия. Рекомендуется 0.05–0.15.
|
|
50
|
+
При role priors реальное несогласие ≥ 0.21, поэтому loss
|
|
51
|
+
почти всегда = 0 (действует как safety net).
|
|
52
|
+
multi_label : bool
|
|
53
|
+
Если True — используется BCE вместо CE.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
lambda_cal: float = 0.4,
|
|
59
|
+
lambda_ac: float = 0.2,
|
|
60
|
+
min_disagreement: float = 0.08,
|
|
61
|
+
multi_label: bool = False,
|
|
62
|
+
):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.lambda_cal = lambda_cal
|
|
65
|
+
self.lambda_ac = lambda_ac
|
|
66
|
+
self.min_disagreement = min_disagreement
|
|
67
|
+
self.multi_label = multi_label
|
|
68
|
+
|
|
69
|
+
# ── Компонент 1: Task Loss ────────────────────────────────────────────────
|
|
70
|
+
|
|
71
|
+
def task_loss(self, p_final: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
72
|
+
"""
|
|
73
|
+
Основная задача: классификация.
|
|
74
|
+
"""
|
|
75
|
+
eps = 1e-7
|
|
76
|
+
if self.multi_label:
|
|
77
|
+
return F.binary_cross_entropy(p_final.clamp(eps, 1.0 - eps), y)
|
|
78
|
+
# FIX v4: log-softmax стабильнее чем clamp().log() — согласовано с consensus.py
|
|
79
|
+
log_probs = F.log_softmax(p_final.log().clamp(-10, 10), dim=-1)
|
|
80
|
+
return F.nll_loss(log_probs, y)
|
|
81
|
+
|
|
82
|
+
# ── Компонент 2: Calibration Loss ─────────────────────────────────────────
|
|
83
|
+
|
|
84
|
+
def calibration_loss(self, p_final: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
85
|
+
"""
|
|
86
|
+
Brier score: MSE между вероятностями и метками.
|
|
87
|
+
"""
|
|
88
|
+
if self.multi_label:
|
|
89
|
+
probs = p_final
|
|
90
|
+
return ((probs - y) ** 2).mean()
|
|
91
|
+
|
|
92
|
+
# Single-label: one-hot targets
|
|
93
|
+
K = p_final.size(-1)
|
|
94
|
+
y_oh = F.one_hot(y, num_classes=K).float()
|
|
95
|
+
return ((p_final - y_oh) ** 2).mean()
|
|
96
|
+
|
|
97
|
+
# ── Компонент 3: Anti-Collapse Loss ──────────────────────────────────────
|
|
98
|
+
|
|
99
|
+
def anti_collapse_loss(self, b_final: torch.Tensor) -> torch.Tensor:
|
|
100
|
+
"""
|
|
101
|
+
Предотвращает коллапс агентов к идентичным убеждениям.
|
|
102
|
+
|
|
103
|
+
НЕ является аналогом debate_loss. Ключевые отличия:
|
|
104
|
+
|
|
105
|
+
debate_loss:
|
|
106
|
+
- Целевое расписание d_t на каждый раунд
|
|
107
|
+
- Штрафует если несогласие != цели
|
|
108
|
+
- Создаёт противоречие с consensus
|
|
109
|
+
|
|
110
|
+
anti_collapse_loss:
|
|
111
|
+
- Только минимальная граница min_disagreement
|
|
112
|
+
- Штрафует только при несогласии НИЖЕ порога
|
|
113
|
+
- Нулевой градиент при нормальном обучении (safety net)
|
|
114
|
+
|
|
115
|
+
b_final: (B, M, K) — финальные убеждения агентов
|
|
116
|
+
Returns: scalar (обычно ≈ 0 если role priors работают)
|
|
117
|
+
|
|
118
|
+
Математика:
|
|
119
|
+
d_ij = ||b_i - b_j||_1 — несогласие между агентами i и j
|
|
120
|
+
L_ac = mean(relu(δ - d_ij)) где δ = min_disagreement
|
|
121
|
+
Градиент ненулевой только когда d_ij < δ
|
|
122
|
+
"""
|
|
123
|
+
b_i = b_final.unsqueeze(2) # (B, M, 1, K)
|
|
124
|
+
b_j = b_final.unsqueeze(1) # (B, 1, M, K)
|
|
125
|
+
pairwise_dist = (b_i - b_j).abs().sum(dim=-1) # (B, M, M)
|
|
126
|
+
|
|
127
|
+
# Только нижняя треугольная часть (без диагонали)
|
|
128
|
+
M = b_final.size(1)
|
|
129
|
+
mask = torch.tril(torch.ones(M, M, device=b_final.device), diagonal=-1).bool()
|
|
130
|
+
off_diag = pairwise_dist[:, mask] # (B, n_pairs)
|
|
131
|
+
|
|
132
|
+
return F.relu(self.min_disagreement - off_diag).mean()
|
|
133
|
+
|
|
134
|
+
# ── Основной forward ──────────────────────────────────────────────────────
|
|
135
|
+
|
|
136
|
+
def forward(
|
|
137
|
+
self,
|
|
138
|
+
p_final: torch.Tensor,
|
|
139
|
+
y: torch.Tensor,
|
|
140
|
+
b_all: list,
|
|
141
|
+
orth_loss: torch.Tensor = None,
|
|
142
|
+
) -> tuple[torch.Tensor, dict]:
|
|
143
|
+
"""
|
|
144
|
+
p_final : (B, K) — финальное предсказание
|
|
145
|
+
y : (B,) или (B,K) — метки
|
|
146
|
+
b_all : list[(B,M,K)] — история убеждений по раундам
|
|
147
|
+
orth_loss: scalar или None — orth loss из representations (опционально)
|
|
148
|
+
|
|
149
|
+
Returns: (total_loss, components_dict)
|
|
150
|
+
"""
|
|
151
|
+
l_t = self.task_loss(p_final, y)
|
|
152
|
+
l_c = self.calibration_loss(p_final, y)
|
|
153
|
+
|
|
154
|
+
# Anti-collapse на финальных убеждениях
|
|
155
|
+
b_final = b_all[-1] if b_all else None
|
|
156
|
+
l_ac = self.anti_collapse_loss(b_final) if b_final is not None else torch.tensor(0.0)
|
|
157
|
+
|
|
158
|
+
total = l_t + self.lambda_cal * l_c + self.lambda_ac * l_ac
|
|
159
|
+
|
|
160
|
+
# Orth loss опционально (из representations, не из весов)
|
|
161
|
+
l_orth = torch.tensor(0.0)
|
|
162
|
+
if orth_loss is not None:
|
|
163
|
+
l_orth = orth_loss
|
|
164
|
+
total = total + 0.05 * l_orth # малый вес, это regularizer
|
|
165
|
+
|
|
166
|
+
components = {
|
|
167
|
+
"task": l_t.item(),
|
|
168
|
+
"calibration": l_c.item(),
|
|
169
|
+
"anti_collapse": l_ac.item(),
|
|
170
|
+
"orth": l_orth.item() if isinstance(l_orth, torch.Tensor) else l_orth,
|
|
171
|
+
"total": total.item(),
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
return total, components
|