cida-plugin 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cida_plugin/__init__.py +62 -0
- cida_plugin/agent.py +198 -0
- cida_plugin/config.py +228 -0
- cida_plugin/consensus.py +178 -0
- cida_plugin/core.py +451 -0
- cida_plugin/deliberation.py +376 -0
- cida_plugin/diagnostics.py +124 -0
- cida_plugin/hf.py +194 -0
- cida_plugin/liquid_dynamics.py +291 -0
- cida_plugin/losses.py +174 -0
- cida_plugin/translator.py +74 -0
- cida_plugin/ttt.py +348 -0
- cida_plugin/vision_backbone.py +43 -0
- cida_plugin-1.0.0.dist-info/METADATA +167 -0
- cida_plugin-1.0.0.dist-info/RECORD +18 -0
- cida_plugin-1.0.0.dist-info/WHEEL +5 -0
- cida_plugin-1.0.0.dist-info/licenses/LICENSE +201 -0
- cida_plugin-1.0.0.dist-info/top_level.txt +1 -0
cida_plugin/core.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CIDAPlugin — Universal Multi-Agent Deliberation Layer (v3 Simplified).
|
|
3
|
+
|
|
4
|
+
Architecture:
|
|
5
|
+
1. Input Projection: d_input → d_hidden
|
|
6
|
+
2. Agent Initialization with Role Priors (structural disagreement)
|
|
7
|
+
3. Deliberation Loop (evidence → messages → communication → update)
|
|
8
|
+
4. Weighted Mean Consensus (correct for correlated agents)
|
|
9
|
+
5. Uncertainty = f(disagreement) — observed, not learned
|
|
10
|
+
"""
|
|
11
|
+
import os
|
|
12
|
+
import warnings
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
import torch.nn.functional as F
|
|
16
|
+
|
|
17
|
+
from .config import CIDAPluginConfig
|
|
18
|
+
from .agent import AgentState, apply_role_prior, apply_role_priors_batched, compute_role_orthogonality_loss
|
|
19
|
+
from .deliberation import (
|
|
20
|
+
AgentEvidenceExtractor,
|
|
21
|
+
MessageFormulator,
|
|
22
|
+
CounterargumentCommunication,
|
|
23
|
+
AgentUpdater,
|
|
24
|
+
PerspectiveProjector,
|
|
25
|
+
)
|
|
26
|
+
from .consensus import ConsensusAggregator, HaltingPredictor
|
|
27
|
+
from .ttt import TestTimeTrainer
|
|
28
|
+
from .liquid_dynamics import LiquidDeliberationSolver
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CIDAPlugin(nn.Module):
|
|
32
|
+
"""
|
|
33
|
+
Universal CIDA-Plugin.
|
|
34
|
+
|
|
35
|
+
Accepts pooled_output from ANY encoder via input_proj,
|
|
36
|
+
runs multi-agent deliberation with structural role priors,
|
|
37
|
+
returns calibrated class distribution with uncertainty.
|
|
38
|
+
|
|
39
|
+
Example:
|
|
40
|
+
cfg = CIDAPluginConfig(d_input=768, num_classes=2)
|
|
41
|
+
plugin = CIDAPlugin(cfg)
|
|
42
|
+
out = plugin(pooled) # (B, 768) → dict
|
|
43
|
+
logits = out["p_final"] # (B, 2)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, config: CIDAPluginConfig):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.config = config
|
|
49
|
+
cfg = config
|
|
50
|
+
|
|
51
|
+
# ── Input projection: d_input → d_hidden ──────────────────────────────
|
|
52
|
+
self.input_proj = nn.Sequential(
|
|
53
|
+
nn.Linear(cfg.d_input, cfg.d_hidden),
|
|
54
|
+
nn.LayerNorm(cfg.d_hidden),
|
|
55
|
+
nn.SiLU(),
|
|
56
|
+
)
|
|
57
|
+
if cfg.freeze_input_proj:
|
|
58
|
+
for p in self.input_proj.parameters():
|
|
59
|
+
p.requires_grad = False
|
|
60
|
+
|
|
61
|
+
# Projection for sequence output (evidence pointers)
|
|
62
|
+
self.seq_proj = (
|
|
63
|
+
nn.Linear(cfg.d_input, cfg.d_hidden)
|
|
64
|
+
if cfg.d_input != cfg.d_hidden
|
|
65
|
+
else nn.Identity()
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# ── Perspective Projector (optional per-agent projections) ─────────────
|
|
69
|
+
if cfg.use_perspective_projector:
|
|
70
|
+
self.perspective_projector = PerspectiveProjector(
|
|
71
|
+
cfg.num_agents, cfg.d_input, cfg.d_hidden
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
self.perspective_projector = None
|
|
75
|
+
|
|
76
|
+
# ── Deliberation components ───────────────────────────────────────────
|
|
77
|
+
self.evidence_extractor = AgentEvidenceExtractor(cfg.d_hidden)
|
|
78
|
+
self.message_formulator = MessageFormulator(
|
|
79
|
+
cfg.d_hidden, cfg.num_classes, cfg.d_message,
|
|
80
|
+
anonymous=cfg.anonymous_messages,
|
|
81
|
+
)
|
|
82
|
+
self.communication = CounterargumentCommunication(
|
|
83
|
+
cfg.d_hidden, cfg.d_message
|
|
84
|
+
)
|
|
85
|
+
self.updater = AgentUpdater(
|
|
86
|
+
cfg.d_hidden,
|
|
87
|
+
cfg.num_classes,
|
|
88
|
+
num_heads=cfg.num_attn_heads,
|
|
89
|
+
multi_label=cfg.multi_label,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# ── Consensus (Weighted Mean for correlated agents) ───────────────────
|
|
93
|
+
self.consensus = ConsensusAggregator(
|
|
94
|
+
num_agents=cfg.num_agents,
|
|
95
|
+
multi_label=cfg.multi_label,
|
|
96
|
+
use_dynamic_trust=cfg.use_dynamic_trust,
|
|
97
|
+
trust_ema_gamma=cfg.trust_ema_gamma,
|
|
98
|
+
)
|
|
99
|
+
self.halter = HaltingPredictor()
|
|
100
|
+
|
|
101
|
+
# ── Gated communication warm-up ───────────────────────────────────────
|
|
102
|
+
self.comm_gate = nn.Sequential(
|
|
103
|
+
nn.Linear(cfg.d_hidden, 1),
|
|
104
|
+
nn.Sigmoid(),
|
|
105
|
+
)
|
|
106
|
+
nn.init.constant_(self.comm_gate[0].bias, -0.5)
|
|
107
|
+
|
|
108
|
+
# ── Test-Time Training (optional) ─────────────────────────────────────
|
|
109
|
+
if cfg.use_ttt:
|
|
110
|
+
self.ttt_trainer = TestTimeTrainer(
|
|
111
|
+
d_hidden=cfg.d_hidden,
|
|
112
|
+
ttt_steps=cfg.ttt_steps,
|
|
113
|
+
ttt_lr=cfg.ttt_lr,
|
|
114
|
+
mask_ratio=cfg.ttt_mask_ratio,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# ── Liquid Neural ODE Dynamics (optional) ─────────────────────────────
|
|
118
|
+
if cfg.use_liquid_dynamics:
|
|
119
|
+
self.liquid_solver = LiquidDeliberationSolver(
|
|
120
|
+
extractor=self.evidence_extractor,
|
|
121
|
+
message_formulator=self.message_formulator,
|
|
122
|
+
communication=self.communication,
|
|
123
|
+
updater=self.updater,
|
|
124
|
+
d_hidden=cfg.d_hidden,
|
|
125
|
+
solver=cfg.liquid_solver,
|
|
126
|
+
atol=cfg.liquid_atol,
|
|
127
|
+
rtol=cfg.liquid_rtol,
|
|
128
|
+
T=float(cfg.max_rounds),
|
|
129
|
+
trajectory_save_every=cfg.trajectory_save_every,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# ── Hugging Face Serialization ────────────────────────────────────────────
|
|
133
|
+
|
|
134
|
+
def save_pretrained(self, save_directory: str):
|
|
135
|
+
"""Save config and weights for later loading or HF Hub publishing."""
|
|
136
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
137
|
+
self.config.save_pretrained(save_directory)
|
|
138
|
+
weights_path = os.path.join(save_directory, "pytorch_model.bin")
|
|
139
|
+
torch.save(self.state_dict(), weights_path)
|
|
140
|
+
|
|
141
|
+
@classmethod
|
|
142
|
+
def from_pretrained(cls, pretrained_model_name_or_path: str):
|
|
143
|
+
"""Load a pretrained plugin (config + weights)."""
|
|
144
|
+
config = CIDAPluginConfig.from_pretrained(
|
|
145
|
+
pretrained_model_name_or_path
|
|
146
|
+
)
|
|
147
|
+
model = cls(config)
|
|
148
|
+
|
|
149
|
+
if os.path.isdir(pretrained_model_name_or_path):
|
|
150
|
+
weights_path = os.path.join(
|
|
151
|
+
pretrained_model_name_or_path, "pytorch_model.bin"
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
from huggingface_hub import hf_hub_download
|
|
155
|
+
weights_path = hf_hub_download(
|
|
156
|
+
repo_id=pretrained_model_name_or_path,
|
|
157
|
+
filename="pytorch_model.bin",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
state_dict = torch.load(weights_path, map_location="cpu")
|
|
161
|
+
model.load_state_dict(state_dict)
|
|
162
|
+
return model
|
|
163
|
+
|
|
164
|
+
# ── Forward ───────────────────────────────────────────────────────────────
|
|
165
|
+
|
|
166
|
+
def forward(
|
|
167
|
+
self,
|
|
168
|
+
pooled: torch.Tensor,
|
|
169
|
+
seq_output: torch.Tensor = None,
|
|
170
|
+
mask: torch.Tensor = None,
|
|
171
|
+
) -> dict:
|
|
172
|
+
"""
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
pooled : (B, d_input) pooled representation from encoder
|
|
176
|
+
seq_output : (B, seq_len, d_input) optional — for evidence pointers
|
|
177
|
+
mask : (B, seq_len) padding mask
|
|
178
|
+
|
|
179
|
+
Returns dict:
|
|
180
|
+
p_final : (B, K) — final calibrated distribution
|
|
181
|
+
b_all : list — belief history per round
|
|
182
|
+
u_all : list — uncertainty history
|
|
183
|
+
m_all : list — message history
|
|
184
|
+
p_ptr_all : list — evidence pointer history
|
|
185
|
+
uncertainty : (B,) — final uncertainty
|
|
186
|
+
disagreement : (B,) — final agent disagreement
|
|
187
|
+
L_orth : scalar — orthogonality loss (representations)
|
|
188
|
+
rounds_used : int — actual number of rounds used
|
|
189
|
+
ponder_cost : (B,) or None — NFE for liquid dynamics
|
|
190
|
+
ttt_info : dict or None — TTT diagnostics
|
|
191
|
+
"""
|
|
192
|
+
cfg = self.config
|
|
193
|
+
B = pooled.size(0)
|
|
194
|
+
M = cfg.num_agents
|
|
195
|
+
|
|
196
|
+
# ── Project input ─────────────────────────────────────────────────────
|
|
197
|
+
h_cls = self.input_proj(pooled) # (B, d_hidden)
|
|
198
|
+
|
|
199
|
+
H = None
|
|
200
|
+
if seq_output is not None:
|
|
201
|
+
H = self.seq_proj(seq_output) # (B, seq_len, d_hidden)
|
|
202
|
+
|
|
203
|
+
# ── Initialize agents ─────────────────────────────────────────────────
|
|
204
|
+
# All agents start from the same representation (no RoleEmbeddings).
|
|
205
|
+
# Structural diversity comes from apply_role_prior on beliefs.
|
|
206
|
+
if self.perspective_projector is not None:
|
|
207
|
+
# Per-agent projections for diverse initial views
|
|
208
|
+
s_t = self.perspective_projector(pooled) # (B, M, d_hidden)
|
|
209
|
+
else:
|
|
210
|
+
# All agents start from the same representation
|
|
211
|
+
s_t = h_cls.unsqueeze(1).expand(B, M, cfg.d_hidden).clone()
|
|
212
|
+
|
|
213
|
+
seq_len = H.size(1) if H is not None else 1
|
|
214
|
+
e_init = torch.zeros(B, M, cfg.d_hidden, device=pooled.device)
|
|
215
|
+
r_init = torch.zeros_like(e_init)
|
|
216
|
+
|
|
217
|
+
if cfg.multi_label:
|
|
218
|
+
b_init = torch.full(
|
|
219
|
+
(B, M, cfg.num_classes), 0.5, device=pooled.device
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
b_init = torch.zeros(
|
|
223
|
+
B, M, cfg.num_classes, device=pooled.device
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
state_t = AgentState(
|
|
227
|
+
s=s_t,
|
|
228
|
+
b=b_init,
|
|
229
|
+
u=torch.ones(B, M, 1, device=pooled.device),
|
|
230
|
+
p=torch.zeros(B, M, seq_len, device=pooled.device),
|
|
231
|
+
e=e_init,
|
|
232
|
+
alpha=torch.ones(B, M, cfg.num_classes, device=pooled.device),
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Initial update to get first beliefs from data
|
|
236
|
+
state_t = self.updater(state_t, r_init, e_init)
|
|
237
|
+
|
|
238
|
+
# Apply structural role priors to initial beliefs (Batched)
|
|
239
|
+
state_t.b = apply_role_priors_batched(
|
|
240
|
+
state_t.b,
|
|
241
|
+
M,
|
|
242
|
+
multi_label=cfg.multi_label,
|
|
243
|
+
positive_class_idx=cfg.positive_class_idx,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# ── TTT: adapt before deliberation ────────────────────────────────────
|
|
247
|
+
ttt_info = None
|
|
248
|
+
_ttt_saved_updater = None
|
|
249
|
+
_ttt_saved_head = None
|
|
250
|
+
if cfg.use_ttt:
|
|
251
|
+
ttt_info = self.ttt_trainer.adapt_and_get_info(
|
|
252
|
+
self.updater, state_t.s
|
|
253
|
+
)
|
|
254
|
+
_ttt_saved_updater = ttt_info.pop("saved_updater_state")
|
|
255
|
+
_ttt_saved_head = ttt_info.pop("saved_head_state")
|
|
256
|
+
|
|
257
|
+
# ── Deliberation ──────────────────────────────────────────────────────
|
|
258
|
+
b_all, u_all, m_all, p_ptr_all = [], [], [], []
|
|
259
|
+
rounds_used = 0
|
|
260
|
+
ponder_cost = None
|
|
261
|
+
|
|
262
|
+
if cfg.use_liquid_dynamics and H is not None:
|
|
263
|
+
# ── Continuous dynamics via ODE ────────────────────────────────────
|
|
264
|
+
s_final, traj = self.liquid_solver(state_t.s, H, mask)
|
|
265
|
+
b_all = traj["b_all"]
|
|
266
|
+
u_all = traj["u_all"]
|
|
267
|
+
p_ptr_all = traj["p_ptr_all"]
|
|
268
|
+
m_all = traj["m_all"]
|
|
269
|
+
|
|
270
|
+
if len(b_all) > 0:
|
|
271
|
+
state_t.s = s_final
|
|
272
|
+
# Final update to get beliefs from final state
|
|
273
|
+
e_final = s_final
|
|
274
|
+
if H is not None and not cfg.abl_no_pointers:
|
|
275
|
+
_, e_final = self.evidence_extractor(s_final, H, mask)
|
|
276
|
+
|
|
277
|
+
r_final = torch.zeros_like(s_final)
|
|
278
|
+
if not cfg.abl_no_messages and not cfg.abl_no_communication:
|
|
279
|
+
num_classes = self.updater.g[-1].out_features
|
|
280
|
+
dummy_b = torch.zeros(
|
|
281
|
+
B, M, num_classes, device=pooled.device
|
|
282
|
+
)
|
|
283
|
+
dummy_u = torch.ones(B, M, 1, device=pooled.device)
|
|
284
|
+
m_final = self.message_formulator(
|
|
285
|
+
s_final, e_final, dummy_b, dummy_u
|
|
286
|
+
)
|
|
287
|
+
r_final = self.communication(
|
|
288
|
+
s_final, m_final, dummy_b, e_final
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
dummy_state = AgentState(
|
|
292
|
+
s=s_final,
|
|
293
|
+
b=b_all[-1] if b_all else state_t.b,
|
|
294
|
+
u=u_all[-1] if u_all else state_t.u,
|
|
295
|
+
p=torch.zeros(B, M, 1, device=pooled.device),
|
|
296
|
+
e=e_final,
|
|
297
|
+
alpha=torch.ones(
|
|
298
|
+
B, M, cfg.num_classes, device=pooled.device
|
|
299
|
+
),
|
|
300
|
+
)
|
|
301
|
+
state_t = self.updater(dummy_state, r_final, e_final)
|
|
302
|
+
|
|
303
|
+
# Apply role priors to final beliefs (Batched)
|
|
304
|
+
state_t.b = apply_role_priors_batched(
|
|
305
|
+
state_t.b,
|
|
306
|
+
M,
|
|
307
|
+
multi_label=cfg.multi_label,
|
|
308
|
+
positive_class_idx=cfg.positive_class_idx,
|
|
309
|
+
)
|
|
310
|
+
b_all[-1] = state_t.b
|
|
311
|
+
|
|
312
|
+
ponder_cost = torch.full(
|
|
313
|
+
(B,), float(traj["nfe"]), device=pooled.device
|
|
314
|
+
)
|
|
315
|
+
rounds_used = traj["nfe"]
|
|
316
|
+
|
|
317
|
+
else:
|
|
318
|
+
# ── Discrete deliberation loop ────────────────────────────────────
|
|
319
|
+
if cfg.use_liquid_dynamics and H is None:
|
|
320
|
+
# ── Silent fallback warning ────────────────────────────────────
|
|
321
|
+
warnings.warn(
|
|
322
|
+
"use_liquid_dynamics=True but seq_output is None. "
|
|
323
|
+
"Falling back to discrete deliberation mode. "
|
|
324
|
+
"Provide seq_output for continuous ODE dynamics.",
|
|
325
|
+
UserWarning,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
from torch.utils.checkpoint import checkpoint
|
|
329
|
+
|
|
330
|
+
prev_disagreement = float('inf')
|
|
331
|
+
rollback_triggered = False
|
|
332
|
+
|
|
333
|
+
for t in range(cfg.max_rounds):
|
|
334
|
+
rounds_used = t + 1
|
|
335
|
+
|
|
336
|
+
def deliberation_round(s_in, b_in, u_in, p_in, e_in):
|
|
337
|
+
state = AgentState(s=s_in, b=b_in, u=u_in, p=p_in, e=e_in, alpha=torch.ones_like(b_in))
|
|
338
|
+
|
|
339
|
+
# Evidence Pointers
|
|
340
|
+
if cfg.abl_no_pointers or H is None:
|
|
341
|
+
p_new = (torch.ones(B, M, seq_len, device=pooled.device) / seq_len)
|
|
342
|
+
e_new = state.s
|
|
343
|
+
else:
|
|
344
|
+
p_new, e_new = self.evidence_extractor(state.s, H, mask)
|
|
345
|
+
state.p = p_new
|
|
346
|
+
state.e = e_new
|
|
347
|
+
|
|
348
|
+
# Message Formulation
|
|
349
|
+
if cfg.abl_no_messages:
|
|
350
|
+
m_new = torch.zeros(B, M, cfg.d_message, device=pooled.device)
|
|
351
|
+
else:
|
|
352
|
+
m_new = self.message_formulator(state.s, state.e, state.b, state.u)
|
|
353
|
+
|
|
354
|
+
# Communication
|
|
355
|
+
if cfg.abl_no_communication:
|
|
356
|
+
r_new = torch.zeros_like(state.s)
|
|
357
|
+
else:
|
|
358
|
+
r_new = self.communication(state.s, m_new, state.b, state.e, sparse_k=cfg.sparse_comm_k)
|
|
359
|
+
gate = self.comm_gate(state.s)
|
|
360
|
+
r_new = gate * r_new
|
|
361
|
+
if self.training and cfg.comm_dropout > 0:
|
|
362
|
+
r_new = F.dropout(r_new, p=cfg.comm_dropout, training=True)
|
|
363
|
+
|
|
364
|
+
# Agent Update
|
|
365
|
+
state = self.updater(state, r_new, e_new)
|
|
366
|
+
|
|
367
|
+
# Apply role priors to updated beliefs
|
|
368
|
+
state.b = apply_role_priors_batched(
|
|
369
|
+
state.b, M, multi_label=cfg.multi_label, positive_class_idx=cfg.positive_class_idx
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
return state.s, state.b, state.u, state.p, state.e, m_new
|
|
373
|
+
|
|
374
|
+
if self.training and state_t.s.requires_grad and pooled.is_cuda:
|
|
375
|
+
s_t, b_t, u_t, p_t, e_t, m_t = checkpoint(
|
|
376
|
+
deliberation_round, state_t.s, state_t.b, state_t.u, state_t.p, state_t.e,
|
|
377
|
+
use_reentrant=False
|
|
378
|
+
)
|
|
379
|
+
else:
|
|
380
|
+
s_t, b_t, u_t, p_t, e_t, m_t = deliberation_round(
|
|
381
|
+
state_t.s, state_t.b, state_t.u, state_t.p, state_t.e
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# ── Kairos Dynamic Deliberation: Time-Travel (Rollback) ────────
|
|
385
|
+
if getattr(cfg, 'enable_rollback', False) and not self.training:
|
|
386
|
+
with torch.no_grad():
|
|
387
|
+
_, _, current_d = self.consensus(b_t)
|
|
388
|
+
current_d_mean = current_d.mean().item()
|
|
389
|
+
|
|
390
|
+
if current_d_mean > prev_disagreement:
|
|
391
|
+
# Entropy increased -> Rollback!
|
|
392
|
+
noise_std = getattr(cfg, 'rollback_noise_std', 0.05)
|
|
393
|
+
noise = torch.randn_like(state_t.s) * noise_std
|
|
394
|
+
state_t.s = state_t.s + noise
|
|
395
|
+
rollback_triggered = True
|
|
396
|
+
prev_disagreement = float('inf') # Reset to allow new path
|
|
397
|
+
continue # Skip update and try next round
|
|
398
|
+
|
|
399
|
+
prev_disagreement = current_d_mean
|
|
400
|
+
# ───────────────────────────────────────────────────────────────
|
|
401
|
+
|
|
402
|
+
# Update main state
|
|
403
|
+
state_t.s, state_t.b, state_t.u, state_t.p, state_t.e = s_t, b_t, u_t, p_t, e_t
|
|
404
|
+
|
|
405
|
+
p_ptr_all.append(state_t.p)
|
|
406
|
+
m_all.append(m_t)
|
|
407
|
+
b_all.append(state_t.b)
|
|
408
|
+
u_all.append(state_t.u)
|
|
409
|
+
|
|
410
|
+
# Confidence-based Early Stopping (inference only)
|
|
411
|
+
if not self.training and cfg.early_stop_threshold is not None:
|
|
412
|
+
with torch.no_grad():
|
|
413
|
+
tmp_p, _, tmp_d = self.consensus(state_t.b)
|
|
414
|
+
if self.halter.should_halt(tmp_d).all():
|
|
415
|
+
break
|
|
416
|
+
if (
|
|
417
|
+
tmp_p.max(dim=-1).values.mean().item()
|
|
418
|
+
>= cfg.early_stop_threshold
|
|
419
|
+
):
|
|
420
|
+
break
|
|
421
|
+
|
|
422
|
+
# ── Final Consensus ───────────────────────────────────────────────────
|
|
423
|
+
p_final, uncertainty, disagreement = self.consensus(
|
|
424
|
+
state_t.b, state_t.u
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
# ── Orthogonality loss (on representations, not weights) ──────────────
|
|
428
|
+
L_orth = compute_role_orthogonality_loss(state_t.s)
|
|
429
|
+
|
|
430
|
+
# ── TTT: restore weights ──────────────────────────────────────────────
|
|
431
|
+
if cfg.use_ttt and _ttt_saved_updater is not None:
|
|
432
|
+
self.ttt_trainer.restore_weights(
|
|
433
|
+
self.updater, _ttt_saved_updater, _ttt_saved_head
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
return {
|
|
437
|
+
"p_final": p_final,
|
|
438
|
+
"b_all": b_all,
|
|
439
|
+
"u_all": u_all,
|
|
440
|
+
"m_all": m_all,
|
|
441
|
+
"p_ptr_all": p_ptr_all,
|
|
442
|
+
"uncertainty": uncertainty,
|
|
443
|
+
"disagreement": disagreement,
|
|
444
|
+
"u_epi": state_t.u_epi,
|
|
445
|
+
"u_alea": state_t.u_alea,
|
|
446
|
+
"L_orth": L_orth,
|
|
447
|
+
"state_final": state_t,
|
|
448
|
+
"rounds_used": rounds_used,
|
|
449
|
+
"ponder_cost": ponder_cost,
|
|
450
|
+
"ttt_info": ttt_info,
|
|
451
|
+
}
|