aragora-client 2.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aragora_client/__init__.py +75 -0
- aragora_client/client.py +544 -0
- aragora_client/exceptions.py +67 -0
- aragora_client/py.typed +0 -0
- aragora_client/types.py +323 -0
- aragora_client/websocket.py +262 -0
- aragora_client-2.1.10.dist-info/METADATA +425 -0
- aragora_client-2.1.10.dist-info/RECORD +9 -0
- aragora_client-2.1.10.dist-info/WHEEL +4 -0
aragora_client/types.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
"""Type definitions for the Aragora SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DebateStatus(str, Enum):
|
|
13
|
+
"""Status of a debate."""
|
|
14
|
+
|
|
15
|
+
PENDING = "pending"
|
|
16
|
+
RUNNING = "running"
|
|
17
|
+
COMPLETED = "completed"
|
|
18
|
+
FAILED = "failed"
|
|
19
|
+
CANCELLED = "cancelled"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class VerificationStatus(str, Enum):
|
|
23
|
+
"""Status of a verification result."""
|
|
24
|
+
|
|
25
|
+
VALID = "valid"
|
|
26
|
+
INVALID = "invalid"
|
|
27
|
+
UNKNOWN = "unknown"
|
|
28
|
+
ERROR = "error"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ConsensusResult(BaseModel):
|
|
32
|
+
"""Result of consensus detection."""
|
|
33
|
+
|
|
34
|
+
reached: bool
|
|
35
|
+
conclusion: str | None = None
|
|
36
|
+
confidence: float = 0.0
|
|
37
|
+
supporting_agents: list[str] = Field(default_factory=list)
|
|
38
|
+
dissenting_agents: list[str] = Field(default_factory=list)
|
|
39
|
+
reasoning: str | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AgentMessage(BaseModel):
|
|
43
|
+
"""A message from an agent during debate."""
|
|
44
|
+
|
|
45
|
+
agent_id: str
|
|
46
|
+
content: str
|
|
47
|
+
round_number: int
|
|
48
|
+
timestamp: datetime
|
|
49
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Debate(BaseModel):
|
|
53
|
+
"""A debate instance."""
|
|
54
|
+
|
|
55
|
+
id: str
|
|
56
|
+
task: str
|
|
57
|
+
status: DebateStatus
|
|
58
|
+
agents: list[str]
|
|
59
|
+
rounds: list[list[AgentMessage]] = Field(default_factory=list)
|
|
60
|
+
consensus: ConsensusResult | None = None
|
|
61
|
+
created_at: datetime
|
|
62
|
+
updated_at: datetime
|
|
63
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class GraphBranch(BaseModel):
|
|
67
|
+
"""A branch in a graph debate."""
|
|
68
|
+
|
|
69
|
+
id: str
|
|
70
|
+
parent_id: str | None = None
|
|
71
|
+
approach: str
|
|
72
|
+
agents: list[str]
|
|
73
|
+
rounds: list[list[AgentMessage]] = Field(default_factory=list)
|
|
74
|
+
consensus: ConsensusResult | None = None
|
|
75
|
+
divergence_score: float = 0.0
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class GraphDebate(BaseModel):
|
|
79
|
+
"""A graph debate with branching."""
|
|
80
|
+
|
|
81
|
+
id: str
|
|
82
|
+
task: str
|
|
83
|
+
status: DebateStatus
|
|
84
|
+
branches: list[GraphBranch] = Field(default_factory=list)
|
|
85
|
+
created_at: datetime
|
|
86
|
+
updated_at: datetime
|
|
87
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class MatrixConclusion(BaseModel):
|
|
91
|
+
"""Conclusions from a matrix debate."""
|
|
92
|
+
|
|
93
|
+
universal: list[str] = Field(default_factory=list)
|
|
94
|
+
conditional: dict[str, list[str]] = Field(default_factory=dict)
|
|
95
|
+
contradictions: list[str] = Field(default_factory=list)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class MatrixScenario(BaseModel):
|
|
99
|
+
"""A scenario in a matrix debate."""
|
|
100
|
+
|
|
101
|
+
name: str
|
|
102
|
+
parameters: dict[str, Any]
|
|
103
|
+
is_baseline: bool = False
|
|
104
|
+
consensus: ConsensusResult | None = None
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class MatrixDebate(BaseModel):
|
|
108
|
+
"""A matrix debate across scenarios."""
|
|
109
|
+
|
|
110
|
+
id: str
|
|
111
|
+
task: str
|
|
112
|
+
status: DebateStatus
|
|
113
|
+
scenarios: list[MatrixScenario] = Field(default_factory=list)
|
|
114
|
+
conclusions: MatrixConclusion | None = None
|
|
115
|
+
created_at: datetime
|
|
116
|
+
updated_at: datetime
|
|
117
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class AgentProfile(BaseModel):
|
|
121
|
+
"""Profile of an agent."""
|
|
122
|
+
|
|
123
|
+
id: str
|
|
124
|
+
name: str
|
|
125
|
+
provider: str
|
|
126
|
+
elo_rating: float = 1500.0
|
|
127
|
+
matches_played: int = 0
|
|
128
|
+
win_rate: float = 0.5
|
|
129
|
+
specialties: list[str] = Field(default_factory=list)
|
|
130
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class VerificationResult(BaseModel):
|
|
134
|
+
"""Result of formal verification."""
|
|
135
|
+
|
|
136
|
+
status: VerificationStatus
|
|
137
|
+
claim: str
|
|
138
|
+
formal_translation: str | None = None
|
|
139
|
+
proof: str | None = None
|
|
140
|
+
counterexample: str | None = None
|
|
141
|
+
backend: str
|
|
142
|
+
duration_ms: int = 0
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class GauntletFinding(BaseModel):
|
|
146
|
+
"""A finding from gauntlet validation."""
|
|
147
|
+
|
|
148
|
+
severity: str
|
|
149
|
+
category: str
|
|
150
|
+
description: str
|
|
151
|
+
location: str | None = None
|
|
152
|
+
suggestion: str | None = None
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class GauntletReceipt(BaseModel):
|
|
156
|
+
"""Receipt from gauntlet validation."""
|
|
157
|
+
|
|
158
|
+
id: str
|
|
159
|
+
score: float
|
|
160
|
+
findings: list[GauntletFinding] = Field(default_factory=list)
|
|
161
|
+
persona: str
|
|
162
|
+
created_at: datetime
|
|
163
|
+
hash: str | None = None
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class MemoryTierStats(BaseModel):
|
|
167
|
+
"""Statistics for a memory tier."""
|
|
168
|
+
|
|
169
|
+
tier: str
|
|
170
|
+
entries: int
|
|
171
|
+
size_bytes: int
|
|
172
|
+
hit_rate: float
|
|
173
|
+
avg_age_seconds: float
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class MemoryAnalytics(BaseModel):
|
|
177
|
+
"""Analytics for the memory system."""
|
|
178
|
+
|
|
179
|
+
total_entries: int
|
|
180
|
+
total_size_bytes: int
|
|
181
|
+
learning_velocity: float
|
|
182
|
+
tiers: list[MemoryTierStats] = Field(default_factory=list)
|
|
183
|
+
period_days: int
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class HealthStatus(BaseModel):
|
|
187
|
+
"""Server health status."""
|
|
188
|
+
|
|
189
|
+
status: str
|
|
190
|
+
version: str
|
|
191
|
+
uptime_seconds: float
|
|
192
|
+
components: dict[str, str] = Field(default_factory=dict)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class DebateEvent(BaseModel):
|
|
196
|
+
"""A WebSocket event from a debate."""
|
|
197
|
+
|
|
198
|
+
type: str
|
|
199
|
+
data: dict[str, Any] = Field(default_factory=dict)
|
|
200
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|
|
201
|
+
loop_id: str | None = None
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class ScorerInfo(BaseModel):
|
|
205
|
+
"""Information about a scorer plugin."""
|
|
206
|
+
|
|
207
|
+
name: str
|
|
208
|
+
description: str
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class TeamSelectorInfo(BaseModel):
|
|
212
|
+
"""Information about a team selector plugin."""
|
|
213
|
+
|
|
214
|
+
name: str
|
|
215
|
+
description: str
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class RoleAssignerInfo(BaseModel):
|
|
219
|
+
"""Information about a role assigner plugin."""
|
|
220
|
+
|
|
221
|
+
name: str
|
|
222
|
+
description: str
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class SelectionPlugins(BaseModel):
|
|
226
|
+
"""Available selection plugins."""
|
|
227
|
+
|
|
228
|
+
scorers: list[ScorerInfo] = Field(default_factory=list)
|
|
229
|
+
team_selectors: list[TeamSelectorInfo] = Field(default_factory=list)
|
|
230
|
+
role_assigners: list[RoleAssignerInfo] = Field(default_factory=list)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class AgentScore(BaseModel):
|
|
234
|
+
"""Score for an agent."""
|
|
235
|
+
|
|
236
|
+
name: str
|
|
237
|
+
score: float
|
|
238
|
+
elo_rating: float
|
|
239
|
+
breakdown: dict[str, float] = Field(default_factory=dict)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class TeamMember(BaseModel):
|
|
243
|
+
"""A member of a selected team."""
|
|
244
|
+
|
|
245
|
+
name: str
|
|
246
|
+
role: str
|
|
247
|
+
score: float
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class TeamSelection(BaseModel):
|
|
251
|
+
"""Result of team selection."""
|
|
252
|
+
|
|
253
|
+
agents: list[TeamMember] = Field(default_factory=list)
|
|
254
|
+
expected_quality: float
|
|
255
|
+
diversity_score: float
|
|
256
|
+
rationale: str
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
# Request/Response types for API calls
|
|
260
|
+
class CreateDebateRequest(BaseModel):
|
|
261
|
+
"""Request to create a debate."""
|
|
262
|
+
|
|
263
|
+
task: str
|
|
264
|
+
agents: list[str] | None = None
|
|
265
|
+
max_rounds: int = 5
|
|
266
|
+
consensus_threshold: float = 0.8
|
|
267
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class CreateGraphDebateRequest(BaseModel):
|
|
271
|
+
"""Request to create a graph debate."""
|
|
272
|
+
|
|
273
|
+
task: str
|
|
274
|
+
agents: list[str] | None = None
|
|
275
|
+
max_rounds: int = 5
|
|
276
|
+
branch_threshold: float = 0.5
|
|
277
|
+
max_branches: int = 10
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class CreateMatrixDebateRequest(BaseModel):
|
|
281
|
+
"""Request to create a matrix debate."""
|
|
282
|
+
|
|
283
|
+
task: str
|
|
284
|
+
scenarios: list[dict[str, Any]]
|
|
285
|
+
agents: list[str] | None = None
|
|
286
|
+
max_rounds: int = 3
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class VerifyClaimRequest(BaseModel):
|
|
290
|
+
"""Request to verify a claim."""
|
|
291
|
+
|
|
292
|
+
claim: str
|
|
293
|
+
backend: str = "z3"
|
|
294
|
+
timeout: int = 30
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class RunGauntletRequest(BaseModel):
|
|
298
|
+
"""Request to run gauntlet validation."""
|
|
299
|
+
|
|
300
|
+
input_content: str
|
|
301
|
+
input_type: str = "spec"
|
|
302
|
+
persona: str = "security"
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class ScoreAgentsRequest(BaseModel):
|
|
306
|
+
"""Request to score agents."""
|
|
307
|
+
|
|
308
|
+
task_description: str
|
|
309
|
+
primary_domain: str | None = None
|
|
310
|
+
scorer: str | None = None
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class SelectTeamRequest(BaseModel):
|
|
314
|
+
"""Request to select a team."""
|
|
315
|
+
|
|
316
|
+
task_description: str
|
|
317
|
+
min_agents: int = 2
|
|
318
|
+
max_agents: int = 5
|
|
319
|
+
diversity_preference: float = 0.5
|
|
320
|
+
quality_priority: float = 0.5
|
|
321
|
+
scorer: str | None = None
|
|
322
|
+
team_selector: str | None = None
|
|
323
|
+
role_assigner: str | None = None
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""WebSocket streaming for the Aragora SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
from collections.abc import AsyncIterator, Callable
|
|
8
|
+
|
|
9
|
+
import websockets
|
|
10
|
+
from websockets.client import WebSocketClientProtocol
|
|
11
|
+
|
|
12
|
+
from aragora_client.types import DebateEvent
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DebateStream:
|
|
16
|
+
"""
|
|
17
|
+
WebSocket stream for debate events.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
>>> stream = DebateStream("ws://localhost:8765", "debate-123")
|
|
21
|
+
>>> stream.on("agent_message", lambda e: print(e.data))
|
|
22
|
+
>>> stream.on("consensus", lambda e: print("Consensus reached!"))
|
|
23
|
+
>>> await stream.connect()
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
base_url: str,
|
|
29
|
+
debate_id: str,
|
|
30
|
+
*,
|
|
31
|
+
reconnect: bool = True,
|
|
32
|
+
reconnect_interval: float = 1.0,
|
|
33
|
+
max_reconnect_attempts: int = 5,
|
|
34
|
+
heartbeat_interval: float = 30.0,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""
|
|
37
|
+
Initialize the debate stream.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
base_url: WebSocket server URL.
|
|
41
|
+
debate_id: ID of the debate to stream.
|
|
42
|
+
reconnect: Whether to auto-reconnect on disconnect.
|
|
43
|
+
reconnect_interval: Base reconnect delay in seconds.
|
|
44
|
+
max_reconnect_attempts: Maximum reconnect attempts.
|
|
45
|
+
heartbeat_interval: Heartbeat ping interval in seconds.
|
|
46
|
+
"""
|
|
47
|
+
# Normalize URL
|
|
48
|
+
if base_url.startswith("http://"):
|
|
49
|
+
base_url = base_url.replace("http://", "ws://")
|
|
50
|
+
elif base_url.startswith("https://"):
|
|
51
|
+
base_url = base_url.replace("https://", "wss://")
|
|
52
|
+
elif not base_url.startswith(("ws://", "wss://")):
|
|
53
|
+
base_url = f"ws://{base_url}"
|
|
54
|
+
|
|
55
|
+
self.base_url = base_url.rstrip("/")
|
|
56
|
+
self.debate_id = debate_id
|
|
57
|
+
self.reconnect = reconnect
|
|
58
|
+
self.reconnect_interval = reconnect_interval
|
|
59
|
+
self.max_reconnect_attempts = max_reconnect_attempts
|
|
60
|
+
self.heartbeat_interval = heartbeat_interval
|
|
61
|
+
|
|
62
|
+
self._ws: WebSocketClientProtocol | None = None
|
|
63
|
+
self._handlers: dict[str, list[Callable[[DebateEvent], None]]] = {}
|
|
64
|
+
self._error_handlers: list[Callable[[Exception], None]] = []
|
|
65
|
+
self._connected = False
|
|
66
|
+
self._should_stop = False
|
|
67
|
+
self._reconnect_attempts = 0
|
|
68
|
+
|
|
69
|
+
def on(
|
|
70
|
+
self, event_type: str, handler: Callable[[DebateEvent], None]
|
|
71
|
+
) -> DebateStream:
|
|
72
|
+
"""
|
|
73
|
+
Register an event handler.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
event_type: Event type to handle (e.g., "agent_message", "consensus").
|
|
77
|
+
handler: Callback function to handle the event.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Self for chaining.
|
|
81
|
+
"""
|
|
82
|
+
if event_type not in self._handlers:
|
|
83
|
+
self._handlers[event_type] = []
|
|
84
|
+
self._handlers[event_type].append(handler)
|
|
85
|
+
return self
|
|
86
|
+
|
|
87
|
+
def on_error(self, handler: Callable[[Exception], None]) -> DebateStream:
|
|
88
|
+
"""
|
|
89
|
+
Register an error handler.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
handler: Callback function to handle errors.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Self for chaining.
|
|
96
|
+
"""
|
|
97
|
+
self._error_handlers.append(handler)
|
|
98
|
+
return self
|
|
99
|
+
|
|
100
|
+
async def connect(self) -> None:
|
|
101
|
+
"""Connect to the WebSocket and start receiving events."""
|
|
102
|
+
self._should_stop = False
|
|
103
|
+
ws_url = f"{self.base_url}/ws/debates/{self.debate_id}"
|
|
104
|
+
|
|
105
|
+
while not self._should_stop:
|
|
106
|
+
try:
|
|
107
|
+
async with websockets.connect(ws_url) as ws:
|
|
108
|
+
self._ws = ws
|
|
109
|
+
self._connected = True
|
|
110
|
+
self._reconnect_attempts = 0
|
|
111
|
+
|
|
112
|
+
# Start heartbeat
|
|
113
|
+
heartbeat_task = asyncio.create_task(self._heartbeat())
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
async for message in ws:
|
|
117
|
+
if self._should_stop:
|
|
118
|
+
break
|
|
119
|
+
await self._handle_message(message)
|
|
120
|
+
finally:
|
|
121
|
+
heartbeat_task.cancel()
|
|
122
|
+
try:
|
|
123
|
+
await heartbeat_task
|
|
124
|
+
except asyncio.CancelledError:
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
except websockets.ConnectionClosed:
|
|
128
|
+
self._connected = False
|
|
129
|
+
if not self.reconnect or self._should_stop:
|
|
130
|
+
break
|
|
131
|
+
if self._reconnect_attempts >= self.max_reconnect_attempts:
|
|
132
|
+
max_attempts = self.max_reconnect_attempts
|
|
133
|
+
self._emit_error(
|
|
134
|
+
Exception(f"Max reconnect attempts ({max_attempts}) exceeded")
|
|
135
|
+
)
|
|
136
|
+
break
|
|
137
|
+
self._reconnect_attempts += 1
|
|
138
|
+
delay = self.reconnect_interval * (2 ** (self._reconnect_attempts - 1))
|
|
139
|
+
await asyncio.sleep(delay)
|
|
140
|
+
|
|
141
|
+
except Exception as e:
|
|
142
|
+
self._connected = False
|
|
143
|
+
self._emit_error(e)
|
|
144
|
+
if not self.reconnect or self._should_stop:
|
|
145
|
+
break
|
|
146
|
+
self._reconnect_attempts += 1
|
|
147
|
+
if self._reconnect_attempts >= self.max_reconnect_attempts:
|
|
148
|
+
break
|
|
149
|
+
delay = self.reconnect_interval * (2 ** (self._reconnect_attempts - 1))
|
|
150
|
+
await asyncio.sleep(delay)
|
|
151
|
+
|
|
152
|
+
async def disconnect(self) -> None:
|
|
153
|
+
"""Disconnect from the WebSocket."""
|
|
154
|
+
self._should_stop = True
|
|
155
|
+
self._connected = False
|
|
156
|
+
if self._ws:
|
|
157
|
+
await self._ws.close()
|
|
158
|
+
self._ws = None
|
|
159
|
+
|
|
160
|
+
async def _heartbeat(self) -> None:
|
|
161
|
+
"""Send periodic heartbeat pings."""
|
|
162
|
+
while self._connected and self._ws:
|
|
163
|
+
try:
|
|
164
|
+
await asyncio.sleep(self.heartbeat_interval)
|
|
165
|
+
if self._ws and self._connected:
|
|
166
|
+
await self._ws.ping()
|
|
167
|
+
except Exception:
|
|
168
|
+
break
|
|
169
|
+
|
|
170
|
+
async def _handle_message(self, message: str | bytes) -> None:
|
|
171
|
+
"""Handle an incoming WebSocket message."""
|
|
172
|
+
try:
|
|
173
|
+
if isinstance(message, bytes):
|
|
174
|
+
message = message.decode("utf-8")
|
|
175
|
+
|
|
176
|
+
data = json.loads(message)
|
|
177
|
+
event = DebateEvent(
|
|
178
|
+
type=data.get("type", "unknown"),
|
|
179
|
+
data=data.get("data", {}),
|
|
180
|
+
loop_id=data.get("loop_id") or data.get("data", {}).get("debate_id"),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Call handlers for this event type
|
|
184
|
+
handlers = self._handlers.get(event.type, [])
|
|
185
|
+
for handler in handlers:
|
|
186
|
+
try:
|
|
187
|
+
handler(event)
|
|
188
|
+
except Exception as e:
|
|
189
|
+
self._emit_error(e)
|
|
190
|
+
|
|
191
|
+
# Call handlers for "*" (all events)
|
|
192
|
+
for handler in self._handlers.get("*", []):
|
|
193
|
+
try:
|
|
194
|
+
handler(event)
|
|
195
|
+
except Exception as e:
|
|
196
|
+
self._emit_error(e)
|
|
197
|
+
|
|
198
|
+
except json.JSONDecodeError as e:
|
|
199
|
+
self._emit_error(e)
|
|
200
|
+
|
|
201
|
+
def _emit_error(self, error: Exception) -> None:
|
|
202
|
+
"""Emit an error to all error handlers."""
|
|
203
|
+
for handler in self._error_handlers:
|
|
204
|
+
try:
|
|
205
|
+
handler(error)
|
|
206
|
+
except Exception:
|
|
207
|
+
pass # Avoid infinite loops
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def connected(self) -> bool:
|
|
211
|
+
"""Whether the stream is currently connected."""
|
|
212
|
+
return self._connected
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
async def stream_debate(
|
|
216
|
+
base_url: str,
|
|
217
|
+
debate_id: str,
|
|
218
|
+
*,
|
|
219
|
+
reconnect: bool = True,
|
|
220
|
+
) -> AsyncIterator[DebateEvent]:
|
|
221
|
+
"""
|
|
222
|
+
Stream debate events as an async iterator.
|
|
223
|
+
|
|
224
|
+
Example:
|
|
225
|
+
>>> async for event in stream_debate("ws://localhost:8765", "debate-123"):
|
|
226
|
+
... print(event.type, event.data)
|
|
227
|
+
... if event.type == "debate_end":
|
|
228
|
+
... break
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
base_url: WebSocket server URL.
|
|
232
|
+
debate_id: ID of the debate to stream.
|
|
233
|
+
reconnect: Whether to auto-reconnect on disconnect.
|
|
234
|
+
|
|
235
|
+
Yields:
|
|
236
|
+
DebateEvent objects.
|
|
237
|
+
"""
|
|
238
|
+
# Normalize URL
|
|
239
|
+
if base_url.startswith("http://"):
|
|
240
|
+
base_url = base_url.replace("http://", "ws://")
|
|
241
|
+
elif base_url.startswith("https://"):
|
|
242
|
+
base_url = base_url.replace("https://", "wss://")
|
|
243
|
+
elif not base_url.startswith(("ws://", "wss://")):
|
|
244
|
+
base_url = f"ws://{base_url}"
|
|
245
|
+
|
|
246
|
+
ws_url = f"{base_url.rstrip('/')}/ws/debates/{debate_id}"
|
|
247
|
+
|
|
248
|
+
async with websockets.connect(ws_url) as ws:
|
|
249
|
+
async for message in ws:
|
|
250
|
+
if isinstance(message, bytes):
|
|
251
|
+
message = message.decode("utf-8")
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
data = json.loads(message)
|
|
255
|
+
yield DebateEvent(
|
|
256
|
+
type=data.get("type", "unknown"),
|
|
257
|
+
data=data.get("data", {}),
|
|
258
|
+
loop_id=data.get("loop_id")
|
|
259
|
+
or data.get("data", {}).get("debate_id"),
|
|
260
|
+
)
|
|
261
|
+
except json.JSONDecodeError:
|
|
262
|
+
continue
|