msapling-cli 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
msapling_cli/api.py ADDED
@@ -0,0 +1,394 @@
1
+ """HTTP client for MSapling backend API with retry logic and offline detection."""
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ import json
6
+ import uuid
7
+ from typing import Any, AsyncGenerator, Dict, List, Optional
8
+
9
+ import httpx
10
+
11
+ from .config import get_settings, get_token
12
+
13
+
14
+ _DIFF_SYSTEM_PROMPT = (
15
+ "You are a code assistant running in a CLI terminal. "
16
+ "When the user asks you to modify, edit, fix, refactor, or write code, "
17
+ "ALWAYS respond with a unified diff (--- / +++ / @@ format). "
18
+ "Do NOT output the full file — only the diff of what changed. "
19
+ "For non-code questions (explanations, debugging help, general questions), "
20
+ "respond normally but keep answers concise and terminal-friendly."
21
+ )
22
+
23
+ # Retry configuration
24
+ MAX_RETRIES = 3
25
+ RETRY_BACKOFF_BASE = 1.0 # seconds
26
+ RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504}
27
+
28
+
29
+ class APIError(Exception):
30
+ """Raised when an API call fails after all retries."""
31
+
32
+ def __init__(self, message: str, status_code: int = 0):
33
+ super().__init__(message)
34
+ self.status_code = status_code
35
+
36
+
37
+ class OfflineError(Exception):
38
+ """Raised when the API is unreachable."""
39
+ pass
40
+
41
+
42
+ async def _retry_request(
43
+ fn,
44
+ *,
45
+ retries: int = MAX_RETRIES,
46
+ backoff: float = RETRY_BACKOFF_BASE,
47
+ ) -> httpx.Response:
48
+ """Execute an HTTP request with exponential backoff retry on transient errors."""
49
+ last_exc = None
50
+ for attempt in range(retries):
51
+ try:
52
+ resp = await fn()
53
+ if resp.status_code not in RETRYABLE_STATUS_CODES:
54
+ return resp
55
+ # Retryable status — wait and try again
56
+ if attempt < retries - 1:
57
+ wait = backoff * (2 ** attempt)
58
+ if resp.status_code == 429:
59
+ # Respect Retry-After header if present
60
+ retry_after = resp.headers.get("retry-after")
61
+ if retry_after and retry_after.isdigit():
62
+ wait = max(wait, int(retry_after))
63
+ await asyncio.sleep(wait)
64
+ else:
65
+ resp.raise_for_status()
66
+ except httpx.ConnectError as e:
67
+ last_exc = OfflineError(f"Cannot connect to API: {e}")
68
+ if attempt < retries - 1:
69
+ await asyncio.sleep(backoff * (2 ** attempt))
70
+ else:
71
+ raise last_exc
72
+ except httpx.TimeoutException as e:
73
+ last_exc = OfflineError(f"API request timed out: {e}")
74
+ if attempt < retries - 1:
75
+ await asyncio.sleep(backoff * (2 ** attempt))
76
+ else:
77
+ raise last_exc
78
+ except httpx.HTTPStatusError:
79
+ raise
80
+ raise last_exc or APIError("Request failed after retries")
81
+
82
+
83
+ class MSaplingClient:
84
+ """Async HTTP client for MSapling backend with retry and offline support."""
85
+
86
+ def __init__(self, api_url: Optional[str] = None, token: Optional[str] = None, diff_mode: bool = False):
87
+ settings = get_settings()
88
+ self.api_url = (api_url or settings.api_url).rstrip("/")
89
+ self.token = token or get_token()
90
+ self.diff_mode = diff_mode # Only inject diff system prompt when explicitly enabled
91
+ self._client: Optional[httpx.AsyncClient] = None
92
+
93
+ async def _get_client(self) -> httpx.AsyncClient:
94
+ if self._client is None or self._client.is_closed:
95
+ headers = {"Content-Type": "application/json"}
96
+ cookies = {}
97
+ if self.token:
98
+ cookies["msaplingauth"] = self.token
99
+ self._client = httpx.AsyncClient(
100
+ base_url=self.api_url,
101
+ headers=headers,
102
+ cookies=cookies,
103
+ timeout=60.0,
104
+ )
105
+ # Bootstrap CSRF token from /health endpoint
106
+ try:
107
+ health = await self._client.get("/health")
108
+ csrf = health.cookies.get("msapling_csrftoken") or self._client.cookies.get("msapling_csrftoken")
109
+ if csrf:
110
+ self._client.headers["X-CSRF-Token"] = csrf
111
+ except Exception:
112
+ pass # Offline — CSRF will be missing but local commands still work
113
+ return self._client
114
+
115
+ async def close(self):
116
+ if self._client and not self._client.is_closed:
117
+ await self._client.aclose()
118
+
119
+ async def health_check(self) -> bool:
120
+ """Check if API is reachable. Returns False if offline."""
121
+ try:
122
+ client = await self._get_client()
123
+ resp = await asyncio.wait_for(client.get("/health"), timeout=5.0)
124
+ return resp.status_code == 200
125
+ except Exception:
126
+ return False
127
+
128
+ # --- Auth ---
129
+
130
+ async def login(self, email: str, password: str) -> Dict[str, Any]:
131
+ client = await self._get_client()
132
+ # Step 1: GET /health to obtain CSRF cookie
133
+ health_resp = await client.get("/health")
134
+ csrf_token = health_resp.cookies.get("msapling_csrftoken") or client.cookies.get("msapling_csrftoken")
135
+ # Step 2: POST /auth/login with CSRF header
136
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
137
+ if csrf_token:
138
+ headers["X-CSRF-Token"] = csrf_token
139
+ resp = await _retry_request(lambda: client.post(
140
+ "/auth/login",
141
+ data={"username": email, "password": password},
142
+ headers=headers,
143
+ ))
144
+ resp.raise_for_status()
145
+ token = resp.cookies.get("msaplingauth") or client.cookies.get("msaplingauth")
146
+ if token:
147
+ self.token = token
148
+ self._client = None
149
+ return resp.json()
150
+
151
+ async def me(self) -> Dict[str, Any]:
152
+ client = await self._get_client()
153
+ resp = await _retry_request(lambda: client.get("/users/me"))
154
+ resp.raise_for_status()
155
+ return resp.json()
156
+
157
+ # --- Chat ---
158
+
159
+ async def stream_chat(
160
+ self,
161
+ chat_id: str,
162
+ prompt: str,
163
+ model: str,
164
+ *,
165
+ project_name: Optional[str] = None,
166
+ history: Optional[list] = None,
167
+ agent_mode: bool = False,
168
+ ) -> AsyncGenerator[Dict[str, Any], None]:
169
+ """Stream chat response as NDJSON chunks."""
170
+ client = await self._get_client()
171
+ messages = list(history or [])
172
+ if self.diff_mode:
173
+ messages = [{"role": "system", "content": _DIFF_SYSTEM_PROMPT}] + messages
174
+ payload = {
175
+ "chat_id": chat_id,
176
+ "prompt": prompt,
177
+ "model": model,
178
+ "project_name": project_name or "",
179
+ "messages": messages,
180
+ "agent_mode": agent_mode,
181
+ }
182
+ async with client.stream("POST", "/api/chat/message", json=payload) as resp:
183
+ resp.raise_for_status()
184
+ buffer = ""
185
+ async for chunk in resp.aiter_text():
186
+ buffer += chunk
187
+ while "\n" in buffer:
188
+ line, buffer = buffer.split("\n", 1)
189
+ line = line.strip()
190
+ if line:
191
+ try:
192
+ yield json.loads(line)
193
+ except json.JSONDecodeError:
194
+ continue
195
+ if buffer.strip():
196
+ try:
197
+ yield json.loads(buffer.strip())
198
+ except json.JSONDecodeError:
199
+ pass
200
+
201
+ async def get_models(self) -> list:
202
+ client = await self._get_client()
203
+ resp = await _retry_request(lambda: client.get("/api/chat/models"))
204
+ resp.raise_for_status()
205
+ data = resp.json()
206
+ return data if isinstance(data, list) else data.get("models", data.get("data", []))
207
+
208
+ # --- Projects ---
209
+
210
+ async def list_projects(self) -> list:
211
+ client = await self._get_client()
212
+ resp = await _retry_request(lambda: client.get("/api/projects"))
213
+ resp.raise_for_status()
214
+ data = resp.json()
215
+ return data if isinstance(data, list) else data.get("projects", [])
216
+
217
+ async def projects_overview(self) -> Dict[str, Any]:
218
+ """Get projects with chat counts and limits. Used by CLI startup."""
219
+ client = await self._get_client()
220
+ resp = await _retry_request(lambda: client.get("/api/projects/overview"))
221
+ resp.raise_for_status()
222
+ return resp.json()
223
+
224
+ async def create_project(self, name: str) -> Dict[str, Any]:
225
+ client = await self._get_client()
226
+ resp = await _retry_request(lambda: client.post("/api/projects", json={
227
+ "project_name": name,
228
+ }))
229
+ resp.raise_for_status()
230
+ return resp.json()
231
+
232
+ async def create_chat(
233
+ self, project_name: str, title: str = "CLI Chat",
234
+ model: str = "", client_type: str = "cli",
235
+ ) -> Dict[str, Any]:
236
+ client = await self._get_client()
237
+ resp = await _retry_request(lambda: client.post("/api/projects/chat/new", json={
238
+ "project_name": project_name,
239
+ "slot_label": title,
240
+ "model": model or get_settings().default_model,
241
+ "client_type": client_type,
242
+ }))
243
+ resp.raise_for_status()
244
+ return resp.json()
245
+
246
+ async def list_project_chats(self, project_name: str) -> List[Dict[str, Any]]:
247
+ """Get chats for a specific project."""
248
+ client = await self._get_client()
249
+ resp = await _retry_request(lambda: client.get("/api/projects", params={
250
+ "limit": 100,
251
+ }))
252
+ resp.raise_for_status()
253
+ data = resp.json()
254
+ projects = data if isinstance(data, dict) else {}
255
+ if isinstance(data, dict) and "projects" in data:
256
+ projects = data["projects"]
257
+ proj_data = projects.get(project_name, {})
258
+ return proj_data.get("chats", [])
259
+
260
+ async def delete_chat(self, chat_id: str) -> Dict[str, Any]:
261
+ client = await self._get_client()
262
+ resp = await _retry_request(lambda: client.request("DELETE", f"/api/projects/chat/{chat_id}"))
263
+ resp.raise_for_status()
264
+ return resp.json()
265
+
266
+ # --- Multi-Chat / Swarm ---
267
+
268
+ async def chat_once(self, prompt: str, model: str, history: Optional[list] = None) -> str:
269
+ """Non-streaming single-shot chat. Returns full response text."""
270
+ parts = []
271
+ async for chunk in self.stream_chat(
272
+ chat_id=str(uuid.uuid4()),
273
+ prompt=prompt,
274
+ model=model,
275
+ history=history,
276
+ ):
277
+ content = chunk.get("content", "")
278
+ if content:
279
+ parts.append(content)
280
+ return "".join(parts)
281
+
282
+ async def multi_chat(
283
+ self,
284
+ prompt: str,
285
+ models: List[str],
286
+ history: Optional[list] = None,
287
+ ) -> List[Dict[str, Any]]:
288
+ """Send the same prompt to multiple models in parallel."""
289
+
290
+ async def _single(model: str) -> Dict[str, Any]:
291
+ try:
292
+ text = await self.chat_once(prompt, model, history)
293
+ return {"model": model, "response": text, "status": "ok"}
294
+ except Exception as e:
295
+ return {"model": model, "response": "", "status": "error", "error": str(e)}
296
+
297
+ results = await asyncio.gather(*[_single(m) for m in models])
298
+ return list(results)
299
+
300
+ async def swarm(
301
+ self,
302
+ prompt: str,
303
+ models: Optional[List[str]] = None,
304
+ synthesize_model: Optional[str] = None,
305
+ ) -> Dict[str, Any]:
306
+ """Run swarm: multiple models answer in parallel, then a judge synthesizes."""
307
+ default_models = [
308
+ "google/gemini-flash-1.5",
309
+ "anthropic/claude-3-haiku",
310
+ "openai/gpt-4o-mini",
311
+ ]
312
+ use_models = models or default_models
313
+ judge = synthesize_model or use_models[0]
314
+
315
+ responses = await self.multi_chat(prompt, use_models)
316
+
317
+ agent_outputs = []
318
+ for r in responses:
319
+ status = "OK" if r["status"] == "ok" else "FAILED"
320
+ agent_outputs.append(f"[{r['model']} - {status}]:\n{r['response'][:2000]}")
321
+
322
+ synthesis_prompt = (
323
+ f"You are a synthesis judge. Multiple AI models were asked the same question.\n\n"
324
+ f"Original question: {prompt}\n\n"
325
+ f"Their responses:\n\n" + "\n\n---\n\n".join(agent_outputs) + "\n\n"
326
+ f"Synthesize the best answer. Combine strengths, note disagreements, give a final answer."
327
+ )
328
+
329
+ synthesis = await self.chat_once(synthesis_prompt, judge)
330
+ return {
331
+ "agent_responses": responses,
332
+ "synthesis": synthesis,
333
+ "judge_model": judge,
334
+ "models_used": use_models,
335
+ }
336
+
337
+ # --- MLineage ---
338
+
339
+ async def generate_diff(self, old_content: str, new_content: str, filename: str = "file") -> Dict[str, Any]:
340
+ client = await self._get_client()
341
+ resp = await _retry_request(lambda: client.post("/api/mlineage/generate-diff", json={
342
+ "old_content": old_content,
343
+ "new_content": new_content,
344
+ "filename": filename,
345
+ }))
346
+ resp.raise_for_status()
347
+ return resp.json()
348
+
349
+ async def apply_diff(self, original: str, diff_text: str) -> Dict[str, Any]:
350
+ client = await self._get_client()
351
+ resp = await _retry_request(lambda: client.post("/api/mlineage/apply-diff", json={
352
+ "original_content": original,
353
+ "diff_text": diff_text,
354
+ }))
355
+ resp.raise_for_status()
356
+ return resp.json()
357
+
358
+ # --- MDrive ---
359
+
360
+ async def list_files(self, project_id: str, path: str = "/") -> list:
361
+ client = await self._get_client()
362
+ resp = await _retry_request(lambda: client.get("/api/mdrive/list", params={"project_id": project_id, "path": path}))
363
+ resp.raise_for_status()
364
+ data = resp.json()
365
+ return data if isinstance(data, list) else data.get("files", [])
366
+
367
+ async def read_file(self, project_id: str, file_path: str) -> str:
368
+ client = await self._get_client()
369
+ resp = await _retry_request(lambda: client.post("/api/mdrive/read", json={"project_id": project_id, "file_path": file_path}))
370
+ resp.raise_for_status()
371
+ return resp.json().get("content", "")
372
+
373
+ async def write_file(self, project_id: str, file_path: str, content: str) -> Dict[str, Any]:
374
+ client = await self._get_client()
375
+ resp = await _retry_request(lambda: client.post("/api/mdrive/write", json={
376
+ "project_id": project_id,
377
+ "file_path": file_path,
378
+ "content": content,
379
+ }))
380
+ resp.raise_for_status()
381
+ return resp.json()
382
+
383
+ # --- Load Project Context ---
384
+
385
+ async def load_project(self, chat_id: str, project_path: str = "", patterns: Optional[list] = None) -> Dict[str, Any]:
386
+ client = await self._get_client()
387
+ resp = await _retry_request(lambda: client.post("/api/chat/load-project", json={
388
+ "chat_id": chat_id,
389
+ "project_path": project_path,
390
+ "file_patterns": patterns or ["*.py", "*.ts", "*.tsx", "*.js", "*.md"],
391
+ "max_files": 30,
392
+ }))
393
+ resp.raise_for_status()
394
+ return resp.json()