quantalogic 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantalogic/__init__.py +20 -0
- quantalogic/agent.py +638 -0
- quantalogic/agent_config.py +138 -0
- quantalogic/coding_agent.py +83 -0
- quantalogic/event_emitter.py +223 -0
- quantalogic/generative_model.py +226 -0
- quantalogic/interactive_text_editor.py +190 -0
- quantalogic/main.py +185 -0
- quantalogic/memory.py +217 -0
- quantalogic/model_names.py +19 -0
- quantalogic/print_event.py +66 -0
- quantalogic/prompts.py +99 -0
- quantalogic/server/__init__.py +3 -0
- quantalogic/server/agent_server.py +633 -0
- quantalogic/server/models.py +60 -0
- quantalogic/server/routes.py +117 -0
- quantalogic/server/state.py +199 -0
- quantalogic/server/static/js/event_visualizer.js +430 -0
- quantalogic/server/static/js/quantalogic.js +571 -0
- quantalogic/server/templates/index.html +134 -0
- quantalogic/tool_manager.py +68 -0
- quantalogic/tools/__init__.py +46 -0
- quantalogic/tools/agent_tool.py +88 -0
- quantalogic/tools/download_http_file_tool.py +64 -0
- quantalogic/tools/edit_whole_content_tool.py +70 -0
- quantalogic/tools/elixir_tool.py +240 -0
- quantalogic/tools/execute_bash_command_tool.py +116 -0
- quantalogic/tools/input_question_tool.py +57 -0
- quantalogic/tools/language_handlers/__init__.py +21 -0
- quantalogic/tools/language_handlers/c_handler.py +33 -0
- quantalogic/tools/language_handlers/cpp_handler.py +33 -0
- quantalogic/tools/language_handlers/go_handler.py +33 -0
- quantalogic/tools/language_handlers/java_handler.py +37 -0
- quantalogic/tools/language_handlers/javascript_handler.py +42 -0
- quantalogic/tools/language_handlers/python_handler.py +29 -0
- quantalogic/tools/language_handlers/rust_handler.py +33 -0
- quantalogic/tools/language_handlers/scala_handler.py +33 -0
- quantalogic/tools/language_handlers/typescript_handler.py +42 -0
- quantalogic/tools/list_directory_tool.py +123 -0
- quantalogic/tools/llm_tool.py +119 -0
- quantalogic/tools/markitdown_tool.py +105 -0
- quantalogic/tools/nodejs_tool.py +515 -0
- quantalogic/tools/python_tool.py +469 -0
- quantalogic/tools/read_file_block_tool.py +140 -0
- quantalogic/tools/read_file_tool.py +79 -0
- quantalogic/tools/replace_in_file_tool.py +300 -0
- quantalogic/tools/ripgrep_tool.py +353 -0
- quantalogic/tools/search_definition_names.py +419 -0
- quantalogic/tools/task_complete_tool.py +35 -0
- quantalogic/tools/tool.py +146 -0
- quantalogic/tools/unified_diff_tool.py +387 -0
- quantalogic/tools/write_file_tool.py +97 -0
- quantalogic/utils/__init__.py +17 -0
- quantalogic/utils/ask_user_validation.py +12 -0
- quantalogic/utils/download_http_file.py +77 -0
- quantalogic/utils/get_coding_environment.py +15 -0
- quantalogic/utils/get_environment.py +26 -0
- quantalogic/utils/get_quantalogic_rules_content.py +19 -0
- quantalogic/utils/git_ls.py +121 -0
- quantalogic/utils/read_file.py +54 -0
- quantalogic/utils/read_http_text_content.py +101 -0
- quantalogic/xml_parser.py +242 -0
- quantalogic/xml_tool_parser.py +99 -0
- quantalogic-0.2.0.dist-info/LICENSE +201 -0
- quantalogic-0.2.0.dist-info/METADATA +1034 -0
- quantalogic-0.2.0.dist-info/RECORD +68 -0
- quantalogic-0.2.0.dist-info/WHEEL +4 -0
- quantalogic-0.2.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,633 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
"""FastAPI server for the QuantaLogic agent."""
|
3
|
+
|
4
|
+
import asyncio
|
5
|
+
import functools
|
6
|
+
import json
|
7
|
+
import signal
|
8
|
+
import sys
|
9
|
+
import time
|
10
|
+
import uuid
|
11
|
+
from contextlib import asynccontextmanager
|
12
|
+
from datetime import datetime
|
13
|
+
from queue import Empty, Queue
|
14
|
+
from threading import Lock
|
15
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional
|
16
|
+
|
17
|
+
import uvicorn
|
18
|
+
from fastapi import FastAPI, HTTPException, Request
|
19
|
+
from fastapi.middleware.cors import CORSMiddleware
|
20
|
+
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
|
21
|
+
from fastapi.staticfiles import StaticFiles
|
22
|
+
from fastapi.templating import Jinja2Templates
|
23
|
+
from loguru import logger
|
24
|
+
from pydantic import BaseModel
|
25
|
+
from rich.console import Console
|
26
|
+
|
27
|
+
from quantalogic.agent_config import (
|
28
|
+
MODEL_NAME,
|
29
|
+
create_agent,
|
30
|
+
create_coding_agent, # noqa: F401
|
31
|
+
create_orchestrator_agent, # noqa: F401
|
32
|
+
)
|
33
|
+
from quantalogic.print_event import console_print_events
|
34
|
+
|
35
|
+
# Configure logger
|
36
|
+
logger.remove()
|
37
|
+
logger.add(
|
38
|
+
sys.stderr,
|
39
|
+
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
40
|
+
level="INFO",
|
41
|
+
)
|
42
|
+
|
43
|
+
# Constants
|
44
|
+
SHUTDOWN_TIMEOUT = 5.0 # seconds
|
45
|
+
VALIDATION_TIMEOUT = 30.0 # seconds
|
46
|
+
|
47
|
+
|
48
|
+
def handle_sigterm(signum, frame):
|
49
|
+
"""Handle SIGTERM signal."""
|
50
|
+
logger.info("Received SIGTERM signal")
|
51
|
+
raise SystemExit(0)
|
52
|
+
|
53
|
+
|
54
|
+
signal.signal(signal.SIGTERM, handle_sigterm)
|
55
|
+
|
56
|
+
|
57
|
+
def get_version() -> str:
|
58
|
+
"""Get the current version of the package."""
|
59
|
+
return "QuantaLogic version: 1.0.0"
|
60
|
+
|
61
|
+
|
62
|
+
class ServerState:
|
63
|
+
"""Global server state management."""
|
64
|
+
|
65
|
+
def __init__(self):
|
66
|
+
"""Initialize the global server state."""
|
67
|
+
self.interrupt_count = 0
|
68
|
+
self.force_exit = False
|
69
|
+
self.is_shutting_down = False
|
70
|
+
self.shutdown_initiated = asyncio.Event()
|
71
|
+
self.shutdown_complete = asyncio.Event()
|
72
|
+
self.server = None
|
73
|
+
|
74
|
+
async def initiate_shutdown(self, force: bool = False):
|
75
|
+
"""Initiate the shutdown process."""
|
76
|
+
if not self.is_shutting_down or force:
|
77
|
+
logger.info("Initiating server shutdown...")
|
78
|
+
self.is_shutting_down = True
|
79
|
+
self.force_exit = force
|
80
|
+
self.shutdown_initiated.set()
|
81
|
+
if force:
|
82
|
+
# Force exit immediately
|
83
|
+
logger.warning("Forcing immediate shutdown...")
|
84
|
+
sys.exit(1)
|
85
|
+
await self.shutdown_complete.wait()
|
86
|
+
|
87
|
+
def handle_interrupt(self):
|
88
|
+
"""Handle interrupt signal."""
|
89
|
+
self.interrupt_count += 1
|
90
|
+
if self.interrupt_count == 1:
|
91
|
+
logger.info("Graceful shutdown initiated (press Ctrl+C again to force)")
|
92
|
+
asyncio.create_task(self.initiate_shutdown(force=False))
|
93
|
+
else:
|
94
|
+
logger.warning("Forced shutdown initiated...")
|
95
|
+
# Use asyncio.create_task to avoid RuntimeError
|
96
|
+
asyncio.create_task(self.initiate_shutdown(force=True))
|
97
|
+
|
98
|
+
|
99
|
+
# Models
|
100
|
+
class EventMessage(BaseModel):
|
101
|
+
"""Event message model for SSE."""
|
102
|
+
|
103
|
+
id: str
|
104
|
+
event: str
|
105
|
+
task_id: Optional[str] = None # Added task_id field
|
106
|
+
data: Dict[str, Any]
|
107
|
+
timestamp: str
|
108
|
+
|
109
|
+
model_config = {"extra": "forbid"}
|
110
|
+
|
111
|
+
|
112
|
+
class UserValidationRequest(BaseModel):
|
113
|
+
"""Request model for user validation."""
|
114
|
+
|
115
|
+
question: str
|
116
|
+
validation_id: str | None = None
|
117
|
+
|
118
|
+
model_config = {"extra": "forbid"}
|
119
|
+
|
120
|
+
|
121
|
+
class UserValidationResponse(BaseModel):
|
122
|
+
"""Response model for user validation."""
|
123
|
+
|
124
|
+
response: bool
|
125
|
+
|
126
|
+
model_config = {"extra": "forbid"}
|
127
|
+
|
128
|
+
|
129
|
+
class TaskSubmission(BaseModel):
|
130
|
+
"""Request model for task submission."""
|
131
|
+
|
132
|
+
task: str
|
133
|
+
model_name: Optional[str] = MODEL_NAME
|
134
|
+
max_iterations: Optional[int] = 30
|
135
|
+
|
136
|
+
model_config = {"extra": "forbid"}
|
137
|
+
|
138
|
+
|
139
|
+
class TaskStatus(BaseModel):
|
140
|
+
"""Task status response model."""
|
141
|
+
|
142
|
+
task_id: str
|
143
|
+
status: str # "pending", "running", "completed", "failed"
|
144
|
+
created_at: str
|
145
|
+
started_at: Optional[str] = None
|
146
|
+
completed_at: Optional[str] = None
|
147
|
+
result: Optional[str] = None
|
148
|
+
error: Optional[str] = None
|
149
|
+
total_tokens: Optional[int] = None
|
150
|
+
model_name: Optional[str] = None
|
151
|
+
|
152
|
+
|
153
|
+
class AgentState:
|
154
|
+
"""Manages agent state and event queues."""
|
155
|
+
|
156
|
+
def __init__(self):
|
157
|
+
"""Initialize the agent state."""
|
158
|
+
self.agent = None
|
159
|
+
# Use a nested dictionary to track event queues per client and task
|
160
|
+
self.event_queues: Dict[str, Dict[str, Queue]] = {}
|
161
|
+
# Track active agents per client-task combination
|
162
|
+
self.active_agents: Dict[str, Dict[str, Any]] = {}
|
163
|
+
self.queue_lock = Lock()
|
164
|
+
self.client_counter = 0
|
165
|
+
self.console = Console()
|
166
|
+
self.validation_requests: Dict[str, Dict[str, Any]] = {}
|
167
|
+
self.validation_responses: Dict[str, asyncio.Queue] = {}
|
168
|
+
self.tasks: Dict[str, Dict[str, Any]] = {}
|
169
|
+
self.task_queues: Dict[str, asyncio.Queue] = {}
|
170
|
+
|
171
|
+
def add_client(self, task_id: Optional[str] = None) -> str:
|
172
|
+
"""Add a new client and return its ID.
|
173
|
+
|
174
|
+
Ensures unique client-task combination.
|
175
|
+
"""
|
176
|
+
with self.queue_lock:
|
177
|
+
# Generate a unique client ID
|
178
|
+
client_id = f"client_{self.client_counter}"
|
179
|
+
self.client_counter += 1
|
180
|
+
|
181
|
+
# Initialize nested event queue structure
|
182
|
+
if client_id not in self.event_queues:
|
183
|
+
self.event_queues[client_id] = {}
|
184
|
+
self.active_agents[client_id] = {}
|
185
|
+
|
186
|
+
if task_id:
|
187
|
+
# Prevent multiple agents for the same client-task combination
|
188
|
+
if task_id in self.active_agents[client_id]:
|
189
|
+
raise ValueError(f"An agent already exists for client {client_id} and task {task_id}")
|
190
|
+
|
191
|
+
# Create a specific queue for this client-task combination
|
192
|
+
self.event_queues[client_id][task_id] = Queue()
|
193
|
+
self.active_agents[client_id][task_id] = {
|
194
|
+
"created_at": datetime.utcnow().isoformat(),
|
195
|
+
"status": "active",
|
196
|
+
}
|
197
|
+
else:
|
198
|
+
# Global client queue
|
199
|
+
self.event_queues[client_id] = {"global": Queue()}
|
200
|
+
|
201
|
+
return client_id
|
202
|
+
|
203
|
+
def remove_client(self, client_id: str, task_id: Optional[str] = None):
|
204
|
+
"""Remove a client's event queue, optionally for a specific task."""
|
205
|
+
with self.queue_lock:
|
206
|
+
if client_id in self.event_queues:
|
207
|
+
if task_id and task_id in self.event_queues[client_id]:
|
208
|
+
# Remove specific task queue for this client
|
209
|
+
del self.event_queues[client_id][task_id]
|
210
|
+
|
211
|
+
# Remove active agent for this client-task
|
212
|
+
if client_id in self.active_agents and task_id in self.active_agents[client_id]:
|
213
|
+
del self.active_agents[client_id][task_id]
|
214
|
+
else:
|
215
|
+
# Remove entire client entry
|
216
|
+
del self.event_queues[client_id]
|
217
|
+
|
218
|
+
# Remove all active agents for this client
|
219
|
+
if client_id in self.active_agents:
|
220
|
+
del self.active_agents[client_id]
|
221
|
+
|
222
|
+
def broadcast_event(
|
223
|
+
self, event_type: str, data: Dict[str, Any], task_id: Optional[str] = None, client_id: Optional[str] = None
|
224
|
+
):
|
225
|
+
"""Broadcast an event to specific client-task queues or globally.
|
226
|
+
|
227
|
+
Allows optional filtering by client_id and task_id to prevent event leakage.
|
228
|
+
"""
|
229
|
+
event = EventMessage(
|
230
|
+
id=str(uuid.uuid4()), event=event_type, task_id=task_id, data=data, timestamp=datetime.utcnow().isoformat()
|
231
|
+
)
|
232
|
+
|
233
|
+
with self.queue_lock:
|
234
|
+
for curr_client_id, client_queues in self.event_queues.items():
|
235
|
+
# Skip if specific client_id is provided and doesn't match
|
236
|
+
if client_id and curr_client_id != client_id:
|
237
|
+
continue
|
238
|
+
|
239
|
+
if task_id and task_id in client_queues:
|
240
|
+
# Send to specific task queue
|
241
|
+
client_queues[task_id].put(event)
|
242
|
+
elif not task_id and "global" in client_queues:
|
243
|
+
# Send to global queue if no task specified
|
244
|
+
client_queues["global"].put(event)
|
245
|
+
|
246
|
+
def initialize_agent_with_sse_validation(self, model_name: str = MODEL_NAME):
|
247
|
+
"""Initialize agent with SSE-based user validation."""
|
248
|
+
try:
|
249
|
+
self.agent = create_agent(model_name)
|
250
|
+
|
251
|
+
# Comprehensive list of agent events to track
|
252
|
+
agent_events = [
|
253
|
+
"session_start",
|
254
|
+
"session_end",
|
255
|
+
"session_add_message",
|
256
|
+
"task_solve_start",
|
257
|
+
"task_solve_end",
|
258
|
+
"task_think_start",
|
259
|
+
"task_think_end",
|
260
|
+
"task_complete",
|
261
|
+
"tool_execution_start",
|
262
|
+
"tool_execution_end",
|
263
|
+
"tool_execute_validation_start",
|
264
|
+
"tool_execute_validation_end",
|
265
|
+
"memory_full",
|
266
|
+
"memory_compacted",
|
267
|
+
"memory_summary",
|
268
|
+
"error_max_iterations_reached",
|
269
|
+
"error_tool_execution",
|
270
|
+
"error_model_response",
|
271
|
+
]
|
272
|
+
|
273
|
+
# Setup event handlers
|
274
|
+
for event in agent_events:
|
275
|
+
self.agent.event_emitter.on(event, lambda e, d, event=event: self._handle_event(event, d))
|
276
|
+
|
277
|
+
# Override ask_for_user_validation with SSE-based method
|
278
|
+
self.agent.ask_for_user_validation = self.sse_ask_for_user_validation
|
279
|
+
|
280
|
+
logger.info(f"Agent initialized with model: {model_name}")
|
281
|
+
except Exception as e:
|
282
|
+
logger.error(f"Failed to initialize agent: {e}", exc_info=True)
|
283
|
+
raise
|
284
|
+
|
285
|
+
async def sse_ask_for_user_validation(self, question: str = "Do you want to continue?") -> bool:
|
286
|
+
"""SSE-based user validation method."""
|
287
|
+
validation_id = str(uuid.uuid4())
|
288
|
+
response_queue = asyncio.Queue()
|
289
|
+
|
290
|
+
# Store validation request and response queue
|
291
|
+
self.validation_requests[validation_id] = {"question": question, "timestamp": datetime.now().isoformat()}
|
292
|
+
self.validation_responses[validation_id] = response_queue
|
293
|
+
|
294
|
+
# Broadcast validation request
|
295
|
+
self.broadcast_event("user_validation_request", {"validation_id": validation_id, "question": question})
|
296
|
+
|
297
|
+
try:
|
298
|
+
# Wait for response with timeout
|
299
|
+
async with asyncio.timeout(VALIDATION_TIMEOUT):
|
300
|
+
response = await response_queue.get()
|
301
|
+
return response
|
302
|
+
except TimeoutError:
|
303
|
+
logger.warning(f"Validation request timed out: {validation_id}")
|
304
|
+
return False
|
305
|
+
finally:
|
306
|
+
# Cleanup
|
307
|
+
if validation_id in self.validation_requests:
|
308
|
+
del self.validation_requests[validation_id]
|
309
|
+
if validation_id in self.validation_responses:
|
310
|
+
del self.validation_responses[validation_id]
|
311
|
+
|
312
|
+
def _handle_event(self, event_type: str, data: Dict[str, Any]):
|
313
|
+
"""Enhanced event handling with rich console output."""
|
314
|
+
try:
|
315
|
+
# Print events to server console
|
316
|
+
console_print_events(event_type, data)
|
317
|
+
|
318
|
+
# Log event details
|
319
|
+
logger.info(f"Agent Event: {event_type}")
|
320
|
+
logger.debug(f"Event Data: {data}")
|
321
|
+
|
322
|
+
# Broadcast to clients
|
323
|
+
self.broadcast_event(event_type, data)
|
324
|
+
|
325
|
+
except Exception as e:
|
326
|
+
logger.error(f"Error in event handling: {e}", exc_info=True)
|
327
|
+
|
328
|
+
def get_current_model_name(self) -> str:
|
329
|
+
"""Get the current model name safely."""
|
330
|
+
if self.agent and self.agent.model:
|
331
|
+
return self.agent.model.model
|
332
|
+
return MODEL_NAME
|
333
|
+
|
334
|
+
async def cleanup(self):
|
335
|
+
"""Clean up resources during shutdown."""
|
336
|
+
try:
|
337
|
+
logger.info("Cleaning up resources...")
|
338
|
+
if server_state.force_exit:
|
339
|
+
logger.warning("Forced cleanup - skipping graceful shutdown")
|
340
|
+
return
|
341
|
+
|
342
|
+
async with asyncio.timeout(SHUTDOWN_TIMEOUT):
|
343
|
+
with self.queue_lock:
|
344
|
+
# Notify all clients
|
345
|
+
self.broadcast_event("server_shutdown", {"message": "Server is shutting down"})
|
346
|
+
# Clear queues
|
347
|
+
self.event_queues.clear()
|
348
|
+
self.validation_requests.clear()
|
349
|
+
self.validation_responses.clear()
|
350
|
+
# Clear agent
|
351
|
+
self.agent = None
|
352
|
+
logger.info("Cleanup completed")
|
353
|
+
except TimeoutError:
|
354
|
+
logger.warning(f"Cleanup timed out after {SHUTDOWN_TIMEOUT} seconds")
|
355
|
+
except Exception as e:
|
356
|
+
logger.error(f"Error during cleanup: {e}", exc_info=True)
|
357
|
+
finally:
|
358
|
+
self.agent = None
|
359
|
+
if server_state.force_exit:
|
360
|
+
sys.exit(1)
|
361
|
+
|
362
|
+
async def submit_task(self, task_request: TaskSubmission) -> str:
|
363
|
+
"""Submit a new task and return its ID."""
|
364
|
+
task_id = str(uuid.uuid4())
|
365
|
+
self.tasks[task_id] = {
|
366
|
+
"status": "pending",
|
367
|
+
"created_at": datetime.now().isoformat(),
|
368
|
+
"request": task_request.dict(),
|
369
|
+
}
|
370
|
+
self.task_queues[task_id] = asyncio.Queue()
|
371
|
+
return task_id
|
372
|
+
|
373
|
+
async def execute_task(self, task_id: str):
|
374
|
+
"""Execute a task asynchronously."""
|
375
|
+
try:
|
376
|
+
task = self.tasks[task_id]
|
377
|
+
task["status"] = "running"
|
378
|
+
task["started_at"] = datetime.now().isoformat()
|
379
|
+
|
380
|
+
# Initialize agent if needed
|
381
|
+
if not self.agent:
|
382
|
+
self.initialize_agent_with_sse_validation(task["request"]["model_name"])
|
383
|
+
|
384
|
+
# Execute task
|
385
|
+
loop = asyncio.get_event_loop()
|
386
|
+
result = await loop.run_in_executor(
|
387
|
+
None,
|
388
|
+
functools.partial(
|
389
|
+
self.agent.solve_task, task["request"]["task"], max_iterations=task["request"]["max_iterations"]
|
390
|
+
),
|
391
|
+
)
|
392
|
+
|
393
|
+
# Update task status
|
394
|
+
task["status"] = "completed"
|
395
|
+
task["completed_at"] = datetime.now().isoformat()
|
396
|
+
task["result"] = result
|
397
|
+
task["total_tokens"] = self.agent.total_tokens
|
398
|
+
task["model_name"] = self.get_current_model_name()
|
399
|
+
|
400
|
+
# Broadcast completion event to task-specific queue
|
401
|
+
self.broadcast_event(
|
402
|
+
"task_complete",
|
403
|
+
{
|
404
|
+
"task_id": task_id,
|
405
|
+
"result": result,
|
406
|
+
"total_tokens": self.agent.total_tokens,
|
407
|
+
"model_name": self.get_current_model_name(),
|
408
|
+
},
|
409
|
+
)
|
410
|
+
|
411
|
+
except Exception as e:
|
412
|
+
logger.error(f"Task execution failed: {e}", exc_info=True)
|
413
|
+
task["status"] = "failed"
|
414
|
+
task["completed_at"] = datetime.now().isoformat()
|
415
|
+
task["error"] = str(e)
|
416
|
+
|
417
|
+
# Broadcast error event to task-specific queue
|
418
|
+
self.broadcast_event("task_error", {"task_id": task_id, "error": str(e)})
|
419
|
+
|
420
|
+
async def get_task_event_queue(self, task_id: str) -> Queue:
|
421
|
+
"""Get or create a task-specific event queue."""
|
422
|
+
with self.queue_lock:
|
423
|
+
if task_id not in self.task_queues:
|
424
|
+
self.task_queues[task_id] = Queue()
|
425
|
+
return self.task_queues[task_id]
|
426
|
+
|
427
|
+
def remove_task_event_queue(self, task_id: str):
|
428
|
+
"""Remove a task-specific event queue."""
|
429
|
+
with self.queue_lock:
|
430
|
+
if task_id in self.task_queues:
|
431
|
+
del self.task_queues[task_id]
|
432
|
+
logger.info(f"Removed event queue for task_id: {task_id}")
|
433
|
+
|
434
|
+
|
435
|
+
# Initialize global states
|
436
|
+
server_state = ServerState()
|
437
|
+
agent_state = AgentState()
|
438
|
+
|
439
|
+
|
440
|
+
# Initialize FastAPI app
|
441
|
+
@asynccontextmanager
|
442
|
+
async def lifespan(app: FastAPI):
|
443
|
+
"""Lifecycle manager for FastAPI app."""
|
444
|
+
try:
|
445
|
+
# Setup signal handlers
|
446
|
+
loop = asyncio.get_running_loop()
|
447
|
+
for sig in (signal.SIGTERM, signal.SIGINT):
|
448
|
+
loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(handle_shutdown(s)))
|
449
|
+
yield
|
450
|
+
finally:
|
451
|
+
logger.info("Shutting down server gracefully...")
|
452
|
+
await server_state.initiate_shutdown()
|
453
|
+
await agent_state.cleanup()
|
454
|
+
server_state.shutdown_complete.set()
|
455
|
+
logger.info("Server shutdown complete")
|
456
|
+
|
457
|
+
|
458
|
+
async def handle_shutdown(sig):
|
459
|
+
"""Handle shutdown signals."""
|
460
|
+
if sig == signal.SIGINT and server_state.interrupt_count >= 1:
|
461
|
+
# Force exit on second CTRL+C
|
462
|
+
await server_state.initiate_shutdown(force=True)
|
463
|
+
else:
|
464
|
+
server_state.handle_interrupt()
|
465
|
+
|
466
|
+
|
467
|
+
app = FastAPI(
|
468
|
+
title="QuantaLogic API",
|
469
|
+
description="AI Agent Server for QuantaLogic",
|
470
|
+
version="0.1.0",
|
471
|
+
lifespan=lifespan,
|
472
|
+
)
|
473
|
+
|
474
|
+
# Add CORS middleware
|
475
|
+
app.add_middleware(
|
476
|
+
CORSMiddleware,
|
477
|
+
allow_origins=["*"],
|
478
|
+
allow_credentials=True,
|
479
|
+
allow_methods=["*"],
|
480
|
+
allow_headers=["*"],
|
481
|
+
)
|
482
|
+
|
483
|
+
# Mount static files
|
484
|
+
app.mount("/static", StaticFiles(directory="quantalogic/server/static"), name="static")
|
485
|
+
|
486
|
+
# Configure Jinja2 templates
|
487
|
+
templates = Jinja2Templates(directory="quantalogic/server/templates")
|
488
|
+
|
489
|
+
|
490
|
+
# Middleware to log requests
|
491
|
+
@app.middleware("http")
|
492
|
+
async def log_requests(request: Request, call_next):
|
493
|
+
"""Log all requests."""
|
494
|
+
start_time = time.time()
|
495
|
+
response = await call_next(request)
|
496
|
+
process_time = time.time() - start_time
|
497
|
+
|
498
|
+
logger.debug(
|
499
|
+
f"Path: {request.url.path} "
|
500
|
+
f"Method: {request.method} "
|
501
|
+
f"Time: {process_time:.3f}s "
|
502
|
+
f"Status: {response.status_code}"
|
503
|
+
)
|
504
|
+
|
505
|
+
return response
|
506
|
+
|
507
|
+
|
508
|
+
@app.post("/validate_response/{validation_id}")
|
509
|
+
async def submit_validation_response(validation_id: str, response: UserValidationResponse):
|
510
|
+
"""Submit a validation response."""
|
511
|
+
if validation_id not in agent_state.validation_responses:
|
512
|
+
raise HTTPException(status_code=404, detail="Validation request not found")
|
513
|
+
|
514
|
+
try:
|
515
|
+
response_queue = agent_state.validation_responses[validation_id]
|
516
|
+
await response_queue.put(response.response)
|
517
|
+
return JSONResponse(content={"status": "success"})
|
518
|
+
except Exception as e:
|
519
|
+
logger.error(f"Error processing validation response: {e}")
|
520
|
+
raise HTTPException(status_code=500, detail="Failed to process validation response")
|
521
|
+
|
522
|
+
|
523
|
+
@app.get("/events")
|
524
|
+
async def event_stream(request: Request, task_id: Optional[str] = None) -> StreamingResponse:
|
525
|
+
"""SSE endpoint for streaming agent events."""
|
526
|
+
|
527
|
+
async def event_generator() -> AsyncGenerator[str, None]:
|
528
|
+
# Ensure unique client-task combination
|
529
|
+
client_id = agent_state.add_client(task_id)
|
530
|
+
logger.info(f"Client {client_id} subscribed to {'task_id: ' + task_id if task_id else 'all events'}")
|
531
|
+
|
532
|
+
try:
|
533
|
+
while not server_state.is_shutting_down:
|
534
|
+
if await request.is_disconnected():
|
535
|
+
break
|
536
|
+
|
537
|
+
try:
|
538
|
+
# Prioritize task-specific queue if task_id is provided
|
539
|
+
if task_id:
|
540
|
+
event = agent_state.event_queues[client_id][task_id].get_nowait()
|
541
|
+
else:
|
542
|
+
# Fall back to global queue if no task_id
|
543
|
+
event = agent_state.event_queues[client_id]["global"].get_nowait()
|
544
|
+
|
545
|
+
# Yield the event
|
546
|
+
yield f"event: {event.event}\ndata: {json.dumps(event.dict())}\n\n"
|
547
|
+
|
548
|
+
except Empty:
|
549
|
+
# Send keepalive to maintain connection
|
550
|
+
yield ": keepalive\n\n"
|
551
|
+
await asyncio.sleep(0.1)
|
552
|
+
|
553
|
+
if server_state.is_shutting_down:
|
554
|
+
yield 'event: shutdown\ndata: {"message": "Server shutting down"}\n\n'
|
555
|
+
break
|
556
|
+
|
557
|
+
finally:
|
558
|
+
# Clean up the client's event queue
|
559
|
+
agent_state.remove_client(client_id, task_id)
|
560
|
+
logger.info(f"Client {client_id} {'unsubscribed from task_id: ' + task_id if task_id else 'disconnected'}")
|
561
|
+
|
562
|
+
return StreamingResponse(
|
563
|
+
event_generator(),
|
564
|
+
media_type="text/event-stream",
|
565
|
+
headers={
|
566
|
+
"Cache-Control": "no-cache",
|
567
|
+
"Connection": "keep-alive",
|
568
|
+
"Transfer-Encoding": "chunked",
|
569
|
+
},
|
570
|
+
)
|
571
|
+
|
572
|
+
|
573
|
+
@app.get("/")
|
574
|
+
async def get_index(request: Request) -> HTMLResponse:
|
575
|
+
"""Serve the main application page."""
|
576
|
+
response = templates.TemplateResponse("index.html", {"request": request})
|
577
|
+
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
|
578
|
+
response.headers["Pragma"] = "no-cache"
|
579
|
+
response.headers["Expires"] = "0"
|
580
|
+
return response
|
581
|
+
|
582
|
+
|
583
|
+
@app.post("/tasks")
|
584
|
+
async def submit_task(request: TaskSubmission) -> Dict[str, str]:
|
585
|
+
"""Submit a new task and return its ID."""
|
586
|
+
task_id = await agent_state.submit_task(request)
|
587
|
+
# Start task execution in background
|
588
|
+
asyncio.create_task(agent_state.execute_task(task_id))
|
589
|
+
return {"task_id": task_id}
|
590
|
+
|
591
|
+
|
592
|
+
@app.get("/tasks/{task_id}")
|
593
|
+
async def get_task_status(task_id: str) -> TaskStatus:
|
594
|
+
"""Get the status of a specific task."""
|
595
|
+
if task_id not in agent_state.tasks:
|
596
|
+
raise HTTPException(status_code=404, detail="Task not found")
|
597
|
+
|
598
|
+
task = agent_state.tasks[task_id]
|
599
|
+
return TaskStatus(task_id=task_id, **task)
|
600
|
+
|
601
|
+
|
602
|
+
@app.get("/tasks")
|
603
|
+
async def list_tasks(status: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[TaskStatus]:
|
604
|
+
"""List all tasks with optional filtering."""
|
605
|
+
tasks = []
|
606
|
+
for task_id, task in agent_state.tasks.items():
|
607
|
+
if status is None or task["status"] == status:
|
608
|
+
tasks.append(TaskStatus(task_id=task_id, **task))
|
609
|
+
|
610
|
+
return tasks[offset : offset + limit]
|
611
|
+
|
612
|
+
|
613
|
+
# Update the Agent initialization to use SSE validation by default
|
614
|
+
AgentState.initialize_agent = AgentState.initialize_agent_with_sse_validation
|
615
|
+
|
616
|
+
if __name__ == "__main__":
|
617
|
+
config = uvicorn.Config(
|
618
|
+
"quantalogic.agent_server:app",
|
619
|
+
host="0.0.0.0",
|
620
|
+
port=8000,
|
621
|
+
reload=True,
|
622
|
+
log_level="info",
|
623
|
+
timeout_keep_alive=5,
|
624
|
+
access_log=True,
|
625
|
+
timeout_graceful_shutdown=5, # Reduced from 10 to 5 seconds
|
626
|
+
)
|
627
|
+
server = uvicorn.Server(config)
|
628
|
+
server_state.server = server
|
629
|
+
try:
|
630
|
+
server.run()
|
631
|
+
except KeyboardInterrupt:
|
632
|
+
logger.info("Received keyboard interrupt")
|
633
|
+
sys.exit(1)
|
@@ -0,0 +1,60 @@
|
|
1
|
+
"""Pydantic models for the QuantaLogic API."""
|
2
|
+
|
3
|
+
from typing import Any, Dict, Optional
|
4
|
+
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
7
|
+
from quantalogic.agent_config import MODEL_NAME
|
8
|
+
|
9
|
+
|
10
|
+
class EventMessage(BaseModel):
|
11
|
+
"""Event message model for SSE."""
|
12
|
+
|
13
|
+
id: str
|
14
|
+
event: str
|
15
|
+
task_id: Optional[str] = None
|
16
|
+
data: Dict[str, Any]
|
17
|
+
timestamp: str
|
18
|
+
|
19
|
+
model_config = {"extra": "forbid"}
|
20
|
+
|
21
|
+
|
22
|
+
class UserValidationRequest(BaseModel):
|
23
|
+
"""Request model for user validation."""
|
24
|
+
|
25
|
+
question: str
|
26
|
+
validation_id: str | None = None
|
27
|
+
|
28
|
+
model_config = {"extra": "forbid"}
|
29
|
+
|
30
|
+
|
31
|
+
class UserValidationResponse(BaseModel):
|
32
|
+
"""Response model for user validation."""
|
33
|
+
|
34
|
+
response: bool
|
35
|
+
|
36
|
+
model_config = {"extra": "forbid"}
|
37
|
+
|
38
|
+
|
39
|
+
class TaskSubmission(BaseModel):
|
40
|
+
"""Request model for task submission."""
|
41
|
+
|
42
|
+
task: str
|
43
|
+
model_name: Optional[str] = MODEL_NAME
|
44
|
+
max_iterations: Optional[int] = 30
|
45
|
+
|
46
|
+
model_config = {"extra": "forbid"}
|
47
|
+
|
48
|
+
|
49
|
+
class TaskStatus(BaseModel):
|
50
|
+
"""Task status response model."""
|
51
|
+
|
52
|
+
task_id: str
|
53
|
+
status: str # "pending", "running", "completed", "failed"
|
54
|
+
created_at: str
|
55
|
+
started_at: Optional[str] = None
|
56
|
+
completed_at: Optional[str] = None
|
57
|
+
result: Optional[str] = None
|
58
|
+
error: Optional[str] = None
|
59
|
+
total_tokens: Optional[int] = None
|
60
|
+
model_name: Optional[str] = None
|