tarang 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tarang/__init__.py +23 -0
- tarang/cli.py +1168 -0
- tarang/client/__init__.py +19 -0
- tarang/client/api_client.py +701 -0
- tarang/client/auth.py +178 -0
- tarang/context/__init__.py +41 -0
- tarang/context/bm25.py +218 -0
- tarang/context/chunker.py +984 -0
- tarang/context/graph.py +464 -0
- tarang/context/indexer.py +514 -0
- tarang/context/retriever.py +270 -0
- tarang/context/skeleton.py +282 -0
- tarang/context_collector.py +449 -0
- tarang/executor/__init__.py +6 -0
- tarang/executor/diff_apply.py +246 -0
- tarang/executor/linter.py +184 -0
- tarang/stream.py +1346 -0
- tarang/ui/__init__.py +7 -0
- tarang/ui/console.py +407 -0
- tarang/ui/diff_viewer.py +146 -0
- tarang/ui/formatter.py +1151 -0
- tarang/ui/keyboard.py +197 -0
- tarang/ws/__init__.py +14 -0
- tarang/ws/client.py +464 -0
- tarang/ws/executor.py +638 -0
- tarang/ws/handlers.py +590 -0
- tarang-4.4.0.dist-info/METADATA +102 -0
- tarang-4.4.0.dist-info/RECORD +31 -0
- tarang-4.4.0.dist-info/WHEEL +5 -0
- tarang-4.4.0.dist-info/entry_points.txt +2 -0
- tarang-4.4.0.dist-info/top_level.txt +1 -0
tarang/ui/keyboard.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Keyboard handler for interactive controls during execution.
|
|
3
|
+
|
|
4
|
+
ESC - Terminate current execution
|
|
5
|
+
SPACE - Pause and add extra instruction
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import sys
|
|
9
|
+
import threading
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Optional, Callable
|
|
13
|
+
import select
|
|
14
|
+
|
|
15
|
+
# Try to import platform-specific modules
|
|
16
|
+
try:
|
|
17
|
+
import termios
|
|
18
|
+
import tty
|
|
19
|
+
HAS_TERMIOS = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
HAS_TERMIOS = False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class KeyAction(Enum):
|
|
25
|
+
"""Keyboard actions."""
|
|
26
|
+
NONE = "none"
|
|
27
|
+
CANCEL = "cancel" # ESC pressed
|
|
28
|
+
PAUSE = "pause" # SPACE pressed
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class KeyboardState:
|
|
33
|
+
"""Shared state for keyboard monitoring."""
|
|
34
|
+
action: KeyAction = KeyAction.NONE
|
|
35
|
+
extra_instruction: Optional[str] = None
|
|
36
|
+
_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
37
|
+
_running: bool = False
|
|
38
|
+
_thread: Optional[threading.Thread] = None
|
|
39
|
+
_original_settings: Optional[list] = None
|
|
40
|
+
|
|
41
|
+
def reset(self):
|
|
42
|
+
"""Reset state for new execution."""
|
|
43
|
+
with self._lock:
|
|
44
|
+
self.action = KeyAction.NONE
|
|
45
|
+
self.extra_instruction = None
|
|
46
|
+
|
|
47
|
+
def set_cancel(self):
|
|
48
|
+
"""Set cancel action."""
|
|
49
|
+
with self._lock:
|
|
50
|
+
self.action = KeyAction.CANCEL
|
|
51
|
+
|
|
52
|
+
def set_pause(self, instruction: str = None):
|
|
53
|
+
"""Set pause action with optional instruction."""
|
|
54
|
+
with self._lock:
|
|
55
|
+
self.action = KeyAction.PAUSE
|
|
56
|
+
self.extra_instruction = instruction
|
|
57
|
+
|
|
58
|
+
def get_action(self) -> KeyAction:
|
|
59
|
+
"""Get current action (thread-safe)."""
|
|
60
|
+
with self._lock:
|
|
61
|
+
return self.action
|
|
62
|
+
|
|
63
|
+
def consume_action(self) -> KeyAction:
|
|
64
|
+
"""Get and clear action (thread-safe)."""
|
|
65
|
+
with self._lock:
|
|
66
|
+
action = self.action
|
|
67
|
+
self.action = KeyAction.NONE
|
|
68
|
+
return action
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class KeyboardMonitor:
|
|
72
|
+
"""
|
|
73
|
+
Monitor keyboard for ESC and SPACE during execution.
|
|
74
|
+
|
|
75
|
+
Usage:
|
|
76
|
+
monitor = KeyboardMonitor(console)
|
|
77
|
+
monitor.start()
|
|
78
|
+
|
|
79
|
+
while executing:
|
|
80
|
+
action = monitor.state.consume_action()
|
|
81
|
+
if action == KeyAction.CANCEL:
|
|
82
|
+
break
|
|
83
|
+
elif action == KeyAction.PAUSE:
|
|
84
|
+
extra = monitor.prompt_extra_instruction()
|
|
85
|
+
...
|
|
86
|
+
|
|
87
|
+
monitor.stop()
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(self, console=None, on_status: Callable[[str], None] = None):
|
|
91
|
+
"""
|
|
92
|
+
Initialize keyboard monitor.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
console: Rich console for output (optional)
|
|
96
|
+
on_status: Callback for status messages
|
|
97
|
+
"""
|
|
98
|
+
self.console = console
|
|
99
|
+
self.on_status = on_status or (lambda x: None)
|
|
100
|
+
self.state = KeyboardState()
|
|
101
|
+
self._stop_event = threading.Event()
|
|
102
|
+
|
|
103
|
+
def start(self):
|
|
104
|
+
"""Start keyboard monitoring."""
|
|
105
|
+
if not HAS_TERMIOS:
|
|
106
|
+
# Windows or non-terminal - skip keyboard monitoring
|
|
107
|
+
return
|
|
108
|
+
|
|
109
|
+
self.state.reset()
|
|
110
|
+
self._stop_event.clear()
|
|
111
|
+
|
|
112
|
+
# Save terminal settings
|
|
113
|
+
try:
|
|
114
|
+
self.state._original_settings = termios.tcgetattr(sys.stdin)
|
|
115
|
+
except Exception:
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
# Start monitor thread
|
|
119
|
+
self.state._thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
|
120
|
+
self.state._running = True
|
|
121
|
+
self.state._thread.start()
|
|
122
|
+
|
|
123
|
+
def stop(self):
|
|
124
|
+
"""Stop keyboard monitoring and restore terminal."""
|
|
125
|
+
self.state._running = False
|
|
126
|
+
self._stop_event.set()
|
|
127
|
+
|
|
128
|
+
# Restore terminal settings
|
|
129
|
+
if HAS_TERMIOS and self.state._original_settings:
|
|
130
|
+
try:
|
|
131
|
+
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self.state._original_settings)
|
|
132
|
+
except Exception:
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
# Wait for thread to finish
|
|
136
|
+
if self.state._thread and self.state._thread.is_alive():
|
|
137
|
+
self.state._thread.join(timeout=0.5)
|
|
138
|
+
|
|
139
|
+
def _monitor_loop(self):
|
|
140
|
+
"""Background thread to monitor keyboard input."""
|
|
141
|
+
if not HAS_TERMIOS:
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
# Set terminal to raw mode for single character input
|
|
146
|
+
tty.setcbreak(sys.stdin.fileno())
|
|
147
|
+
|
|
148
|
+
while self.state._running and not self._stop_event.is_set():
|
|
149
|
+
# Check if input is available (with timeout)
|
|
150
|
+
if select.select([sys.stdin], [], [], 0.1)[0]:
|
|
151
|
+
char = sys.stdin.read(1)
|
|
152
|
+
|
|
153
|
+
if char == '\x1b': # ESC
|
|
154
|
+
self.state.set_cancel()
|
|
155
|
+
self.on_status("[yellow]ESC pressed - cancelling...[/yellow]")
|
|
156
|
+
|
|
157
|
+
elif char == ' ': # SPACE
|
|
158
|
+
self.state.set_pause()
|
|
159
|
+
self.on_status("[cyan]SPACE pressed - pausing for instruction...[/cyan]")
|
|
160
|
+
|
|
161
|
+
except Exception:
|
|
162
|
+
pass
|
|
163
|
+
finally:
|
|
164
|
+
# Restore terminal
|
|
165
|
+
if self.state._original_settings:
|
|
166
|
+
try:
|
|
167
|
+
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self.state._original_settings)
|
|
168
|
+
except Exception:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
def prompt_extra_instruction(self) -> Optional[str]:
|
|
172
|
+
"""
|
|
173
|
+
Prompt user for extra instruction after SPACE.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Extra instruction string, or None if cancelled
|
|
177
|
+
"""
|
|
178
|
+
# Temporarily stop monitoring to get clean input
|
|
179
|
+
self.stop()
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
if self.console:
|
|
183
|
+
self.console.print("\n[bold cyan]Add instruction:[/bold cyan] ", end="")
|
|
184
|
+
|
|
185
|
+
instruction = input().strip()
|
|
186
|
+
return instruction if instruction else None
|
|
187
|
+
|
|
188
|
+
except (KeyboardInterrupt, EOFError):
|
|
189
|
+
return None
|
|
190
|
+
finally:
|
|
191
|
+
# Resume monitoring
|
|
192
|
+
self.start()
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def create_keyboard_hints() -> str:
|
|
196
|
+
"""Create keyboard hints string for display."""
|
|
197
|
+
return "[dim]ESC=cancel SPACE=add instruction[/dim]"
|
tarang/ws/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WebSocket module for hybrid agent architecture.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
- TarangWSClient: WebSocket client for bidirectional communication
|
|
6
|
+
- ToolExecutor: Local tool execution (file ops, shell)
|
|
7
|
+
- MessageHandlers: Handle different message types from backend
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from tarang.ws.client import TarangWSClient
|
|
11
|
+
from tarang.ws.executor import ToolExecutor
|
|
12
|
+
from tarang.ws.handlers import MessageHandlers
|
|
13
|
+
|
|
14
|
+
__all__ = ["TarangWSClient", "ToolExecutor", "MessageHandlers"]
|
tarang/ws/client.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WebSocket Client for Hybrid Agent Architecture.
|
|
3
|
+
|
|
4
|
+
Manages bidirectional WebSocket communication with the Tarang backend:
|
|
5
|
+
- Receives tool requests from backend
|
|
6
|
+
- Executes tools locally
|
|
7
|
+
- Sends results back to backend
|
|
8
|
+
- Handles progress events and UI updates
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from enum import Enum
|
|
17
|
+
from typing import Any, AsyncIterator, Callable, Dict, Optional
|
|
18
|
+
|
|
19
|
+
import websockets
|
|
20
|
+
from websockets.client import WebSocketClientProtocol
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class EventType(str, Enum):
|
|
26
|
+
"""WebSocket event types."""
|
|
27
|
+
CONNECTED = "connected"
|
|
28
|
+
THINKING = "thinking"
|
|
29
|
+
TOOL_REQUEST = "tool_request"
|
|
30
|
+
APPROVAL_REQUEST = "approval_request"
|
|
31
|
+
PHASE_START = "phase_start"
|
|
32
|
+
MILESTONE_UPDATE = "milestone_update"
|
|
33
|
+
PROGRESS = "progress"
|
|
34
|
+
COMPLETE = "complete"
|
|
35
|
+
ERROR = "error"
|
|
36
|
+
PAUSED = "paused"
|
|
37
|
+
HEARTBEAT = "heartbeat"
|
|
38
|
+
PONG = "pong"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class WSEvent:
|
|
43
|
+
"""A WebSocket event from backend."""
|
|
44
|
+
type: EventType
|
|
45
|
+
data: Dict[str, Any] = field(default_factory=dict)
|
|
46
|
+
request_id: Optional[str] = None
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def from_json(cls, data: Dict[str, Any]) -> "WSEvent":
|
|
50
|
+
"""Create event from JSON data."""
|
|
51
|
+
event_type = data.get("type", "")
|
|
52
|
+
try:
|
|
53
|
+
etype = EventType(event_type)
|
|
54
|
+
except ValueError:
|
|
55
|
+
etype = EventType.ERROR
|
|
56
|
+
|
|
57
|
+
return cls(
|
|
58
|
+
type=etype,
|
|
59
|
+
data=data.get("data", data),
|
|
60
|
+
request_id=data.get("request_id"),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Type alias for tool executor callback
|
|
65
|
+
ToolExecutorCallback = Callable[[str, Dict[str, Any]], Any]
|
|
66
|
+
|
|
67
|
+
# Type alias for approval callback (returns True if approved)
|
|
68
|
+
ApprovalCallback = Callable[[str, Dict[str, Any]], bool]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class TarangWSClient:
|
|
72
|
+
"""
|
|
73
|
+
WebSocket client for hybrid agent communication.
|
|
74
|
+
|
|
75
|
+
Usage:
|
|
76
|
+
async with TarangWSClient(base_url, token, openrouter_key) as client:
|
|
77
|
+
async for event in client.execute(instruction, cwd):
|
|
78
|
+
# Handle event
|
|
79
|
+
if event.type == EventType.TOOL_REQUEST:
|
|
80
|
+
result = execute_tool(event.data)
|
|
81
|
+
await client.send_tool_result(event.request_id, result)
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
DEFAULT_BASE_URL = "wss://tarang-backend-intl-web-app-production.up.railway.app"
|
|
85
|
+
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
base_url: Optional[str] = None,
|
|
89
|
+
token: Optional[str] = None,
|
|
90
|
+
openrouter_key: Optional[str] = None,
|
|
91
|
+
reconnect_attempts: int = 3,
|
|
92
|
+
reconnect_delay: float = 2.0,
|
|
93
|
+
auto_reconnect: bool = True,
|
|
94
|
+
):
|
|
95
|
+
self.base_url = (base_url or self.DEFAULT_BASE_URL).replace("https://", "wss://").replace("http://", "ws://")
|
|
96
|
+
self.token = token
|
|
97
|
+
self.openrouter_key = openrouter_key
|
|
98
|
+
self.reconnect_attempts = reconnect_attempts
|
|
99
|
+
self.reconnect_delay = reconnect_delay
|
|
100
|
+
self.auto_reconnect = auto_reconnect
|
|
101
|
+
|
|
102
|
+
self._ws: Optional[WebSocketClientProtocol] = None
|
|
103
|
+
self._session_id: Optional[str] = None
|
|
104
|
+
self._connected = False
|
|
105
|
+
self._heartbeat_task: Optional[asyncio.Task] = None
|
|
106
|
+
self._current_job_id: Optional[str] = None
|
|
107
|
+
self._reconnect_callback: Optional[Callable[[], None]] = None
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def session_id(self) -> Optional[str]:
|
|
111
|
+
return self._session_id
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def current_job_id(self) -> Optional[str]:
|
|
115
|
+
return self._current_job_id
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def is_connected(self) -> bool:
|
|
119
|
+
return self._connected and self._ws is not None
|
|
120
|
+
|
|
121
|
+
def set_reconnect_callback(self, callback: Callable[[], None]):
|
|
122
|
+
"""Set callback to be called on reconnection."""
|
|
123
|
+
self._reconnect_callback = callback
|
|
124
|
+
|
|
125
|
+
async def connect(self) -> str:
|
|
126
|
+
"""
|
|
127
|
+
Connect to the WebSocket endpoint.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Session ID from backend
|
|
131
|
+
"""
|
|
132
|
+
if not self.token:
|
|
133
|
+
raise ValueError("Token is required")
|
|
134
|
+
|
|
135
|
+
if not self.openrouter_key:
|
|
136
|
+
raise ValueError("OpenRouter key is required")
|
|
137
|
+
|
|
138
|
+
# Build WebSocket URL
|
|
139
|
+
ws_url = f"{self.base_url}/v2/ws/agent?token={self.token}&openrouter_key={self.openrouter_key}"
|
|
140
|
+
|
|
141
|
+
logger.debug(f"Connecting to {self.base_url}/v2/ws/agent")
|
|
142
|
+
|
|
143
|
+
# Connect with retry
|
|
144
|
+
last_error = None
|
|
145
|
+
for attempt in range(self.reconnect_attempts):
|
|
146
|
+
try:
|
|
147
|
+
self._ws = await websockets.connect(
|
|
148
|
+
ws_url,
|
|
149
|
+
ping_interval=30,
|
|
150
|
+
ping_timeout=60, # Increased for slow LLM responses
|
|
151
|
+
close_timeout=10,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Wait for connected event
|
|
155
|
+
raw = await asyncio.wait_for(self._ws.recv(), timeout=10.0)
|
|
156
|
+
data = json.loads(raw)
|
|
157
|
+
|
|
158
|
+
if data.get("type") == "connected":
|
|
159
|
+
self._session_id = data.get("data", {}).get("session_id")
|
|
160
|
+
self._connected = True
|
|
161
|
+
logger.info(f"Connected to Tarang (session: {self._session_id})")
|
|
162
|
+
|
|
163
|
+
# Start heartbeat
|
|
164
|
+
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
|
165
|
+
|
|
166
|
+
return self._session_id
|
|
167
|
+
else:
|
|
168
|
+
raise ConnectionError(f"Unexpected response: {data}")
|
|
169
|
+
|
|
170
|
+
except Exception as e:
|
|
171
|
+
last_error = e
|
|
172
|
+
logger.warning(f"Connection attempt {attempt + 1} failed: {e}")
|
|
173
|
+
if attempt < self.reconnect_attempts - 1:
|
|
174
|
+
await asyncio.sleep(self.reconnect_delay)
|
|
175
|
+
|
|
176
|
+
raise ConnectionError(f"Failed to connect after {self.reconnect_attempts} attempts: {last_error}")
|
|
177
|
+
|
|
178
|
+
async def disconnect(self):
|
|
179
|
+
"""Disconnect from WebSocket."""
|
|
180
|
+
self._connected = False
|
|
181
|
+
|
|
182
|
+
if self._heartbeat_task:
|
|
183
|
+
self._heartbeat_task.cancel()
|
|
184
|
+
self._heartbeat_task = None
|
|
185
|
+
|
|
186
|
+
if self._ws:
|
|
187
|
+
try:
|
|
188
|
+
await self._ws.close()
|
|
189
|
+
except Exception:
|
|
190
|
+
pass
|
|
191
|
+
self._ws = None
|
|
192
|
+
|
|
193
|
+
logger.info("Disconnected from Tarang")
|
|
194
|
+
|
|
195
|
+
async def reconnect(self) -> bool:
|
|
196
|
+
"""
|
|
197
|
+
Attempt to reconnect after a disconnection.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
True if reconnection successful, False otherwise
|
|
201
|
+
"""
|
|
202
|
+
logger.info("Attempting to reconnect...")
|
|
203
|
+
|
|
204
|
+
# Clean up old connection
|
|
205
|
+
if self._heartbeat_task:
|
|
206
|
+
self._heartbeat_task.cancel()
|
|
207
|
+
self._heartbeat_task = None
|
|
208
|
+
|
|
209
|
+
if self._ws:
|
|
210
|
+
try:
|
|
211
|
+
await self._ws.close()
|
|
212
|
+
except Exception:
|
|
213
|
+
pass
|
|
214
|
+
self._ws = None
|
|
215
|
+
|
|
216
|
+
self._connected = False
|
|
217
|
+
|
|
218
|
+
# Try to reconnect
|
|
219
|
+
for attempt in range(self.reconnect_attempts):
|
|
220
|
+
try:
|
|
221
|
+
await self.connect()
|
|
222
|
+
logger.info(f"Reconnected (attempt {attempt + 1})")
|
|
223
|
+
|
|
224
|
+
if self._reconnect_callback:
|
|
225
|
+
self._reconnect_callback()
|
|
226
|
+
|
|
227
|
+
return True
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
logger.warning(f"Reconnect attempt {attempt + 1} failed: {e}")
|
|
231
|
+
if attempt < self.reconnect_attempts - 1:
|
|
232
|
+
await asyncio.sleep(self.reconnect_delay * (attempt + 1))
|
|
233
|
+
|
|
234
|
+
logger.error("Failed to reconnect after all attempts")
|
|
235
|
+
return False
|
|
236
|
+
|
|
237
|
+
async def __aenter__(self) -> "TarangWSClient":
|
|
238
|
+
await self.connect()
|
|
239
|
+
return self
|
|
240
|
+
|
|
241
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
242
|
+
await self.disconnect()
|
|
243
|
+
|
|
244
|
+
async def execute(
|
|
245
|
+
self,
|
|
246
|
+
instruction: str,
|
|
247
|
+
cwd: str,
|
|
248
|
+
job_id: Optional[str] = None,
|
|
249
|
+
) -> AsyncIterator[WSEvent]:
|
|
250
|
+
"""
|
|
251
|
+
Execute an instruction and yield events.
|
|
252
|
+
|
|
253
|
+
This is the main entry point for hybrid execution:
|
|
254
|
+
1. Send execute message to backend
|
|
255
|
+
2. Yield events as they arrive
|
|
256
|
+
3. Caller handles tool requests and sends results
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
instruction: User instruction
|
|
260
|
+
cwd: Current working directory
|
|
261
|
+
job_id: Optional job ID to resume
|
|
262
|
+
|
|
263
|
+
Yields:
|
|
264
|
+
WSEvent objects for each backend message
|
|
265
|
+
"""
|
|
266
|
+
if not self._ws or not self._connected:
|
|
267
|
+
raise ConnectionError("Not connected")
|
|
268
|
+
|
|
269
|
+
# Track current job for potential resume
|
|
270
|
+
self._current_job_id = job_id
|
|
271
|
+
|
|
272
|
+
# Send execute request
|
|
273
|
+
await self._ws.send(json.dumps({
|
|
274
|
+
"type": "execute",
|
|
275
|
+
"instruction": instruction,
|
|
276
|
+
"cwd": cwd,
|
|
277
|
+
"job_id": job_id,
|
|
278
|
+
}))
|
|
279
|
+
|
|
280
|
+
logger.debug(f"Sent execute request: {instruction[:50]}...")
|
|
281
|
+
|
|
282
|
+
# Yield events until complete or error
|
|
283
|
+
reconnect_attempts = 0
|
|
284
|
+
max_reconnect_during_exec = 2
|
|
285
|
+
|
|
286
|
+
try:
|
|
287
|
+
while self._connected or (self.auto_reconnect and reconnect_attempts < max_reconnect_during_exec):
|
|
288
|
+
try:
|
|
289
|
+
raw = await asyncio.wait_for(self._ws.recv(), timeout=60.0)
|
|
290
|
+
data = json.loads(raw)
|
|
291
|
+
event = WSEvent.from_json(data)
|
|
292
|
+
|
|
293
|
+
# Track job_id from connected/progress events
|
|
294
|
+
if event.data.get("job_id"):
|
|
295
|
+
self._current_job_id = event.data["job_id"]
|
|
296
|
+
|
|
297
|
+
yield event
|
|
298
|
+
|
|
299
|
+
# Stop on complete or error
|
|
300
|
+
if event.type in (EventType.COMPLETE, EventType.ERROR, EventType.PAUSED):
|
|
301
|
+
self._current_job_id = None
|
|
302
|
+
break
|
|
303
|
+
|
|
304
|
+
except asyncio.TimeoutError:
|
|
305
|
+
# No message in 60s - might be waiting for tool result
|
|
306
|
+
continue
|
|
307
|
+
|
|
308
|
+
except websockets.ConnectionClosed as e:
|
|
309
|
+
logger.warning(f"Connection closed during execution: {e}")
|
|
310
|
+
self._connected = False
|
|
311
|
+
|
|
312
|
+
if not self.auto_reconnect:
|
|
313
|
+
yield WSEvent(type=EventType.ERROR, data={"message": f"Connection closed: {e}"})
|
|
314
|
+
break
|
|
315
|
+
|
|
316
|
+
# Try to reconnect and resume
|
|
317
|
+
reconnect_attempts += 1
|
|
318
|
+
yield WSEvent(
|
|
319
|
+
type=EventType.PROGRESS,
|
|
320
|
+
data={"message": f"Connection lost. Reconnecting (attempt {reconnect_attempts})..."}
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
if await self.reconnect():
|
|
324
|
+
# Resume the job if we have a job_id
|
|
325
|
+
if self._current_job_id:
|
|
326
|
+
logger.info(f"Resuming job {self._current_job_id}")
|
|
327
|
+
await self._ws.send(json.dumps({
|
|
328
|
+
"type": "resume",
|
|
329
|
+
"job_id": self._current_job_id,
|
|
330
|
+
"cwd": cwd,
|
|
331
|
+
}))
|
|
332
|
+
yield WSEvent(
|
|
333
|
+
type=EventType.PROGRESS,
|
|
334
|
+
data={"message": "Reconnected. Resuming..."}
|
|
335
|
+
)
|
|
336
|
+
else:
|
|
337
|
+
yield WSEvent(
|
|
338
|
+
type=EventType.ERROR,
|
|
339
|
+
data={"message": "Reconnected but no job to resume"}
|
|
340
|
+
)
|
|
341
|
+
break
|
|
342
|
+
else:
|
|
343
|
+
yield WSEvent(
|
|
344
|
+
type=EventType.ERROR,
|
|
345
|
+
data={"message": "Failed to reconnect after multiple attempts"}
|
|
346
|
+
)
|
|
347
|
+
break
|
|
348
|
+
|
|
349
|
+
except websockets.ConnectionClosed as e:
|
|
350
|
+
logger.warning(f"Connection closed: {e}")
|
|
351
|
+
self._connected = False
|
|
352
|
+
yield WSEvent(type=EventType.ERROR, data={"message": f"Connection closed: {e}"})
|
|
353
|
+
|
|
354
|
+
async def send_tool_result(
|
|
355
|
+
self,
|
|
356
|
+
request_id: str,
|
|
357
|
+
result: Dict[str, Any],
|
|
358
|
+
):
|
|
359
|
+
"""Send tool execution result to backend."""
|
|
360
|
+
if not self._ws or not self._connected:
|
|
361
|
+
raise ConnectionError("Not connected")
|
|
362
|
+
|
|
363
|
+
await self._ws.send(json.dumps({
|
|
364
|
+
"type": "tool_result",
|
|
365
|
+
"request_id": request_id,
|
|
366
|
+
"result": result,
|
|
367
|
+
}))
|
|
368
|
+
|
|
369
|
+
logger.debug(f"Sent tool result for {request_id}")
|
|
370
|
+
|
|
371
|
+
async def send_tool_error(
|
|
372
|
+
self,
|
|
373
|
+
request_id: str,
|
|
374
|
+
error: str,
|
|
375
|
+
):
|
|
376
|
+
"""Send tool error to backend."""
|
|
377
|
+
if not self._ws or not self._connected:
|
|
378
|
+
raise ConnectionError("Not connected")
|
|
379
|
+
|
|
380
|
+
await self._ws.send(json.dumps({
|
|
381
|
+
"type": "tool_error",
|
|
382
|
+
"request_id": request_id,
|
|
383
|
+
"error": error,
|
|
384
|
+
}))
|
|
385
|
+
|
|
386
|
+
logger.debug(f"Sent tool error for {request_id}: {error}")
|
|
387
|
+
|
|
388
|
+
async def send_approval(
|
|
389
|
+
self,
|
|
390
|
+
request_id: str,
|
|
391
|
+
approved: bool,
|
|
392
|
+
):
|
|
393
|
+
"""Send approval response to backend."""
|
|
394
|
+
if not self._ws or not self._connected:
|
|
395
|
+
raise ConnectionError("Not connected")
|
|
396
|
+
|
|
397
|
+
await self._ws.send(json.dumps({
|
|
398
|
+
"type": "approval",
|
|
399
|
+
"request_id": request_id,
|
|
400
|
+
"approved": approved,
|
|
401
|
+
}))
|
|
402
|
+
|
|
403
|
+
logger.debug(f"Sent approval for {request_id}: {approved}")
|
|
404
|
+
|
|
405
|
+
async def cancel(self):
|
|
406
|
+
"""Cancel current execution."""
|
|
407
|
+
if not self._ws or not self._connected:
|
|
408
|
+
return
|
|
409
|
+
|
|
410
|
+
await self._ws.send(json.dumps({"type": "cancel"}))
|
|
411
|
+
logger.info("Sent cancel request")
|
|
412
|
+
|
|
413
|
+
async def _heartbeat_loop(self, interval: float = 25.0):
|
|
414
|
+
"""Send periodic heartbeat pings."""
|
|
415
|
+
try:
|
|
416
|
+
while self._connected and self._ws:
|
|
417
|
+
await asyncio.sleep(interval)
|
|
418
|
+
try:
|
|
419
|
+
await self._ws.send(json.dumps({"type": "ping"}))
|
|
420
|
+
except Exception:
|
|
421
|
+
break
|
|
422
|
+
except asyncio.CancelledError:
|
|
423
|
+
pass
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
class WSClientPool:
|
|
427
|
+
"""
|
|
428
|
+
Connection pool for WebSocket clients.
|
|
429
|
+
|
|
430
|
+
For future use with multiple concurrent connections.
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
def __init__(self, max_connections: int = 5):
|
|
434
|
+
self.max_connections = max_connections
|
|
435
|
+
self._clients: Dict[str, TarangWSClient] = {}
|
|
436
|
+
|
|
437
|
+
async def get_client(
|
|
438
|
+
self,
|
|
439
|
+
base_url: str,
|
|
440
|
+
token: str,
|
|
441
|
+
openrouter_key: str,
|
|
442
|
+
) -> TarangWSClient:
|
|
443
|
+
"""Get or create a client for the given credentials."""
|
|
444
|
+
key = f"{base_url}:{token[:8]}"
|
|
445
|
+
|
|
446
|
+
if key in self._clients and self._clients[key].is_connected:
|
|
447
|
+
return self._clients[key]
|
|
448
|
+
|
|
449
|
+
client = TarangWSClient(
|
|
450
|
+
base_url=base_url,
|
|
451
|
+
token=token,
|
|
452
|
+
openrouter_key=openrouter_key,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
await client.connect()
|
|
456
|
+
self._clients[key] = client
|
|
457
|
+
|
|
458
|
+
return client
|
|
459
|
+
|
|
460
|
+
async def close_all(self):
|
|
461
|
+
"""Close all connections."""
|
|
462
|
+
for client in self._clients.values():
|
|
463
|
+
await client.disconnect()
|
|
464
|
+
self._clients.clear()
|