robosystems-client 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of robosystems-client might be problematic. Click here for more details.
- robosystems_client/api/agent/auto_select_agent.py +81 -9
- robosystems_client/api/agent/execute_specific_agent.py +73 -5
- robosystems_client/api/agent/get_agent_metadata.py +1 -1
- robosystems_client/api/agent/list_agents.py +1 -1
- robosystems_client/api/billing/{update_org_payment_method.py → create_portal_session.py} +57 -47
- robosystems_client/api/credits_/list_credit_transactions.py +4 -4
- robosystems_client/api/subgraphs/create_subgraph.py +4 -4
- robosystems_client/api/subgraphs/delete_subgraph.py +8 -8
- robosystems_client/extensions/__init__.py +25 -0
- robosystems_client/extensions/agent_client.py +526 -0
- robosystems_client/extensions/extensions.py +3 -0
- robosystems_client/models/__init__.py +2 -6
- robosystems_client/models/checkout_status_response.py +2 -1
- robosystems_client/models/create_checkout_request.py +2 -1
- robosystems_client/models/create_checkout_request_resource_config.py +4 -1
- robosystems_client/models/create_subgraph_request.py +5 -26
- robosystems_client/models/graph_subscription_response.py +21 -0
- robosystems_client/models/list_subgraphs_response.py +9 -0
- robosystems_client/models/{update_payment_method_request.py → portal_session_response.py} +12 -12
- {robosystems_client-0.2.12.dist-info → robosystems_client-0.2.14.dist-info}/METADATA +1 -1
- {robosystems_client-0.2.12.dist-info → robosystems_client-0.2.14.dist-info}/RECORD +23 -25
- robosystems_client/api/subscriptions/cancel_subscription.py +0 -193
- robosystems_client/models/cancellation_response.py +0 -76
- robosystems_client/models/update_payment_method_response.py +0 -74
- {robosystems_client-0.2.12.dist-info → robosystems_client-0.2.14.dist-info}/WHEEL +0 -0
- {robosystems_client-0.2.12.dist-info → robosystems_client-0.2.14.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
"""Enhanced Agent Client with SSE support
|
|
2
|
+
|
|
3
|
+
Provides intelligent agent execution with automatic strategy selection.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Dict, Any, Optional, Callable
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
|
|
10
|
+
from ..api.agent.auto_select_agent import sync_detailed as auto_select_agent
|
|
11
|
+
from ..api.agent.execute_specific_agent import sync_detailed as execute_specific_agent
|
|
12
|
+
from ..models.agent_request import AgentRequest
|
|
13
|
+
from ..models.agent_message import AgentMessage
|
|
14
|
+
from .sse_client import SSEClient, SSEConfig, EventType
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class AgentQueryRequest:
|
|
19
|
+
"""Request object for agent queries"""
|
|
20
|
+
|
|
21
|
+
message: str
|
|
22
|
+
history: Optional[list] = None
|
|
23
|
+
context: Optional[Dict[str, Any]] = None
|
|
24
|
+
mode: Optional[str] = None # 'quick', 'standard', 'extended', 'streaming'
|
|
25
|
+
enable_rag: Optional[bool] = None
|
|
26
|
+
force_extended_analysis: Optional[bool] = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class AgentOptions:
|
|
31
|
+
"""Options for agent execution"""
|
|
32
|
+
|
|
33
|
+
mode: Optional[str] = "auto" # 'auto', 'sync', 'async'
|
|
34
|
+
max_wait: Optional[int] = None
|
|
35
|
+
on_progress: Optional[Callable[[str, Optional[int]], None]] = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class AgentResult:
|
|
40
|
+
"""Result from agent execution"""
|
|
41
|
+
|
|
42
|
+
content: str
|
|
43
|
+
agent_used: str
|
|
44
|
+
mode_used: str
|
|
45
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
46
|
+
tokens_used: Optional[Dict[str, int]] = None
|
|
47
|
+
confidence_score: Optional[float] = None
|
|
48
|
+
execution_time: Optional[float] = None
|
|
49
|
+
timestamp: Optional[str] = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class QueuedAgentResponse:
|
|
54
|
+
"""Response when agent execution is queued"""
|
|
55
|
+
|
|
56
|
+
status: str
|
|
57
|
+
operation_id: str
|
|
58
|
+
message: str
|
|
59
|
+
sse_endpoint: Optional[str] = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class QueuedAgentError(Exception):
|
|
63
|
+
"""Exception thrown when agent execution is queued and maxWait is 0"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, queue_info: QueuedAgentResponse):
|
|
66
|
+
super().__init__("Agent execution was queued")
|
|
67
|
+
self.queue_info = queue_info
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class AgentClient:
|
|
71
|
+
"""Enhanced agent client with SSE streaming support"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, config: Dict[str, Any]):
|
|
74
|
+
self.config = config
|
|
75
|
+
self.base_url = config["base_url"]
|
|
76
|
+
self.headers = config.get("headers", {})
|
|
77
|
+
self.token = config.get("token")
|
|
78
|
+
self.sse_client: Optional[SSEClient] = None
|
|
79
|
+
|
|
80
|
+
def execute_query(
|
|
81
|
+
self,
|
|
82
|
+
graph_id: str,
|
|
83
|
+
request: AgentQueryRequest,
|
|
84
|
+
options: AgentOptions = None,
|
|
85
|
+
) -> AgentResult:
|
|
86
|
+
"""Execute agent query with automatic agent selection"""
|
|
87
|
+
if options is None:
|
|
88
|
+
options = AgentOptions()
|
|
89
|
+
|
|
90
|
+
# Build request data
|
|
91
|
+
agent_request = AgentRequest(
|
|
92
|
+
message=request.message,
|
|
93
|
+
history=[
|
|
94
|
+
AgentMessage(role=msg["role"], content=msg["content"])
|
|
95
|
+
for msg in (request.history or [])
|
|
96
|
+
],
|
|
97
|
+
context=request.context,
|
|
98
|
+
mode=request.mode,
|
|
99
|
+
enable_rag=request.enable_rag,
|
|
100
|
+
force_extended_analysis=request.force_extended_analysis,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Execute through the generated client
|
|
104
|
+
from ..client import AuthenticatedClient
|
|
105
|
+
|
|
106
|
+
if not self.token:
|
|
107
|
+
raise Exception("No API key provided. Set X-API-Key in headers.")
|
|
108
|
+
|
|
109
|
+
client = AuthenticatedClient(
|
|
110
|
+
base_url=self.base_url,
|
|
111
|
+
token=self.token,
|
|
112
|
+
prefix="",
|
|
113
|
+
auth_header_name="X-API-Key",
|
|
114
|
+
headers=self.headers,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
response = auto_select_agent(
|
|
119
|
+
graph_id=graph_id,
|
|
120
|
+
client=client,
|
|
121
|
+
body=agent_request,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Check response type and handle accordingly
|
|
125
|
+
if hasattr(response, "parsed") and response.parsed:
|
|
126
|
+
response_data = response.parsed
|
|
127
|
+
|
|
128
|
+
# Handle both dict and attrs object responses
|
|
129
|
+
if isinstance(response_data, dict):
|
|
130
|
+
data = response_data
|
|
131
|
+
else:
|
|
132
|
+
# Response is an attrs object
|
|
133
|
+
data = response_data
|
|
134
|
+
|
|
135
|
+
# Check if this is an immediate response (sync or SSE execution)
|
|
136
|
+
has_content = False
|
|
137
|
+
if isinstance(data, dict):
|
|
138
|
+
has_content = "content" in data and "agent_used" in data
|
|
139
|
+
else:
|
|
140
|
+
has_content = hasattr(data, "content") and hasattr(data, "agent_used")
|
|
141
|
+
|
|
142
|
+
if has_content:
|
|
143
|
+
# Extract data from either dict or attrs object
|
|
144
|
+
if isinstance(data, dict):
|
|
145
|
+
return AgentResult(
|
|
146
|
+
content=data["content"],
|
|
147
|
+
agent_used=data["agent_used"],
|
|
148
|
+
mode_used=data["mode_used"],
|
|
149
|
+
metadata=data.get("metadata"),
|
|
150
|
+
tokens_used=data.get("tokens_used"),
|
|
151
|
+
confidence_score=data.get("confidence_score"),
|
|
152
|
+
execution_time=data.get("execution_time"),
|
|
153
|
+
timestamp=data.get("timestamp", datetime.now().isoformat()),
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
# attrs object - access attributes directly
|
|
157
|
+
from ..types import UNSET
|
|
158
|
+
|
|
159
|
+
return AgentResult(
|
|
160
|
+
content=data.content if data.content is not UNSET else "",
|
|
161
|
+
agent_used=data.agent_used if data.agent_used is not UNSET else "unknown",
|
|
162
|
+
mode_used=data.mode_used.value
|
|
163
|
+
if hasattr(data.mode_used, "value")
|
|
164
|
+
else data.mode_used
|
|
165
|
+
if data.mode_used is not UNSET
|
|
166
|
+
else "standard",
|
|
167
|
+
metadata=data.metadata if data.metadata is not UNSET else None,
|
|
168
|
+
tokens_used=data.tokens_used if data.tokens_used is not UNSET else None,
|
|
169
|
+
confidence_score=data.confidence_score
|
|
170
|
+
if data.confidence_score is not UNSET
|
|
171
|
+
else None,
|
|
172
|
+
execution_time=data.execution_time
|
|
173
|
+
if data.execution_time is not UNSET
|
|
174
|
+
else None,
|
|
175
|
+
timestamp=data.timestamp
|
|
176
|
+
if hasattr(data, "timestamp") and data.timestamp is not UNSET
|
|
177
|
+
else datetime.now().isoformat(),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Check if this is a queued response (async Celery execution)
|
|
181
|
+
is_queued = False
|
|
182
|
+
queued_response = None
|
|
183
|
+
|
|
184
|
+
if isinstance(data, dict):
|
|
185
|
+
is_queued = "operation_id" in data
|
|
186
|
+
if is_queued:
|
|
187
|
+
queued_response = QueuedAgentResponse(
|
|
188
|
+
status=data.get("status", "queued"),
|
|
189
|
+
operation_id=data["operation_id"],
|
|
190
|
+
message=data.get("message", "Agent execution queued"),
|
|
191
|
+
sse_endpoint=data.get("sse_endpoint"),
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
is_queued = hasattr(data, "operation_id")
|
|
195
|
+
if is_queued:
|
|
196
|
+
from ..types import UNSET
|
|
197
|
+
|
|
198
|
+
queued_response = QueuedAgentResponse(
|
|
199
|
+
status=data.status if hasattr(data, "status") else "queued",
|
|
200
|
+
operation_id=data.operation_id,
|
|
201
|
+
message=data.message
|
|
202
|
+
if hasattr(data, "message") and data.message is not UNSET
|
|
203
|
+
else "Agent execution queued",
|
|
204
|
+
sse_endpoint=data.sse_endpoint
|
|
205
|
+
if hasattr(data, "sse_endpoint") and data.sse_endpoint is not UNSET
|
|
206
|
+
else None,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if is_queued and queued_response:
|
|
210
|
+
# If user doesn't want to wait, raise with queue info
|
|
211
|
+
if options.max_wait == 0:
|
|
212
|
+
raise QueuedAgentError(queued_response)
|
|
213
|
+
|
|
214
|
+
# Use SSE to monitor the operation
|
|
215
|
+
return self._wait_for_agent_completion(queued_response.operation_id, options)
|
|
216
|
+
|
|
217
|
+
except Exception as e:
|
|
218
|
+
if isinstance(e, QueuedAgentError):
|
|
219
|
+
raise
|
|
220
|
+
|
|
221
|
+
error_msg = str(e)
|
|
222
|
+
# Check for authentication errors
|
|
223
|
+
if (
|
|
224
|
+
"401" in error_msg or "403" in error_msg or "unauthorized" in error_msg.lower()
|
|
225
|
+
):
|
|
226
|
+
raise Exception(f"Authentication failed during agent execution: {error_msg}")
|
|
227
|
+
else:
|
|
228
|
+
raise Exception(f"Agent execution failed: {error_msg}")
|
|
229
|
+
|
|
230
|
+
# Unexpected response format
|
|
231
|
+
raise Exception("Unexpected response format from agent endpoint")
|
|
232
|
+
|
|
233
|
+
def execute_agent(
|
|
234
|
+
self,
|
|
235
|
+
graph_id: str,
|
|
236
|
+
agent_type: str,
|
|
237
|
+
request: AgentQueryRequest,
|
|
238
|
+
options: AgentOptions = None,
|
|
239
|
+
) -> AgentResult:
|
|
240
|
+
"""Execute specific agent type"""
|
|
241
|
+
if options is None:
|
|
242
|
+
options = AgentOptions()
|
|
243
|
+
|
|
244
|
+
# Build request data
|
|
245
|
+
agent_request = AgentRequest(
|
|
246
|
+
message=request.message,
|
|
247
|
+
history=[
|
|
248
|
+
AgentMessage(role=msg["role"], content=msg["content"])
|
|
249
|
+
for msg in (request.history or [])
|
|
250
|
+
],
|
|
251
|
+
context=request.context,
|
|
252
|
+
mode=request.mode,
|
|
253
|
+
enable_rag=request.enable_rag,
|
|
254
|
+
force_extended_analysis=request.force_extended_analysis,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Execute through the generated client
|
|
258
|
+
from ..client import AuthenticatedClient
|
|
259
|
+
|
|
260
|
+
if not self.token:
|
|
261
|
+
raise Exception("No API key provided. Set X-API-Key in headers.")
|
|
262
|
+
|
|
263
|
+
client = AuthenticatedClient(
|
|
264
|
+
base_url=self.base_url,
|
|
265
|
+
token=self.token,
|
|
266
|
+
prefix="",
|
|
267
|
+
auth_header_name="X-API-Key",
|
|
268
|
+
headers=self.headers,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
try:
|
|
272
|
+
response = execute_specific_agent(
|
|
273
|
+
graph_id=graph_id,
|
|
274
|
+
agent_type=agent_type,
|
|
275
|
+
client=client,
|
|
276
|
+
body=agent_request,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Check response type and handle accordingly
|
|
280
|
+
if hasattr(response, "parsed") and response.parsed:
|
|
281
|
+
response_data = response.parsed
|
|
282
|
+
|
|
283
|
+
# Handle both dict and attrs object responses
|
|
284
|
+
if isinstance(response_data, dict):
|
|
285
|
+
data = response_data
|
|
286
|
+
else:
|
|
287
|
+
data = response_data
|
|
288
|
+
|
|
289
|
+
# Check if this is an immediate response
|
|
290
|
+
has_content = False
|
|
291
|
+
if isinstance(data, dict):
|
|
292
|
+
has_content = "content" in data and "agent_used" in data
|
|
293
|
+
else:
|
|
294
|
+
has_content = hasattr(data, "content") and hasattr(data, "agent_used")
|
|
295
|
+
|
|
296
|
+
if has_content:
|
|
297
|
+
# Extract data from either dict or attrs object
|
|
298
|
+
if isinstance(data, dict):
|
|
299
|
+
return AgentResult(
|
|
300
|
+
content=data["content"],
|
|
301
|
+
agent_used=data["agent_used"],
|
|
302
|
+
mode_used=data["mode_used"],
|
|
303
|
+
metadata=data.get("metadata"),
|
|
304
|
+
tokens_used=data.get("tokens_used"),
|
|
305
|
+
confidence_score=data.get("confidence_score"),
|
|
306
|
+
execution_time=data.get("execution_time"),
|
|
307
|
+
timestamp=data.get("timestamp", datetime.now().isoformat()),
|
|
308
|
+
)
|
|
309
|
+
else:
|
|
310
|
+
# attrs object
|
|
311
|
+
from ..types import UNSET
|
|
312
|
+
|
|
313
|
+
return AgentResult(
|
|
314
|
+
content=data.content if data.content is not UNSET else "",
|
|
315
|
+
agent_used=data.agent_used if data.agent_used is not UNSET else "unknown",
|
|
316
|
+
mode_used=data.mode_used.value
|
|
317
|
+
if hasattr(data.mode_used, "value")
|
|
318
|
+
else data.mode_used
|
|
319
|
+
if data.mode_used is not UNSET
|
|
320
|
+
else "standard",
|
|
321
|
+
metadata=data.metadata if data.metadata is not UNSET else None,
|
|
322
|
+
tokens_used=data.tokens_used if data.tokens_used is not UNSET else None,
|
|
323
|
+
confidence_score=data.confidence_score
|
|
324
|
+
if data.confidence_score is not UNSET
|
|
325
|
+
else None,
|
|
326
|
+
execution_time=data.execution_time
|
|
327
|
+
if data.execution_time is not UNSET
|
|
328
|
+
else None,
|
|
329
|
+
timestamp=data.timestamp
|
|
330
|
+
if hasattr(data, "timestamp") and data.timestamp is not UNSET
|
|
331
|
+
else datetime.now().isoformat(),
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Check if this is a queued response
|
|
335
|
+
is_queued = False
|
|
336
|
+
queued_response = None
|
|
337
|
+
|
|
338
|
+
if isinstance(data, dict):
|
|
339
|
+
is_queued = "operation_id" in data
|
|
340
|
+
if is_queued:
|
|
341
|
+
queued_response = QueuedAgentResponse(
|
|
342
|
+
status=data.get("status", "queued"),
|
|
343
|
+
operation_id=data["operation_id"],
|
|
344
|
+
message=data.get("message", "Agent execution queued"),
|
|
345
|
+
sse_endpoint=data.get("sse_endpoint"),
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
is_queued = hasattr(data, "operation_id")
|
|
349
|
+
if is_queued:
|
|
350
|
+
from ..types import UNSET
|
|
351
|
+
|
|
352
|
+
queued_response = QueuedAgentResponse(
|
|
353
|
+
status=data.status if hasattr(data, "status") else "queued",
|
|
354
|
+
operation_id=data.operation_id,
|
|
355
|
+
message=data.message
|
|
356
|
+
if hasattr(data, "message") and data.message is not UNSET
|
|
357
|
+
else "Agent execution queued",
|
|
358
|
+
sse_endpoint=data.sse_endpoint
|
|
359
|
+
if hasattr(data, "sse_endpoint") and data.sse_endpoint is not UNSET
|
|
360
|
+
else None,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
if is_queued and queued_response:
|
|
364
|
+
# If user doesn't want to wait, raise with queue info
|
|
365
|
+
if options.max_wait == 0:
|
|
366
|
+
raise QueuedAgentError(queued_response)
|
|
367
|
+
|
|
368
|
+
# Use SSE to monitor the operation
|
|
369
|
+
return self._wait_for_agent_completion(queued_response.operation_id, options)
|
|
370
|
+
|
|
371
|
+
except Exception as e:
|
|
372
|
+
if isinstance(e, QueuedAgentError):
|
|
373
|
+
raise
|
|
374
|
+
|
|
375
|
+
error_msg = str(e)
|
|
376
|
+
if (
|
|
377
|
+
"401" in error_msg or "403" in error_msg or "unauthorized" in error_msg.lower()
|
|
378
|
+
):
|
|
379
|
+
raise Exception(f"Authentication failed during agent execution: {error_msg}")
|
|
380
|
+
else:
|
|
381
|
+
raise Exception(f"Agent execution failed: {error_msg}")
|
|
382
|
+
|
|
383
|
+
# Unexpected response format
|
|
384
|
+
raise Exception("Unexpected response format from agent endpoint")
|
|
385
|
+
|
|
386
|
+
def _wait_for_agent_completion(
|
|
387
|
+
self, operation_id: str, options: AgentOptions
|
|
388
|
+
) -> AgentResult:
|
|
389
|
+
"""Wait for agent completion and return final result"""
|
|
390
|
+
result = None
|
|
391
|
+
error = None
|
|
392
|
+
completed = False
|
|
393
|
+
|
|
394
|
+
# Set up SSE connection
|
|
395
|
+
sse_config = SSEConfig(base_url=self.base_url, headers=self.headers)
|
|
396
|
+
sse_client = SSEClient(sse_config)
|
|
397
|
+
|
|
398
|
+
def on_progress(data):
|
|
399
|
+
if options.on_progress:
|
|
400
|
+
options.on_progress(
|
|
401
|
+
data.get("message", "Processing..."), data.get("percentage")
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
def on_agent_started(data):
|
|
405
|
+
if options.on_progress:
|
|
406
|
+
options.on_progress(f"Agent {data.get('agent_type')} started", 0)
|
|
407
|
+
|
|
408
|
+
def on_agent_initialized(data):
|
|
409
|
+
if options.on_progress:
|
|
410
|
+
options.on_progress(f"{data.get('agent_name')} initialized", 10)
|
|
411
|
+
|
|
412
|
+
def on_agent_completed(data):
|
|
413
|
+
nonlocal result, completed
|
|
414
|
+
result = AgentResult(
|
|
415
|
+
content=data.get("content", ""),
|
|
416
|
+
agent_used=data.get("agent_used", "unknown"),
|
|
417
|
+
mode_used=data.get("mode_used", "standard"),
|
|
418
|
+
metadata=data.get("metadata"),
|
|
419
|
+
tokens_used=data.get("tokens_used"),
|
|
420
|
+
confidence_score=data.get("confidence_score"),
|
|
421
|
+
execution_time=data.get("execution_time"),
|
|
422
|
+
timestamp=data.get("timestamp", datetime.now().isoformat()),
|
|
423
|
+
)
|
|
424
|
+
completed = True
|
|
425
|
+
|
|
426
|
+
def on_completed(data):
|
|
427
|
+
nonlocal result, completed
|
|
428
|
+
if not result:
|
|
429
|
+
# Fallback to generic completion event
|
|
430
|
+
agent_result = data.get("result", data)
|
|
431
|
+
result = AgentResult(
|
|
432
|
+
content=agent_result.get("content", ""),
|
|
433
|
+
agent_used=agent_result.get("agent_used", "unknown"),
|
|
434
|
+
mode_used=agent_result.get("mode_used", "standard"),
|
|
435
|
+
metadata=agent_result.get("metadata"),
|
|
436
|
+
tokens_used=agent_result.get("tokens_used"),
|
|
437
|
+
confidence_score=agent_result.get("confidence_score"),
|
|
438
|
+
execution_time=agent_result.get("execution_time"),
|
|
439
|
+
timestamp=agent_result.get("timestamp", datetime.now().isoformat()),
|
|
440
|
+
)
|
|
441
|
+
completed = True
|
|
442
|
+
|
|
443
|
+
def on_error(err):
|
|
444
|
+
nonlocal error, completed
|
|
445
|
+
error = Exception(err.get("message", err.get("error", "Unknown error")))
|
|
446
|
+
completed = True
|
|
447
|
+
|
|
448
|
+
def on_cancelled():
|
|
449
|
+
nonlocal error, completed
|
|
450
|
+
error = Exception("Agent execution cancelled")
|
|
451
|
+
completed = True
|
|
452
|
+
|
|
453
|
+
# Register event handlers
|
|
454
|
+
sse_client.on(EventType.OPERATION_PROGRESS.value, on_progress)
|
|
455
|
+
sse_client.on("agent_started", on_agent_started)
|
|
456
|
+
sse_client.on("agent_initialized", on_agent_initialized)
|
|
457
|
+
sse_client.on("progress", on_progress)
|
|
458
|
+
sse_client.on("agent_completed", on_agent_completed)
|
|
459
|
+
sse_client.on(EventType.OPERATION_COMPLETED.value, on_completed)
|
|
460
|
+
sse_client.on(EventType.OPERATION_ERROR.value, on_error)
|
|
461
|
+
sse_client.on("error", on_error)
|
|
462
|
+
sse_client.on(EventType.OPERATION_CANCELLED.value, on_cancelled)
|
|
463
|
+
|
|
464
|
+
# Connect and wait
|
|
465
|
+
sse_client.connect(operation_id)
|
|
466
|
+
|
|
467
|
+
# Wait for completion
|
|
468
|
+
import time
|
|
469
|
+
|
|
470
|
+
while not completed:
|
|
471
|
+
if error:
|
|
472
|
+
sse_client.close()
|
|
473
|
+
raise error
|
|
474
|
+
time.sleep(0.1)
|
|
475
|
+
|
|
476
|
+
sse_client.close()
|
|
477
|
+
return result
|
|
478
|
+
|
|
479
|
+
def query(
|
|
480
|
+
self, graph_id: str, message: str, context: Dict[str, Any] = None
|
|
481
|
+
) -> AgentResult:
|
|
482
|
+
"""Convenience method for simple agent queries with auto-selection"""
|
|
483
|
+
request = AgentQueryRequest(message=message, context=context)
|
|
484
|
+
return self.execute_query(graph_id, request, AgentOptions(mode="auto"))
|
|
485
|
+
|
|
486
|
+
def analyze_financials(
|
|
487
|
+
self,
|
|
488
|
+
graph_id: str,
|
|
489
|
+
message: str,
|
|
490
|
+
on_progress: Optional[Callable[[str, Optional[int]], None]] = None,
|
|
491
|
+
) -> AgentResult:
|
|
492
|
+
"""Execute financial agent for financial analysis"""
|
|
493
|
+
request = AgentQueryRequest(message=message)
|
|
494
|
+
return self.execute_agent(
|
|
495
|
+
graph_id, "financial", request, AgentOptions(on_progress=on_progress)
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
def research(
|
|
499
|
+
self,
|
|
500
|
+
graph_id: str,
|
|
501
|
+
message: str,
|
|
502
|
+
on_progress: Optional[Callable[[str, Optional[int]], None]] = None,
|
|
503
|
+
) -> AgentResult:
|
|
504
|
+
"""Execute research agent for deep research"""
|
|
505
|
+
request = AgentQueryRequest(message=message)
|
|
506
|
+
return self.execute_agent(
|
|
507
|
+
graph_id, "research", request, AgentOptions(on_progress=on_progress)
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
def rag(
|
|
511
|
+
self,
|
|
512
|
+
graph_id: str,
|
|
513
|
+
message: str,
|
|
514
|
+
on_progress: Optional[Callable[[str, Optional[int]], None]] = None,
|
|
515
|
+
) -> AgentResult:
|
|
516
|
+
"""Execute RAG agent for fast retrieval"""
|
|
517
|
+
request = AgentQueryRequest(message=message)
|
|
518
|
+
return self.execute_agent(
|
|
519
|
+
graph_id, "rag", request, AgentOptions(on_progress=on_progress)
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
def close(self):
|
|
523
|
+
"""Cancel any active SSE connections"""
|
|
524
|
+
if self.sse_client:
|
|
525
|
+
self.sse_client.close()
|
|
526
|
+
self.sse_client = None
|
|
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
|
|
7
7
|
from typing import Dict, Any, Optional, Callable
|
|
8
8
|
|
|
9
9
|
from .query_client import QueryClient
|
|
10
|
+
from .agent_client import AgentClient
|
|
10
11
|
from .operation_client import OperationClient
|
|
11
12
|
from .table_ingest_client import TableIngestClient
|
|
12
13
|
from .graph_client import GraphClient
|
|
@@ -58,6 +59,7 @@ class RoboSystemsExtensions:
|
|
|
58
59
|
|
|
59
60
|
# Initialize clients
|
|
60
61
|
self.query = QueryClient(self.config)
|
|
62
|
+
self.agent = AgentClient(self.config)
|
|
61
63
|
self.operations = OperationClient(self.config)
|
|
62
64
|
self.tables = TableIngestClient(self.config)
|
|
63
65
|
self.graphs = GraphClient(self.config)
|
|
@@ -88,6 +90,7 @@ class RoboSystemsExtensions:
|
|
|
88
90
|
def close(self):
|
|
89
91
|
"""Clean up all active connections"""
|
|
90
92
|
self.query.close()
|
|
93
|
+
self.agent.close()
|
|
91
94
|
self.operations.close_all()
|
|
92
95
|
self.tables.close()
|
|
93
96
|
self.graphs.close()
|
|
@@ -45,7 +45,6 @@ from .bulk_ingest_response import BulkIngestResponse
|
|
|
45
45
|
from .cancel_operation_response_canceloperation import (
|
|
46
46
|
CancelOperationResponseCanceloperation,
|
|
47
47
|
)
|
|
48
|
-
from .cancellation_response import CancellationResponse
|
|
49
48
|
from .check_credit_balance_response_checkcreditbalance import (
|
|
50
49
|
CheckCreditBalanceResponseCheckcreditbalance,
|
|
51
50
|
)
|
|
@@ -208,6 +207,7 @@ from .plaid_connection_config_accounts_type_0_item import (
|
|
|
208
207
|
from .plaid_connection_config_institution_type_0 import (
|
|
209
208
|
PlaidConnectionConfigInstitutionType0,
|
|
210
209
|
)
|
|
210
|
+
from .portal_session_response import PortalSessionResponse
|
|
211
211
|
from .query_limits import QueryLimits
|
|
212
212
|
from .quick_books_connection_config import QuickBooksConnectionConfig
|
|
213
213
|
from .rate_limits import RateLimits
|
|
@@ -279,8 +279,6 @@ from .update_file_status_response_updatefilestatus import (
|
|
|
279
279
|
from .update_member_role_request import UpdateMemberRoleRequest
|
|
280
280
|
from .update_org_request import UpdateOrgRequest
|
|
281
281
|
from .update_password_request import UpdatePasswordRequest
|
|
282
|
-
from .update_payment_method_request import UpdatePaymentMethodRequest
|
|
283
|
-
from .update_payment_method_response import UpdatePaymentMethodResponse
|
|
284
282
|
from .update_user_request import UpdateUserRequest
|
|
285
283
|
from .upgrade_subscription_request import UpgradeSubscriptionRequest
|
|
286
284
|
from .user_graphs_response import UserGraphsResponse
|
|
@@ -326,7 +324,6 @@ __all__ = (
|
|
|
326
324
|
"BillingCustomer",
|
|
327
325
|
"BulkIngestRequest",
|
|
328
326
|
"BulkIngestResponse",
|
|
329
|
-
"CancellationResponse",
|
|
330
327
|
"CancelOperationResponseCanceloperation",
|
|
331
328
|
"CheckCreditBalanceResponseCheckcreditbalance",
|
|
332
329
|
"CheckoutResponse",
|
|
@@ -462,6 +459,7 @@ __all__ = (
|
|
|
462
459
|
"PlaidConnectionConfig",
|
|
463
460
|
"PlaidConnectionConfigAccountsType0Item",
|
|
464
461
|
"PlaidConnectionConfigInstitutionType0",
|
|
462
|
+
"PortalSessionResponse",
|
|
465
463
|
"QueryLimits",
|
|
466
464
|
"QuickBooksConnectionConfig",
|
|
467
465
|
"RateLimits",
|
|
@@ -519,8 +517,6 @@ __all__ = (
|
|
|
519
517
|
"UpdateMemberRoleRequest",
|
|
520
518
|
"UpdateOrgRequest",
|
|
521
519
|
"UpdatePasswordRequest",
|
|
522
|
-
"UpdatePaymentMethodRequest",
|
|
523
|
-
"UpdatePaymentMethodResponse",
|
|
524
520
|
"UpdateUserRequest",
|
|
525
521
|
"UpgradeSubscriptionRequest",
|
|
526
522
|
"UserGraphsResponse",
|
|
@@ -16,7 +16,8 @@ class CheckoutStatusResponse:
|
|
|
16
16
|
Attributes:
|
|
17
17
|
status (str): Checkout status: 'pending_payment', 'provisioning', 'completed', 'failed'
|
|
18
18
|
subscription_id (str): Internal subscription ID
|
|
19
|
-
resource_id (Union[None, Unset, str]): Resource ID (graph_id
|
|
19
|
+
resource_id (Union[None, Unset, str]): Resource ID (graph_id for both graphs and repositories) once provisioned.
|
|
20
|
+
For repositories, this is the repository slug (e.g., 'sec')
|
|
20
21
|
operation_id (Union[None, Unset, str]): SSE operation ID for monitoring provisioning progress
|
|
21
22
|
error (Union[None, Unset, str]): Error message if checkout failed
|
|
22
23
|
"""
|
|
@@ -20,7 +20,8 @@ class CreateCheckoutRequest:
|
|
|
20
20
|
Attributes:
|
|
21
21
|
plan_name (str): Billing plan name (e.g., 'kuzu-standard')
|
|
22
22
|
resource_type (str): Resource type ('graph' or 'repository')
|
|
23
|
-
resource_config (CreateCheckoutRequestResourceConfig): Configuration for the resource to be provisioned
|
|
23
|
+
resource_config (CreateCheckoutRequestResourceConfig): Configuration for the resource to be provisioned. For
|
|
24
|
+
repositories: {'repository_name': 'graph_id'} where graph_id is the repository slug (e.g., 'sec')
|
|
24
25
|
"""
|
|
25
26
|
|
|
26
27
|
plan_name: str
|
|
@@ -9,7 +9,10 @@ T = TypeVar("T", bound="CreateCheckoutRequestResourceConfig")
|
|
|
9
9
|
|
|
10
10
|
@_attrs_define
|
|
11
11
|
class CreateCheckoutRequestResourceConfig:
|
|
12
|
-
"""Configuration for the resource to be provisioned
|
|
12
|
+
"""Configuration for the resource to be provisioned. For repositories: {'repository_name': 'graph_id'} where graph_id
|
|
13
|
+
is the repository slug (e.g., 'sec')
|
|
14
|
+
|
|
15
|
+
"""
|
|
13
16
|
|
|
14
17
|
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
15
18
|
|
|
@@ -24,8 +24,7 @@ class CreateSubgraphRequest:
|
|
|
24
24
|
name (str): Alphanumeric name for the subgraph (e.g., dev, staging, prod1)
|
|
25
25
|
display_name (str): Human-readable display name for the subgraph
|
|
26
26
|
description (Union[None, Unset, str]): Optional description of the subgraph's purpose
|
|
27
|
-
schema_extensions (Union[
|
|
28
|
-
default)
|
|
27
|
+
schema_extensions (Union[Unset, list[str]]): Schema extensions to include (inherits from parent by default)
|
|
29
28
|
subgraph_type (Union[Unset, SubgraphType]): Types of subgraphs.
|
|
30
29
|
metadata (Union['CreateSubgraphRequestMetadataType0', None, Unset]): Additional metadata for the subgraph
|
|
31
30
|
"""
|
|
@@ -33,7 +32,7 @@ class CreateSubgraphRequest:
|
|
|
33
32
|
name: str
|
|
34
33
|
display_name: str
|
|
35
34
|
description: Union[None, Unset, str] = UNSET
|
|
36
|
-
schema_extensions: Union[
|
|
35
|
+
schema_extensions: Union[Unset, list[str]] = UNSET
|
|
37
36
|
subgraph_type: Union[Unset, SubgraphType] = UNSET
|
|
38
37
|
metadata: Union["CreateSubgraphRequestMetadataType0", None, Unset] = UNSET
|
|
39
38
|
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
@@ -53,13 +52,8 @@ class CreateSubgraphRequest:
|
|
|
53
52
|
else:
|
|
54
53
|
description = self.description
|
|
55
54
|
|
|
56
|
-
schema_extensions: Union[
|
|
57
|
-
if isinstance(self.schema_extensions, Unset):
|
|
58
|
-
schema_extensions = UNSET
|
|
59
|
-
elif isinstance(self.schema_extensions, list):
|
|
60
|
-
schema_extensions = self.schema_extensions
|
|
61
|
-
|
|
62
|
-
else:
|
|
55
|
+
schema_extensions: Union[Unset, list[str]] = UNSET
|
|
56
|
+
if not isinstance(self.schema_extensions, Unset):
|
|
63
57
|
schema_extensions = self.schema_extensions
|
|
64
58
|
|
|
65
59
|
subgraph_type: Union[Unset, str] = UNSET
|
|
@@ -113,22 +107,7 @@ class CreateSubgraphRequest:
|
|
|
113
107
|
|
|
114
108
|
description = _parse_description(d.pop("description", UNSET))
|
|
115
109
|
|
|
116
|
-
|
|
117
|
-
if data is None:
|
|
118
|
-
return data
|
|
119
|
-
if isinstance(data, Unset):
|
|
120
|
-
return data
|
|
121
|
-
try:
|
|
122
|
-
if not isinstance(data, list):
|
|
123
|
-
raise TypeError()
|
|
124
|
-
schema_extensions_type_0 = cast(list[str], data)
|
|
125
|
-
|
|
126
|
-
return schema_extensions_type_0
|
|
127
|
-
except: # noqa: E722
|
|
128
|
-
pass
|
|
129
|
-
return cast(Union[None, Unset, list[str]], data)
|
|
130
|
-
|
|
131
|
-
schema_extensions = _parse_schema_extensions(d.pop("schema_extensions", UNSET))
|
|
110
|
+
schema_extensions = cast(list[str], d.pop("schema_extensions", UNSET))
|
|
132
111
|
|
|
133
112
|
_subgraph_type = d.pop("subgraph_type", UNSET)
|
|
134
113
|
subgraph_type: Union[Unset, SubgraphType]
|