tarang 4.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tarang/ui/keyboard.py ADDED
@@ -0,0 +1,197 @@
1
+ """
2
+ Keyboard handler for interactive controls during execution.
3
+
4
+ ESC - Terminate current execution
5
+ SPACE - Pause and add extra instruction
6
+ """
7
+
8
+ import sys
9
+ import threading
10
+ from dataclasses import dataclass, field
11
+ from enum import Enum
12
+ from typing import Optional, Callable
13
+ import select
14
+
15
+ # Try to import platform-specific modules
16
+ try:
17
+ import termios
18
+ import tty
19
+ HAS_TERMIOS = True
20
+ except ImportError:
21
+ HAS_TERMIOS = False
22
+
23
+
24
+ class KeyAction(Enum):
25
+ """Keyboard actions."""
26
+ NONE = "none"
27
+ CANCEL = "cancel" # ESC pressed
28
+ PAUSE = "pause" # SPACE pressed
29
+
30
+
31
+ @dataclass
32
+ class KeyboardState:
33
+ """Shared state for keyboard monitoring."""
34
+ action: KeyAction = KeyAction.NONE
35
+ extra_instruction: Optional[str] = None
36
+ _lock: threading.Lock = field(default_factory=threading.Lock)
37
+ _running: bool = False
38
+ _thread: Optional[threading.Thread] = None
39
+ _original_settings: Optional[list] = None
40
+
41
+ def reset(self):
42
+ """Reset state for new execution."""
43
+ with self._lock:
44
+ self.action = KeyAction.NONE
45
+ self.extra_instruction = None
46
+
47
+ def set_cancel(self):
48
+ """Set cancel action."""
49
+ with self._lock:
50
+ self.action = KeyAction.CANCEL
51
+
52
+ def set_pause(self, instruction: str = None):
53
+ """Set pause action with optional instruction."""
54
+ with self._lock:
55
+ self.action = KeyAction.PAUSE
56
+ self.extra_instruction = instruction
57
+
58
+ def get_action(self) -> KeyAction:
59
+ """Get current action (thread-safe)."""
60
+ with self._lock:
61
+ return self.action
62
+
63
+ def consume_action(self) -> KeyAction:
64
+ """Get and clear action (thread-safe)."""
65
+ with self._lock:
66
+ action = self.action
67
+ self.action = KeyAction.NONE
68
+ return action
69
+
70
+
71
+ class KeyboardMonitor:
72
+ """
73
+ Monitor keyboard for ESC and SPACE during execution.
74
+
75
+ Usage:
76
+ monitor = KeyboardMonitor(console)
77
+ monitor.start()
78
+
79
+ while executing:
80
+ action = monitor.state.consume_action()
81
+ if action == KeyAction.CANCEL:
82
+ break
83
+ elif action == KeyAction.PAUSE:
84
+ extra = monitor.prompt_extra_instruction()
85
+ ...
86
+
87
+ monitor.stop()
88
+ """
89
+
90
+ def __init__(self, console=None, on_status: Callable[[str], None] = None):
91
+ """
92
+ Initialize keyboard monitor.
93
+
94
+ Args:
95
+ console: Rich console for output (optional)
96
+ on_status: Callback for status messages
97
+ """
98
+ self.console = console
99
+ self.on_status = on_status or (lambda x: None)
100
+ self.state = KeyboardState()
101
+ self._stop_event = threading.Event()
102
+
103
+ def start(self):
104
+ """Start keyboard monitoring."""
105
+ if not HAS_TERMIOS:
106
+ # Windows or non-terminal - skip keyboard monitoring
107
+ return
108
+
109
+ self.state.reset()
110
+ self._stop_event.clear()
111
+
112
+ # Save terminal settings
113
+ try:
114
+ self.state._original_settings = termios.tcgetattr(sys.stdin)
115
+ except Exception:
116
+ return
117
+
118
+ # Start monitor thread
119
+ self.state._thread = threading.Thread(target=self._monitor_loop, daemon=True)
120
+ self.state._running = True
121
+ self.state._thread.start()
122
+
123
+ def stop(self):
124
+ """Stop keyboard monitoring and restore terminal."""
125
+ self.state._running = False
126
+ self._stop_event.set()
127
+
128
+ # Restore terminal settings
129
+ if HAS_TERMIOS and self.state._original_settings:
130
+ try:
131
+ termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self.state._original_settings)
132
+ except Exception:
133
+ pass
134
+
135
+ # Wait for thread to finish
136
+ if self.state._thread and self.state._thread.is_alive():
137
+ self.state._thread.join(timeout=0.5)
138
+
139
+ def _monitor_loop(self):
140
+ """Background thread to monitor keyboard input."""
141
+ if not HAS_TERMIOS:
142
+ return
143
+
144
+ try:
145
+ # Set terminal to raw mode for single character input
146
+ tty.setcbreak(sys.stdin.fileno())
147
+
148
+ while self.state._running and not self._stop_event.is_set():
149
+ # Check if input is available (with timeout)
150
+ if select.select([sys.stdin], [], [], 0.1)[0]:
151
+ char = sys.stdin.read(1)
152
+
153
+ if char == '\x1b': # ESC
154
+ self.state.set_cancel()
155
+ self.on_status("[yellow]ESC pressed - cancelling...[/yellow]")
156
+
157
+ elif char == ' ': # SPACE
158
+ self.state.set_pause()
159
+ self.on_status("[cyan]SPACE pressed - pausing for instruction...[/cyan]")
160
+
161
+ except Exception:
162
+ pass
163
+ finally:
164
+ # Restore terminal
165
+ if self.state._original_settings:
166
+ try:
167
+ termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self.state._original_settings)
168
+ except Exception:
169
+ pass
170
+
171
+ def prompt_extra_instruction(self) -> Optional[str]:
172
+ """
173
+ Prompt user for extra instruction after SPACE.
174
+
175
+ Returns:
176
+ Extra instruction string, or None if cancelled
177
+ """
178
+ # Temporarily stop monitoring to get clean input
179
+ self.stop()
180
+
181
+ try:
182
+ if self.console:
183
+ self.console.print("\n[bold cyan]Add instruction:[/bold cyan] ", end="")
184
+
185
+ instruction = input().strip()
186
+ return instruction if instruction else None
187
+
188
+ except (KeyboardInterrupt, EOFError):
189
+ return None
190
+ finally:
191
+ # Resume monitoring
192
+ self.start()
193
+
194
+
195
+ def create_keyboard_hints() -> str:
196
+ """Create keyboard hints string for display."""
197
+ return "[dim]ESC=cancel SPACE=add instruction[/dim]"
tarang/ws/__init__.py ADDED
@@ -0,0 +1,14 @@
1
+ """
2
+ WebSocket module for hybrid agent architecture.
3
+
4
+ This module provides:
5
+ - TarangWSClient: WebSocket client for bidirectional communication
6
+ - ToolExecutor: Local tool execution (file ops, shell)
7
+ - MessageHandlers: Handle different message types from backend
8
+ """
9
+
10
+ from tarang.ws.client import TarangWSClient
11
+ from tarang.ws.executor import ToolExecutor
12
+ from tarang.ws.handlers import MessageHandlers
13
+
14
+ __all__ = ["TarangWSClient", "ToolExecutor", "MessageHandlers"]
tarang/ws/client.py ADDED
@@ -0,0 +1,464 @@
1
+ """
2
+ WebSocket Client for Hybrid Agent Architecture.
3
+
4
+ Manages bidirectional WebSocket communication with the Tarang backend:
5
+ - Receives tool requests from backend
6
+ - Executes tools locally
7
+ - Sends results back to backend
8
+ - Handles progress events and UI updates
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ import json
14
+ import logging
15
+ from dataclasses import dataclass, field
16
+ from enum import Enum
17
+ from typing import Any, AsyncIterator, Callable, Dict, Optional
18
+
19
+ import websockets
20
+ from websockets.client import WebSocketClientProtocol
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class EventType(str, Enum):
26
+ """WebSocket event types."""
27
+ CONNECTED = "connected"
28
+ THINKING = "thinking"
29
+ TOOL_REQUEST = "tool_request"
30
+ APPROVAL_REQUEST = "approval_request"
31
+ PHASE_START = "phase_start"
32
+ MILESTONE_UPDATE = "milestone_update"
33
+ PROGRESS = "progress"
34
+ COMPLETE = "complete"
35
+ ERROR = "error"
36
+ PAUSED = "paused"
37
+ HEARTBEAT = "heartbeat"
38
+ PONG = "pong"
39
+
40
+
41
+ @dataclass
42
+ class WSEvent:
43
+ """A WebSocket event from backend."""
44
+ type: EventType
45
+ data: Dict[str, Any] = field(default_factory=dict)
46
+ request_id: Optional[str] = None
47
+
48
+ @classmethod
49
+ def from_json(cls, data: Dict[str, Any]) -> "WSEvent":
50
+ """Create event from JSON data."""
51
+ event_type = data.get("type", "")
52
+ try:
53
+ etype = EventType(event_type)
54
+ except ValueError:
55
+ etype = EventType.ERROR
56
+
57
+ return cls(
58
+ type=etype,
59
+ data=data.get("data", data),
60
+ request_id=data.get("request_id"),
61
+ )
62
+
63
+
64
+ # Type alias for tool executor callback
65
+ ToolExecutorCallback = Callable[[str, Dict[str, Any]], Any]
66
+
67
+ # Type alias for approval callback (returns True if approved)
68
+ ApprovalCallback = Callable[[str, Dict[str, Any]], bool]
69
+
70
+
71
+ class TarangWSClient:
72
+ """
73
+ WebSocket client for hybrid agent communication.
74
+
75
+ Usage:
76
+ async with TarangWSClient(base_url, token, openrouter_key) as client:
77
+ async for event in client.execute(instruction, cwd):
78
+ # Handle event
79
+ if event.type == EventType.TOOL_REQUEST:
80
+ result = execute_tool(event.data)
81
+ await client.send_tool_result(event.request_id, result)
82
+ """
83
+
84
+ DEFAULT_BASE_URL = "wss://tarang-backend-intl-web-app-production.up.railway.app"
85
+
86
+ def __init__(
87
+ self,
88
+ base_url: Optional[str] = None,
89
+ token: Optional[str] = None,
90
+ openrouter_key: Optional[str] = None,
91
+ reconnect_attempts: int = 3,
92
+ reconnect_delay: float = 2.0,
93
+ auto_reconnect: bool = True,
94
+ ):
95
+ self.base_url = (base_url or self.DEFAULT_BASE_URL).replace("https://", "wss://").replace("http://", "ws://")
96
+ self.token = token
97
+ self.openrouter_key = openrouter_key
98
+ self.reconnect_attempts = reconnect_attempts
99
+ self.reconnect_delay = reconnect_delay
100
+ self.auto_reconnect = auto_reconnect
101
+
102
+ self._ws: Optional[WebSocketClientProtocol] = None
103
+ self._session_id: Optional[str] = None
104
+ self._connected = False
105
+ self._heartbeat_task: Optional[asyncio.Task] = None
106
+ self._current_job_id: Optional[str] = None
107
+ self._reconnect_callback: Optional[Callable[[], None]] = None
108
+
109
+ @property
110
+ def session_id(self) -> Optional[str]:
111
+ return self._session_id
112
+
113
+ @property
114
+ def current_job_id(self) -> Optional[str]:
115
+ return self._current_job_id
116
+
117
+ @property
118
+ def is_connected(self) -> bool:
119
+ return self._connected and self._ws is not None
120
+
121
+ def set_reconnect_callback(self, callback: Callable[[], None]):
122
+ """Set callback to be called on reconnection."""
123
+ self._reconnect_callback = callback
124
+
125
+ async def connect(self) -> str:
126
+ """
127
+ Connect to the WebSocket endpoint.
128
+
129
+ Returns:
130
+ Session ID from backend
131
+ """
132
+ if not self.token:
133
+ raise ValueError("Token is required")
134
+
135
+ if not self.openrouter_key:
136
+ raise ValueError("OpenRouter key is required")
137
+
138
+ # Build WebSocket URL
139
+ ws_url = f"{self.base_url}/v2/ws/agent?token={self.token}&openrouter_key={self.openrouter_key}"
140
+
141
+ logger.debug(f"Connecting to {self.base_url}/v2/ws/agent")
142
+
143
+ # Connect with retry
144
+ last_error = None
145
+ for attempt in range(self.reconnect_attempts):
146
+ try:
147
+ self._ws = await websockets.connect(
148
+ ws_url,
149
+ ping_interval=30,
150
+ ping_timeout=60, # Increased for slow LLM responses
151
+ close_timeout=10,
152
+ )
153
+
154
+ # Wait for connected event
155
+ raw = await asyncio.wait_for(self._ws.recv(), timeout=10.0)
156
+ data = json.loads(raw)
157
+
158
+ if data.get("type") == "connected":
159
+ self._session_id = data.get("data", {}).get("session_id")
160
+ self._connected = True
161
+ logger.info(f"Connected to Tarang (session: {self._session_id})")
162
+
163
+ # Start heartbeat
164
+ self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
165
+
166
+ return self._session_id
167
+ else:
168
+ raise ConnectionError(f"Unexpected response: {data}")
169
+
170
+ except Exception as e:
171
+ last_error = e
172
+ logger.warning(f"Connection attempt {attempt + 1} failed: {e}")
173
+ if attempt < self.reconnect_attempts - 1:
174
+ await asyncio.sleep(self.reconnect_delay)
175
+
176
+ raise ConnectionError(f"Failed to connect after {self.reconnect_attempts} attempts: {last_error}")
177
+
178
+ async def disconnect(self):
179
+ """Disconnect from WebSocket."""
180
+ self._connected = False
181
+
182
+ if self._heartbeat_task:
183
+ self._heartbeat_task.cancel()
184
+ self._heartbeat_task = None
185
+
186
+ if self._ws:
187
+ try:
188
+ await self._ws.close()
189
+ except Exception:
190
+ pass
191
+ self._ws = None
192
+
193
+ logger.info("Disconnected from Tarang")
194
+
195
+ async def reconnect(self) -> bool:
196
+ """
197
+ Attempt to reconnect after a disconnection.
198
+
199
+ Returns:
200
+ True if reconnection successful, False otherwise
201
+ """
202
+ logger.info("Attempting to reconnect...")
203
+
204
+ # Clean up old connection
205
+ if self._heartbeat_task:
206
+ self._heartbeat_task.cancel()
207
+ self._heartbeat_task = None
208
+
209
+ if self._ws:
210
+ try:
211
+ await self._ws.close()
212
+ except Exception:
213
+ pass
214
+ self._ws = None
215
+
216
+ self._connected = False
217
+
218
+ # Try to reconnect
219
+ for attempt in range(self.reconnect_attempts):
220
+ try:
221
+ await self.connect()
222
+ logger.info(f"Reconnected (attempt {attempt + 1})")
223
+
224
+ if self._reconnect_callback:
225
+ self._reconnect_callback()
226
+
227
+ return True
228
+
229
+ except Exception as e:
230
+ logger.warning(f"Reconnect attempt {attempt + 1} failed: {e}")
231
+ if attempt < self.reconnect_attempts - 1:
232
+ await asyncio.sleep(self.reconnect_delay * (attempt + 1))
233
+
234
+ logger.error("Failed to reconnect after all attempts")
235
+ return False
236
+
237
+ async def __aenter__(self) -> "TarangWSClient":
238
+ await self.connect()
239
+ return self
240
+
241
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
242
+ await self.disconnect()
243
+
244
+ async def execute(
245
+ self,
246
+ instruction: str,
247
+ cwd: str,
248
+ job_id: Optional[str] = None,
249
+ ) -> AsyncIterator[WSEvent]:
250
+ """
251
+ Execute an instruction and yield events.
252
+
253
+ This is the main entry point for hybrid execution:
254
+ 1. Send execute message to backend
255
+ 2. Yield events as they arrive
256
+ 3. Caller handles tool requests and sends results
257
+
258
+ Args:
259
+ instruction: User instruction
260
+ cwd: Current working directory
261
+ job_id: Optional job ID to resume
262
+
263
+ Yields:
264
+ WSEvent objects for each backend message
265
+ """
266
+ if not self._ws or not self._connected:
267
+ raise ConnectionError("Not connected")
268
+
269
+ # Track current job for potential resume
270
+ self._current_job_id = job_id
271
+
272
+ # Send execute request
273
+ await self._ws.send(json.dumps({
274
+ "type": "execute",
275
+ "instruction": instruction,
276
+ "cwd": cwd,
277
+ "job_id": job_id,
278
+ }))
279
+
280
+ logger.debug(f"Sent execute request: {instruction[:50]}...")
281
+
282
+ # Yield events until complete or error
283
+ reconnect_attempts = 0
284
+ max_reconnect_during_exec = 2
285
+
286
+ try:
287
+ while self._connected or (self.auto_reconnect and reconnect_attempts < max_reconnect_during_exec):
288
+ try:
289
+ raw = await asyncio.wait_for(self._ws.recv(), timeout=60.0)
290
+ data = json.loads(raw)
291
+ event = WSEvent.from_json(data)
292
+
293
+ # Track job_id from connected/progress events
294
+ if event.data.get("job_id"):
295
+ self._current_job_id = event.data["job_id"]
296
+
297
+ yield event
298
+
299
+ # Stop on complete or error
300
+ if event.type in (EventType.COMPLETE, EventType.ERROR, EventType.PAUSED):
301
+ self._current_job_id = None
302
+ break
303
+
304
+ except asyncio.TimeoutError:
305
+ # No message in 60s - might be waiting for tool result
306
+ continue
307
+
308
+ except websockets.ConnectionClosed as e:
309
+ logger.warning(f"Connection closed during execution: {e}")
310
+ self._connected = False
311
+
312
+ if not self.auto_reconnect:
313
+ yield WSEvent(type=EventType.ERROR, data={"message": f"Connection closed: {e}"})
314
+ break
315
+
316
+ # Try to reconnect and resume
317
+ reconnect_attempts += 1
318
+ yield WSEvent(
319
+ type=EventType.PROGRESS,
320
+ data={"message": f"Connection lost. Reconnecting (attempt {reconnect_attempts})..."}
321
+ )
322
+
323
+ if await self.reconnect():
324
+ # Resume the job if we have a job_id
325
+ if self._current_job_id:
326
+ logger.info(f"Resuming job {self._current_job_id}")
327
+ await self._ws.send(json.dumps({
328
+ "type": "resume",
329
+ "job_id": self._current_job_id,
330
+ "cwd": cwd,
331
+ }))
332
+ yield WSEvent(
333
+ type=EventType.PROGRESS,
334
+ data={"message": "Reconnected. Resuming..."}
335
+ )
336
+ else:
337
+ yield WSEvent(
338
+ type=EventType.ERROR,
339
+ data={"message": "Reconnected but no job to resume"}
340
+ )
341
+ break
342
+ else:
343
+ yield WSEvent(
344
+ type=EventType.ERROR,
345
+ data={"message": "Failed to reconnect after multiple attempts"}
346
+ )
347
+ break
348
+
349
+ except websockets.ConnectionClosed as e:
350
+ logger.warning(f"Connection closed: {e}")
351
+ self._connected = False
352
+ yield WSEvent(type=EventType.ERROR, data={"message": f"Connection closed: {e}"})
353
+
354
+ async def send_tool_result(
355
+ self,
356
+ request_id: str,
357
+ result: Dict[str, Any],
358
+ ):
359
+ """Send tool execution result to backend."""
360
+ if not self._ws or not self._connected:
361
+ raise ConnectionError("Not connected")
362
+
363
+ await self._ws.send(json.dumps({
364
+ "type": "tool_result",
365
+ "request_id": request_id,
366
+ "result": result,
367
+ }))
368
+
369
+ logger.debug(f"Sent tool result for {request_id}")
370
+
371
+ async def send_tool_error(
372
+ self,
373
+ request_id: str,
374
+ error: str,
375
+ ):
376
+ """Send tool error to backend."""
377
+ if not self._ws or not self._connected:
378
+ raise ConnectionError("Not connected")
379
+
380
+ await self._ws.send(json.dumps({
381
+ "type": "tool_error",
382
+ "request_id": request_id,
383
+ "error": error,
384
+ }))
385
+
386
+ logger.debug(f"Sent tool error for {request_id}: {error}")
387
+
388
+ async def send_approval(
389
+ self,
390
+ request_id: str,
391
+ approved: bool,
392
+ ):
393
+ """Send approval response to backend."""
394
+ if not self._ws or not self._connected:
395
+ raise ConnectionError("Not connected")
396
+
397
+ await self._ws.send(json.dumps({
398
+ "type": "approval",
399
+ "request_id": request_id,
400
+ "approved": approved,
401
+ }))
402
+
403
+ logger.debug(f"Sent approval for {request_id}: {approved}")
404
+
405
+ async def cancel(self):
406
+ """Cancel current execution."""
407
+ if not self._ws or not self._connected:
408
+ return
409
+
410
+ await self._ws.send(json.dumps({"type": "cancel"}))
411
+ logger.info("Sent cancel request")
412
+
413
+ async def _heartbeat_loop(self, interval: float = 25.0):
414
+ """Send periodic heartbeat pings."""
415
+ try:
416
+ while self._connected and self._ws:
417
+ await asyncio.sleep(interval)
418
+ try:
419
+ await self._ws.send(json.dumps({"type": "ping"}))
420
+ except Exception:
421
+ break
422
+ except asyncio.CancelledError:
423
+ pass
424
+
425
+
426
+ class WSClientPool:
427
+ """
428
+ Connection pool for WebSocket clients.
429
+
430
+ For future use with multiple concurrent connections.
431
+ """
432
+
433
+ def __init__(self, max_connections: int = 5):
434
+ self.max_connections = max_connections
435
+ self._clients: Dict[str, TarangWSClient] = {}
436
+
437
+ async def get_client(
438
+ self,
439
+ base_url: str,
440
+ token: str,
441
+ openrouter_key: str,
442
+ ) -> TarangWSClient:
443
+ """Get or create a client for the given credentials."""
444
+ key = f"{base_url}:{token[:8]}"
445
+
446
+ if key in self._clients and self._clients[key].is_connected:
447
+ return self._clients[key]
448
+
449
+ client = TarangWSClient(
450
+ base_url=base_url,
451
+ token=token,
452
+ openrouter_key=openrouter_key,
453
+ )
454
+
455
+ await client.connect()
456
+ self._clients[key] = client
457
+
458
+ return client
459
+
460
+ async def close_all(self):
461
+ """Close all connections."""
462
+ for client in self._clients.values():
463
+ await client.disconnect()
464
+ self._clients.clear()