opencode-a2a 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,544 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import time
7
+ from collections.abc import AsyncIterator, Mapping, Sequence
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ import httpx
12
+
13
+ from .config import Settings
14
+ from .parts.text import extract_text_from_parts
15
+
16
+ _UNSET = object()
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class UpstreamContractError(RuntimeError):
21
+ """Raised when upstream returns a shape/status that violates documented contract."""
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class OpencodeMessage:
26
+ text: str
27
+ session_id: str
28
+ message_id: str | None
29
+ raw: dict[str, Any]
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class InterruptRequestBinding:
34
+ request_id: str
35
+ session_id: str
36
+ interrupt_type: str
37
+ identity: str | None
38
+ task_id: str | None
39
+ context_id: str | None
40
+ expires_at: float
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class InterruptRequestTombstone:
45
+ request_id: str
46
+ expires_at: float
47
+
48
+
49
+ class OpencodeUpstreamClient:
50
+ def __init__(self, settings: Settings) -> None:
51
+ self._settings = settings
52
+ self._base_url = settings.opencode_base_url.rstrip("/")
53
+ self._directory = settings.opencode_workspace_root
54
+ self._agent = settings.opencode_agent
55
+ self._system = settings.opencode_system
56
+ self._variant = settings.opencode_variant
57
+ self._stream_timeout = settings.opencode_timeout_stream
58
+ self._log_payloads = settings.a2a_log_payloads
59
+ self._interrupt_request_ttl_seconds = float(settings.a2a_interrupt_request_ttl_seconds)
60
+ self._interrupt_request_tombstone_ttl_seconds = float(
61
+ settings.a2a_interrupt_request_tombstone_ttl_seconds
62
+ )
63
+ self._interrupt_request_clock = time.monotonic
64
+ self._interrupt_requests: dict[str, InterruptRequestBinding] = {}
65
+ self._interrupt_request_tombstones: dict[str, InterruptRequestTombstone] = {}
66
+ self._client = self._build_http_client(self._base_url)
67
+
68
+ def _build_http_client(self, base_url: str) -> httpx.AsyncClient:
69
+ return httpx.AsyncClient(
70
+ base_url=base_url,
71
+ timeout=self._settings.opencode_timeout,
72
+ headers={"Accept": "application/json"},
73
+ )
74
+
75
+ async def close(self) -> None:
76
+ await self._client.aclose()
77
+
78
+ @staticmethod
79
+ def _response_body_preview(response: httpx.Response, *, limit: int = 200) -> str:
80
+ body = response.text.strip()
81
+ if not body:
82
+ return "<empty>"
83
+ compact = " ".join(body.split())
84
+ if len(compact) <= limit:
85
+ return compact
86
+ return f"{compact[: limit - 3]}..."
87
+
88
+ def _decode_json_response(self, response: httpx.Response, *, endpoint: str) -> Any:
89
+ try:
90
+ return response.json()
91
+ except json.JSONDecodeError as exc:
92
+ content_type = response.headers.get("content-type", "").split(";", 1)[0].strip()
93
+ normalized_content_type = content_type or "unknown"
94
+ body_preview = self._response_body_preview(response)
95
+ raise UpstreamContractError(
96
+ f"OpenCode {endpoint} returned non-JSON response "
97
+ f"(status={response.status_code}, content-type={normalized_content_type}, "
98
+ f"body={body_preview})"
99
+ ) from exc
100
+
101
+ @staticmethod
102
+ def _require_boolean_response(*, endpoint: str, payload: Any) -> bool:
103
+ if isinstance(payload, bool):
104
+ return payload
105
+ raise RuntimeError(
106
+ f"OpenCode {endpoint} response must be boolean; got {type(payload).__name__}"
107
+ )
108
+
109
+ async def _get_json(
110
+ self,
111
+ path: str,
112
+ *,
113
+ endpoint: str,
114
+ params: Mapping[str, Any] | None = None,
115
+ ) -> Any:
116
+ response = await self._client.get(path, params=params)
117
+ response.raise_for_status()
118
+ return self._decode_json_response(response, endpoint=endpoint)
119
+
120
+ async def _post_json(
121
+ self,
122
+ path: str,
123
+ *,
124
+ endpoint: str,
125
+ params: Mapping[str, Any] | None = None,
126
+ json_body: Any = _UNSET,
127
+ timeout: float | None | object = _UNSET,
128
+ ) -> Any:
129
+ request_kwargs: dict[str, Any] = {}
130
+ if json_body is not _UNSET:
131
+ request_kwargs["json"] = json_body
132
+ if timeout is not _UNSET:
133
+ request_kwargs["timeout"] = timeout
134
+ response = await self._client.post(
135
+ path,
136
+ params=params,
137
+ **request_kwargs,
138
+ )
139
+ response.raise_for_status()
140
+ return self._decode_json_response(response, endpoint=endpoint)
141
+
142
+ async def _post_boolean(
143
+ self,
144
+ path: str,
145
+ *,
146
+ endpoint: str,
147
+ params: Mapping[str, Any] | None = None,
148
+ json_body: Any = _UNSET,
149
+ timeout: float | None | object = _UNSET,
150
+ ) -> bool:
151
+ data = await self._post_json(
152
+ path,
153
+ endpoint=endpoint,
154
+ params=params,
155
+ json_body=json_body,
156
+ timeout=timeout,
157
+ )
158
+ return self._require_boolean_response(endpoint=endpoint, payload=data)
159
+
160
+ def _prune_interrupt_requests(self, *, now: float) -> None:
161
+ expired = [
162
+ request_id
163
+ for request_id, binding in self._interrupt_requests.items()
164
+ if binding.expires_at <= now
165
+ ]
166
+ for request_id in expired:
167
+ self._interrupt_requests.pop(request_id, None)
168
+ self._remember_interrupt_request_tombstone(request_id, now=now)
169
+
170
+ def _prune_interrupt_request_tombstones(self, *, now: float) -> None:
171
+ expired = [
172
+ request_id
173
+ for request_id, tombstone in self._interrupt_request_tombstones.items()
174
+ if tombstone.expires_at <= now
175
+ ]
176
+ for request_id in expired:
177
+ self._interrupt_request_tombstones.pop(request_id, None)
178
+
179
+ def _remember_interrupt_request_tombstone(self, request_id: str, *, now: float) -> None:
180
+ ttl = self._interrupt_request_tombstone_ttl_seconds
181
+ if ttl <= 0:
182
+ self._interrupt_request_tombstones.pop(request_id, None)
183
+ return
184
+ self._interrupt_request_tombstones[request_id] = InterruptRequestTombstone(
185
+ request_id=request_id,
186
+ expires_at=now + ttl,
187
+ )
188
+
189
+ def remember_interrupt_request(
190
+ self,
191
+ *,
192
+ request_id: str,
193
+ session_id: str,
194
+ interrupt_type: str,
195
+ identity: str | None = None,
196
+ task_id: str | None = None,
197
+ context_id: str | None = None,
198
+ ttl_seconds: float | None = None,
199
+ ) -> None:
200
+ request = request_id.strip()
201
+ session = session_id.strip()
202
+ kind = interrupt_type.strip()
203
+ if not request or not session or kind not in {"permission", "question"}:
204
+ return
205
+ now = self._interrupt_request_clock()
206
+ self._prune_interrupt_requests(now=now)
207
+ self._prune_interrupt_request_tombstones(now=now)
208
+ ttl = self._interrupt_request_ttl_seconds if ttl_seconds is None else ttl_seconds
209
+ expires_at = now + max(0.0, float(ttl))
210
+ self._interrupt_requests[request] = InterruptRequestBinding(
211
+ request_id=request,
212
+ session_id=session,
213
+ interrupt_type=kind,
214
+ identity=identity.strip() if isinstance(identity, str) and identity.strip() else None,
215
+ task_id=task_id.strip() if isinstance(task_id, str) and task_id.strip() else None,
216
+ context_id=(
217
+ context_id.strip() if isinstance(context_id, str) and context_id.strip() else None
218
+ ),
219
+ expires_at=expires_at,
220
+ )
221
+ self._interrupt_request_tombstones.pop(request, None)
222
+
223
+ def resolve_interrupt_request(
224
+ self,
225
+ request_id: str,
226
+ ) -> tuple[str, InterruptRequestBinding | None]:
227
+ request = request_id.strip()
228
+ if not request:
229
+ return "missing", None
230
+ now = self._interrupt_request_clock()
231
+ self._prune_interrupt_request_tombstones(now=now)
232
+ binding = self._interrupt_requests.get(request)
233
+ if binding is None:
234
+ if request in self._interrupt_request_tombstones:
235
+ return "expired", None
236
+ return "missing", None
237
+ if binding.expires_at <= now:
238
+ self._interrupt_requests.pop(request, None)
239
+ self._prune_interrupt_requests(now=now)
240
+ self._remember_interrupt_request_tombstone(request, now=now)
241
+ return "expired", None
242
+ self._prune_interrupt_requests(now=now)
243
+ return "active", binding
244
+
245
+ def resolve_interrupt_session(self, request_id: str) -> str | None:
246
+ status, binding = self.resolve_interrupt_request(request_id)
247
+ if status != "active" or binding is None:
248
+ return None
249
+ return binding.session_id
250
+
251
+ def discard_interrupt_request(self, request_id: str) -> None:
252
+ request = request_id.strip()
253
+ if not request:
254
+ return
255
+ self._interrupt_requests.pop(request, None)
256
+ self._interrupt_request_tombstones.pop(request, None)
257
+
258
+ @property
259
+ def stream_timeout(self) -> float | None:
260
+ return self._stream_timeout
261
+
262
+ @property
263
+ def directory(self) -> str | None:
264
+ return self._directory
265
+
266
+ @property
267
+ def settings(self) -> Settings:
268
+ return self._settings
269
+
270
+ @staticmethod
271
+ def _normalize_model_ref(value: Mapping[str, Any] | None) -> dict[str, str] | None:
272
+ if value is None:
273
+ return None
274
+ provider = value.get("providerID")
275
+ model = value.get("modelID")
276
+ if not isinstance(provider, str) or not isinstance(model, str):
277
+ return None
278
+ provider_id = provider.strip()
279
+ model_id = model.strip()
280
+ if not provider_id or not model_id:
281
+ return None
282
+ return {
283
+ "providerID": provider_id,
284
+ "modelID": model_id,
285
+ }
286
+
287
+ def _query_params(self, directory: str | None = None) -> dict[str, str]:
288
+ d = directory or self._directory
289
+ if not d:
290
+ return {}
291
+ return {"directory": d}
292
+
293
+ def _merge_params(
294
+ self, extra: dict[str, Any] | None, *, directory: str | None = None
295
+ ) -> dict[str, Any]:
296
+ params: dict[str, Any] = dict(self._query_params(directory=directory))
297
+ if not extra:
298
+ return params
299
+ for key, value in extra.items():
300
+ if value is None:
301
+ continue
302
+ # "directory" is server-controlled. Client overrides are handled via explicit parameter.
303
+ if key == "directory":
304
+ continue
305
+ # FastAPI query params are strings; keep them as-is. Coerce other primitives to str.
306
+ params[key] = value if isinstance(value, str) else str(value)
307
+ return params
308
+
309
+ async def stream_events(
310
+ self, stop_event: asyncio.Event | None = None, *, directory: str | None = None
311
+ ) -> AsyncIterator[dict[str, Any]]:
312
+ params = self._query_params(directory=directory)
313
+ async with self._client.stream(
314
+ "GET",
315
+ "/event",
316
+ params=params,
317
+ timeout=None,
318
+ headers={"Accept": "text/event-stream"},
319
+ ) as response:
320
+ response.raise_for_status()
321
+ data_lines: list[str] = []
322
+ async for line in response.aiter_lines():
323
+ if stop_event and stop_event.is_set():
324
+ break
325
+ if line.startswith(":"):
326
+ continue
327
+ if line == "":
328
+ if not data_lines:
329
+ continue
330
+ payload = "\n".join(data_lines).strip()
331
+ data_lines.clear()
332
+ if not payload:
333
+ continue
334
+ try:
335
+ event = json.loads(payload)
336
+ except json.JSONDecodeError:
337
+ continue
338
+ if isinstance(event, dict):
339
+ yield event
340
+ continue
341
+ if line.startswith("data:"):
342
+ data_lines.append(line[5:].lstrip())
343
+ continue
344
+
345
+ async def create_session(
346
+ self, title: str | None = None, *, directory: str | None = None
347
+ ) -> str:
348
+ payload: dict[str, Any] = {}
349
+ if title:
350
+ payload["title"] = title
351
+ data = await self._post_json(
352
+ "/session",
353
+ endpoint="/session",
354
+ params=self._query_params(directory=directory),
355
+ json_body=payload,
356
+ )
357
+ session_id = data.get("id")
358
+ if not isinstance(session_id, str) or not session_id:
359
+ raise RuntimeError("OpenCode session response missing id")
360
+ return session_id
361
+
362
+ async def abort_session(self, session_id: str, *, directory: str | None = None) -> bool:
363
+ return await self._post_boolean(
364
+ f"/session/{session_id}/abort",
365
+ endpoint="/session/{sessionID}/abort",
366
+ params=self._query_params(directory=directory),
367
+ )
368
+
369
+ async def list_sessions(self, *, params: dict[str, Any] | None = None) -> Any:
370
+ """List sessions from OpenCode."""
371
+ # Note: directory override is not explicitly supported by list_sessions params yet.
372
+ # If needed, we can add it later. For now we use the default.
373
+ return await self._get_json(
374
+ "/session",
375
+ endpoint="/session",
376
+ params=self._merge_params(params),
377
+ )
378
+
379
+ async def list_messages(self, session_id: str, *, params: dict[str, Any] | None = None) -> Any:
380
+ """List messages for a session from OpenCode."""
381
+ return await self._get_json(
382
+ f"/session/{session_id}/message",
383
+ endpoint="/session/{sessionID}/message",
384
+ params=self._merge_params(params),
385
+ )
386
+
387
+ async def session_prompt_async(
388
+ self,
389
+ session_id: str,
390
+ request: dict[str, Any],
391
+ *,
392
+ directory: str | None = None,
393
+ ) -> None:
394
+ response = await self._client.post(
395
+ f"/session/{session_id}/prompt_async",
396
+ params=self._query_params(directory=directory),
397
+ json=request,
398
+ )
399
+ response.raise_for_status()
400
+ if response.status_code != 204:
401
+ raise UpstreamContractError(
402
+ "OpenCode /session/{sessionID}/prompt_async must return 204; "
403
+ f"got {response.status_code}"
404
+ )
405
+
406
+ async def session_command(
407
+ self,
408
+ session_id: str,
409
+ request: dict[str, Any],
410
+ *,
411
+ directory: str | None = None,
412
+ ) -> Any:
413
+ return await self._post_json(
414
+ f"/session/{session_id}/command",
415
+ endpoint="/session/{sessionID}/command",
416
+ params=self._query_params(directory=directory),
417
+ json_body=request,
418
+ )
419
+
420
+ async def session_shell(
421
+ self,
422
+ session_id: str,
423
+ request: dict[str, Any],
424
+ *,
425
+ directory: str | None = None,
426
+ ) -> Any:
427
+ return await self._post_json(
428
+ f"/session/{session_id}/shell",
429
+ endpoint="/session/{sessionID}/shell",
430
+ params=self._query_params(directory=directory),
431
+ json_body=request,
432
+ )
433
+
434
+ async def list_provider_catalog(self, *, directory: str | None = None) -> Any:
435
+ return await self._get_json(
436
+ "/provider",
437
+ endpoint="/provider",
438
+ params=self._query_params(directory=directory),
439
+ )
440
+
441
+ async def send_message(
442
+ self,
443
+ session_id: str,
444
+ text: str | None = None,
445
+ *,
446
+ parts: Sequence[Mapping[str, Any]] | None = None,
447
+ directory: str | None = None,
448
+ model_override: Mapping[str, Any] | None = None,
449
+ timeout_override: float | None | object = _UNSET,
450
+ ) -> OpencodeMessage:
451
+ payload_parts: list[dict[str, Any]]
452
+ if parts is not None:
453
+ payload_parts = [dict(part) for part in parts]
454
+ elif isinstance(text, str):
455
+ payload_parts = [
456
+ {
457
+ "type": "text",
458
+ "text": text,
459
+ }
460
+ ]
461
+ else:
462
+ raise ValueError("send_message requires either text or parts")
463
+
464
+ if not payload_parts:
465
+ raise ValueError("send_message parts must not be empty")
466
+
467
+ payload: dict[str, Any] = {"parts": payload_parts}
468
+ if self._agent:
469
+ payload["agent"] = self._agent
470
+ if self._system:
471
+ payload["system"] = self._system
472
+ if self._variant:
473
+ payload["variant"] = self._variant
474
+ normalized_model = self._normalize_model_ref(model_override)
475
+ if normalized_model is not None:
476
+ payload["model"] = normalized_model
477
+
478
+ if self._log_payloads:
479
+ logger.debug("OpenCode request payload=%s", payload)
480
+
481
+ data = await self._post_json(
482
+ f"/session/{session_id}/message",
483
+ endpoint="/session/{sessionID}/message",
484
+ params=self._query_params(directory=directory),
485
+ json_body=payload,
486
+ timeout=timeout_override,
487
+ )
488
+ if self._log_payloads:
489
+ logger.debug("OpenCode response payload=%s", data)
490
+ text_content = extract_text_from_parts(data.get("parts", []))
491
+ message_id = None
492
+ info = data.get("info")
493
+ if isinstance(info, dict):
494
+ message_id = info.get("id")
495
+ return OpencodeMessage(
496
+ text=text_content,
497
+ session_id=session_id,
498
+ message_id=message_id,
499
+ raw=data,
500
+ )
501
+
502
+ async def permission_reply(
503
+ self,
504
+ request_id: str,
505
+ *,
506
+ reply: str,
507
+ message: str | None = None,
508
+ directory: str | None = None,
509
+ ) -> bool:
510
+ payload: dict[str, Any] = {"reply": reply}
511
+ if message:
512
+ payload["message"] = message
513
+ return await self._post_boolean(
514
+ f"/permission/{request_id}/reply",
515
+ endpoint="/permission/{requestID}/reply",
516
+ params=self._query_params(directory=directory),
517
+ json_body=payload,
518
+ )
519
+
520
+ async def question_reply(
521
+ self,
522
+ request_id: str,
523
+ *,
524
+ answers: list[list[str]],
525
+ directory: str | None = None,
526
+ ) -> bool:
527
+ return await self._post_boolean(
528
+ f"/question/{request_id}/reply",
529
+ endpoint="/question/{requestID}/reply",
530
+ params=self._query_params(directory=directory),
531
+ json_body={"answers": answers},
532
+ )
533
+
534
+ async def question_reject(
535
+ self,
536
+ request_id: str,
537
+ *,
538
+ directory: str | None = None,
539
+ ) -> bool:
540
+ return await self._post_boolean(
541
+ f"/question/{request_id}/reject",
542
+ endpoint="/question/{requestID}/reject",
543
+ params=self._query_params(directory=directory),
544
+ )
@@ -0,0 +1 @@
1
+ """A2A input/output part normalization helpers."""
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Literal, TypedDict
4
+
5
+
6
+ class UnsupportedA2AInputError(ValueError):
7
+ """Raised when an incoming A2A part cannot be mapped to OpenCode input."""
8
+
9
+
10
+ class OpencodeTextInputPart(TypedDict):
11
+ type: Literal["text"]
12
+ text: str
13
+
14
+
15
+ class OpencodeFileInputPart(TypedDict, total=False):
16
+ type: Literal["file"]
17
+ url: str
18
+ mime: str
19
+ filename: str
20
+
21
+
22
+ OpencodeInputPart = OpencodeTextInputPart | OpencodeFileInputPart
23
+
24
+
25
+ def extract_text_from_a2a_parts(parts: Any) -> str:
26
+ if not isinstance(parts, list):
27
+ return ""
28
+
29
+ texts: list[str] = []
30
+ for part in parts:
31
+ root = _unwrap_part_root(part)
32
+ if getattr(root, "kind", None) != "text":
33
+ continue
34
+ text = getattr(root, "text", None)
35
+ if isinstance(text, str):
36
+ texts.append(text)
37
+ return "\n".join(texts).strip()
38
+
39
+
40
+ def summarize_a2a_parts(parts: Any) -> str | None:
41
+ text = extract_text_from_a2a_parts(parts)
42
+ if text:
43
+ return text[:80]
44
+
45
+ if not isinstance(parts, list):
46
+ return None
47
+
48
+ filenames: list[str] = []
49
+ for part in parts:
50
+ root = _unwrap_part_root(part)
51
+ if getattr(root, "kind", None) != "file":
52
+ continue
53
+ file_value = getattr(root, "file", None)
54
+ name = getattr(file_value, "name", None)
55
+ if isinstance(name, str) and name.strip():
56
+ filenames.append(name.strip())
57
+ else:
58
+ filenames.append("file")
59
+
60
+ if not filenames:
61
+ return None
62
+ if len(filenames) == 1:
63
+ return filenames[0]
64
+ return ", ".join(filenames[:3])[:80]
65
+
66
+
67
+ def map_a2a_parts_to_opencode_parts(parts: Any) -> list[OpencodeInputPart]:
68
+ if not isinstance(parts, list):
69
+ return []
70
+
71
+ mapped: list[OpencodeInputPart] = []
72
+ for index, part in enumerate(parts):
73
+ root = _unwrap_part_root(part)
74
+ kind = getattr(root, "kind", None)
75
+
76
+ if kind == "text":
77
+ text = getattr(root, "text", None)
78
+ if isinstance(text, str):
79
+ mapped.append({"type": "text", "text": text})
80
+ continue
81
+
82
+ if kind == "file":
83
+ mapped.append(_map_file_part(root, index=index))
84
+ continue
85
+
86
+ if kind == "data":
87
+ raise UnsupportedA2AInputError(
88
+ f"request.parts[{index}] DataPart input is not supported; use TextPart or FilePart."
89
+ )
90
+
91
+ raise UnsupportedA2AInputError(
92
+ f"request.parts[{index}] is not supported; only TextPart and FilePart are accepted."
93
+ )
94
+
95
+ return mapped
96
+
97
+
98
+ def _map_file_part(part: Any, *, index: int) -> OpencodeFileInputPart:
99
+ file_value = getattr(part, "file", None)
100
+ if file_value is None:
101
+ raise UnsupportedA2AInputError(
102
+ f"request.parts[{index}] FilePart is missing the file payload."
103
+ )
104
+
105
+ mime = (
106
+ _normalize_string(
107
+ getattr(file_value, "mime_type", None) or getattr(file_value, "mimeType", None)
108
+ )
109
+ or "application/octet-stream"
110
+ )
111
+ name = _normalize_string(getattr(file_value, "name", None))
112
+
113
+ bytes_value = _normalize_string(getattr(file_value, "bytes", None))
114
+ if bytes_value:
115
+ mapped: OpencodeFileInputPart = {
116
+ "type": "file",
117
+ "url": f"data:{mime};base64,{bytes_value}",
118
+ "mime": mime,
119
+ }
120
+ if name:
121
+ mapped["filename"] = name
122
+ return mapped
123
+
124
+ uri = _normalize_string(getattr(file_value, "uri", None))
125
+ if uri:
126
+ mapped = {
127
+ "type": "file",
128
+ "url": uri,
129
+ "mime": mime,
130
+ }
131
+ if name:
132
+ mapped["filename"] = name
133
+ return mapped
134
+
135
+ raise UnsupportedA2AInputError(
136
+ f"request.parts[{index}] FilePart must contain either bytes or uri."
137
+ )
138
+
139
+
140
+ def _unwrap_part_root(part: Any) -> Any:
141
+ root = getattr(part, "root", None)
142
+ if root is not None:
143
+ return root
144
+ return part
145
+
146
+
147
+ def _normalize_string(value: Any) -> str | None:
148
+ if not isinstance(value, str):
149
+ return None
150
+ normalized = value.strip()
151
+ return normalized if normalized else None