mycode-sdk 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,387 @@
1
+ """Shared provider adapter interfaces.
2
+
3
+ The agent loop talks to providers through a small normalized contract:
4
+
5
+ - input: `ProviderRequest`
6
+ - output: streamed `ProviderStreamEvent` objects
7
+
8
+ Concrete adapters are free to use the official SDK or protocol that best matches
9
+ their upstream provider. Each adapter is also responsible for projecting the
10
+ canonical session transcript into a provider-safe replay history before a new
11
+ request is sent upstream.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import html
17
+ import os
18
+ from abc import ABC, abstractmethod
19
+ from collections.abc import AsyncIterator
20
+ from dataclasses import dataclass, field
21
+ from typing import Any
22
+
23
+ from mycode.messages import ConversationMessage, build_message, text_block, tool_result_block
24
+
25
+ DEFAULT_REQUEST_TIMEOUT = 300.0
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class ProviderRequest:
30
+ provider: str
31
+ model: str
32
+ session_id: str | None
33
+ messages: list[ConversationMessage]
34
+ system: str
35
+ tools: list[dict[str, Any]]
36
+ max_tokens: int
37
+ api_key: str | None
38
+ api_base: str | None
39
+ reasoning_effort: str | None = None
40
+ supports_image_input: bool = True
41
+ supports_pdf_input: bool = True
42
+
43
+
44
+ @dataclass
45
+ class ProviderStreamEvent:
46
+ type: str
47
+ data: dict[str, Any] = field(default_factory=dict)
48
+
49
+
50
+ def dump_model(value: Any) -> Any:
51
+ """Convert SDK model objects into plain Python data."""
52
+
53
+ if value is None:
54
+ return None
55
+ if hasattr(value, "model_dump"):
56
+ return value.model_dump()
57
+ if isinstance(value, list):
58
+ return [dump_model(item) for item in value]
59
+ return value
60
+
61
+
62
+ def get_native_meta(block: dict[str, Any]) -> dict[str, Any]:
63
+ """Return block.meta.native as a dict, or {} if absent."""
64
+
65
+ raw_meta = block.get("meta")
66
+ if isinstance(raw_meta, dict):
67
+ candidate = raw_meta.get("native")
68
+ if isinstance(candidate, dict):
69
+ return candidate
70
+ return {}
71
+
72
+
73
+ class ProviderAdapter(ABC):
74
+ """Base class for provider adapters.
75
+
76
+ New adapters usually only need to implement `stream_turn()` and optionally
77
+ override tool-call ID projection.
78
+ """
79
+
80
+ provider_id: str
81
+ label: str
82
+ default_base_url: str | None = None
83
+ env_api_key_names: tuple[str, ...] = ()
84
+ # Used only as lightweight defaults during config resolution.
85
+ default_models: tuple[str, ...] = ()
86
+ # Auto-discovery is intentionally limited to first-party built-ins that can
87
+ # run from environment variables alone.
88
+ auto_discoverable: bool = True
89
+ # Whether this adapter accepts the shared `reasoning_effort` knob. Providers
90
+ # that do not support it keep their upstream default behavior unchanged.
91
+ supports_reasoning_effort: bool = False
92
+
93
+ @abstractmethod
94
+ def stream_turn(self, request: ProviderRequest) -> AsyncIterator[ProviderStreamEvent]:
95
+ """Stream exactly one assistant turn."""
96
+
97
+ def prepare_messages(self, request: ProviderRequest) -> list[ConversationMessage]:
98
+ """Repair canonical history, then project tool IDs for provider replay."""
99
+
100
+ supports_image_input = getattr(request, "supports_image_input", True)
101
+ supports_pdf_input = getattr(request, "supports_pdf_input", True)
102
+ repaired_messages = repair_messages_for_replay(
103
+ request.messages,
104
+ supports_image_input=supports_image_input,
105
+ supports_pdf_input=supports_pdf_input,
106
+ )
107
+ prepared_messages: list[ConversationMessage] = []
108
+ tool_id_map: dict[str, str] = {}
109
+ used_tool_call_ids: set[str] = set()
110
+
111
+ for message in repaired_messages:
112
+ projected_blocks: list[dict[str, Any]] = []
113
+ for raw_block in message.get("content") or []:
114
+ if not isinstance(raw_block, dict):
115
+ continue
116
+
117
+ block = dict(raw_block)
118
+ if block.get("type") == "tool_use":
119
+ tool_use_id = str(block.get("id") or "")
120
+ if tool_use_id and tool_use_id not in tool_id_map:
121
+ tool_id_map[tool_use_id] = self.project_tool_call_id(tool_use_id, used_tool_call_ids)
122
+ used_tool_call_ids.add(tool_id_map[tool_use_id])
123
+ if tool_use_id:
124
+ block["id"] = tool_id_map[tool_use_id]
125
+ elif block.get("type") == "tool_result":
126
+ tool_use_id = str(block.get("tool_use_id") or "")
127
+ if tool_use_id in tool_id_map:
128
+ block["tool_use_id"] = tool_id_map[tool_use_id]
129
+
130
+ projected_blocks.append(block)
131
+
132
+ projected_message = dict(message)
133
+ projected_message["content"] = projected_blocks
134
+ prepared_messages.append(projected_message)
135
+
136
+ return prepared_messages
137
+
138
+ def project_tool_call_id(self, tool_call_id: str, used_tool_call_ids: set[str]) -> str:
139
+ """Project one canonical tool call ID into a provider-safe ID.
140
+
141
+ Most providers accept canonical tool IDs as-is. Adapters can override
142
+ this when the upstream protocol restricts character sets or length, as
143
+ long as the returned ID stays unique within the projected request.
144
+ """
145
+ del used_tool_call_ids
146
+ return tool_call_id
147
+
148
+ def api_key_from_env(self) -> str | None:
149
+ for env_name in self.env_api_key_names:
150
+ value = os.environ.get(env_name)
151
+ if value:
152
+ return value
153
+ return None
154
+
155
+ def require_api_key(self, api_key: str | None) -> str:
156
+ resolved = (api_key or "").strip() or self.api_key_from_env() or ""
157
+ if resolved:
158
+ return resolved
159
+
160
+ checked = ", ".join(self.env_api_key_names) or "<api key env>"
161
+ raise ValueError(f"missing API key for provider {self.provider_id}; checked: {checked}")
162
+
163
+ def resolve_base_url(self, api_base: str | None) -> str | None:
164
+ base = (api_base or self.default_base_url or "").strip()
165
+ return base.rstrip("/") or None
166
+
167
+
168
+ def repair_messages_for_replay(
169
+ source_messages: list[ConversationMessage],
170
+ *,
171
+ supports_image_input: bool,
172
+ supports_pdf_input: bool,
173
+ ) -> list[ConversationMessage]:
174
+ """Return a minimal replay-safe transcript from canonical session history.
175
+
176
+ This keeps only replayable blocks, removes duplicate or orphaned tool
177
+ records, and inserts synthetic error tool results when a tool call was left
178
+ open by an interrupted turn.
179
+ """
180
+
181
+ replay_messages: list[ConversationMessage] = []
182
+ emitted_tool_use_ids: set[str] = set()
183
+ emitted_tool_result_ids: set[str] = set()
184
+ open_tool_use_ids: list[str] = []
185
+
186
+ for message in source_messages:
187
+ role = str(message.get("role") or "")
188
+
189
+ if role == "assistant":
190
+ if open_tool_use_ids:
191
+ replay_messages.append(_interrupted_tool_result_message(open_tool_use_ids))
192
+ emitted_tool_result_ids.update(open_tool_use_ids)
193
+ open_tool_use_ids = []
194
+
195
+ raw_meta = message.get("meta")
196
+ stop_reason = str(raw_meta.get("stop_reason") or "") if isinstance(raw_meta, dict) else ""
197
+ if stop_reason in {"error", "aborted", "cancelled"}:
198
+ continue
199
+
200
+ content: list[dict[str, Any]] = []
201
+ current_tool_use_ids: list[str] = []
202
+ for raw_block in message.get("content") or []:
203
+ if not isinstance(raw_block, dict):
204
+ continue
205
+ block_type = raw_block.get("type")
206
+ if block_type in {"text", "thinking"}:
207
+ text = str(raw_block.get("text") or "")
208
+ if text:
209
+ content.append(dict(raw_block))
210
+ continue
211
+
212
+ if block_type != "tool_use":
213
+ continue
214
+
215
+ tool_use_id = str(raw_block.get("id") or "")
216
+ if not tool_use_id or tool_use_id in emitted_tool_use_ids:
217
+ continue
218
+
219
+ emitted_tool_use_ids.add(tool_use_id)
220
+ current_tool_use_ids.append(tool_use_id)
221
+ content.append(dict(raw_block))
222
+
223
+ if not content:
224
+ continue
225
+
226
+ replay_message = dict(message)
227
+ replay_message["content"] = content
228
+ if isinstance(raw_meta, dict):
229
+ replay_message["meta"] = dict(raw_meta)
230
+ replay_messages.append(replay_message)
231
+ open_tool_use_ids = current_tool_use_ids
232
+ continue
233
+
234
+ if role != "user":
235
+ continue
236
+
237
+ content = []
238
+ resolved_tool_use_ids: set[str] = set()
239
+ has_user_input = False
240
+
241
+ for raw_block in message.get("content") or []:
242
+ if not isinstance(raw_block, dict):
243
+ continue
244
+
245
+ block_type = raw_block.get("type")
246
+ if block_type == "text":
247
+ text = str(raw_block.get("text") or "")
248
+ if text:
249
+ has_user_input = True
250
+ content.append(dict(raw_block))
251
+ continue
252
+
253
+ if block_type in {"image", "document"}:
254
+ supported = supports_image_input if block_type == "image" else supports_pdf_input
255
+ has_user_input = True
256
+ if supported:
257
+ content.append(dict(raw_block))
258
+ else:
259
+ default_mime = "image" if block_type == "image" else "application/pdf"
260
+ label = "image input" if block_type == "image" else "PDF input"
261
+ name = html.escape(str(raw_block.get("name") or f"attached-{block_type}"), quote=True)
262
+ mime = html.escape(str(raw_block.get("mime_type") or default_mime), quote=True)
263
+ content.append(
264
+ {
265
+ "type": "text",
266
+ "text": f'<file name="{name}" media_type="{mime}" kind="{block_type}">Current model does not support {label}.</file>',
267
+ "meta": {"attachment": True},
268
+ }
269
+ )
270
+ continue
271
+
272
+ if block_type != "tool_result":
273
+ continue
274
+
275
+ tool_use_id = str(raw_block.get("tool_use_id") or "")
276
+ if not tool_use_id or tool_use_id not in emitted_tool_use_ids or tool_use_id in emitted_tool_result_ids:
277
+ continue
278
+
279
+ block = dict(raw_block)
280
+ raw_content = block.get("content")
281
+ if not supports_image_input and isinstance(raw_content, list):
282
+ filtered_content = [
283
+ dict(item) for item in raw_content if isinstance(item, dict) and item.get("type") != "image"
284
+ ]
285
+ if filtered_content:
286
+ block["content"] = filtered_content
287
+ else:
288
+ block.pop("content", None)
289
+
290
+ content.append(block)
291
+ resolved_tool_use_ids.add(tool_use_id)
292
+ emitted_tool_result_ids.add(tool_use_id)
293
+
294
+ if has_user_input and open_tool_use_ids:
295
+ missing_tool_use_ids = [
296
+ tool_use_id for tool_use_id in open_tool_use_ids if tool_use_id not in resolved_tool_use_ids
297
+ ]
298
+ if missing_tool_use_ids:
299
+ replay_messages.append(_interrupted_tool_result_message(missing_tool_use_ids))
300
+ emitted_tool_result_ids.update(missing_tool_use_ids)
301
+ open_tool_use_ids = []
302
+
303
+ elif open_tool_use_ids:
304
+ open_tool_use_ids = [
305
+ tool_use_id for tool_use_id in open_tool_use_ids if tool_use_id not in resolved_tool_use_ids
306
+ ]
307
+
308
+ if not content:
309
+ if replay_messages and replay_messages[-1].get("role") == "assistant":
310
+ # Keep a valid replay transcript when a corrupted user turn is
311
+ # reduced to nothing after cleanup.
312
+ replay_messages.append(
313
+ build_message(
314
+ "user",
315
+ [text_block("[User turn omitted during replay]")],
316
+ meta={"synthetic": True},
317
+ )
318
+ )
319
+ continue
320
+
321
+ replay_message = dict(message)
322
+ replay_message["content"] = content
323
+ if isinstance(message.get("meta"), dict):
324
+ replay_message["meta"] = dict(message["meta"])
325
+ replay_messages.append(replay_message)
326
+
327
+ if open_tool_use_ids:
328
+ replay_messages.append(_interrupted_tool_result_message(open_tool_use_ids))
329
+
330
+ return replay_messages
331
+
332
+
333
+ def _interrupted_tool_result_message(tool_use_ids: list[str]) -> ConversationMessage:
334
+ """Return one synthetic user message that closes interrupted tool calls."""
335
+
336
+ return build_message(
337
+ "user",
338
+ [
339
+ tool_result_block(
340
+ tool_use_id=tool_use_id,
341
+ model_text="error: tool call was interrupted",
342
+ display_text="Tool call was interrupted",
343
+ is_error=True,
344
+ )
345
+ for tool_use_id in tool_use_ids
346
+ ],
347
+ )
348
+
349
+
350
+ def load_image_block_payload(block: dict[str, Any]) -> tuple[str, str]:
351
+ """Return (mime_type, base64_data) for one canonical image block."""
352
+
353
+ mime_type = block.get("mime_type")
354
+ if not isinstance(mime_type, str) or not mime_type:
355
+ raise ValueError("image block is missing mime_type")
356
+
357
+ data = block.get("data")
358
+ if not isinstance(data, str) or not data:
359
+ raise ValueError("image block is missing data")
360
+
361
+ return mime_type, data
362
+
363
+
364
+ def load_document_block_payload(block: dict[str, Any]) -> tuple[str, str, str | None]:
365
+ """Return (mime_type, base64_data, name) for one canonical document block."""
366
+
367
+ mime_type = block.get("mime_type")
368
+ if not isinstance(mime_type, str) or not mime_type:
369
+ raise ValueError("document block is missing mime_type")
370
+
371
+ data = block.get("data")
372
+ if not isinstance(data, str) or not data:
373
+ raise ValueError("document block is missing data")
374
+
375
+ name = block.get("name")
376
+ return mime_type, data, name if isinstance(name, str) and name else None
377
+
378
+
379
+ def tool_result_content_blocks(block: dict[str, Any]) -> list[dict[str, Any]]:
380
+ """Return structured tool-result content, falling back to one text block."""
381
+
382
+ raw_content = block.get("content")
383
+ if isinstance(raw_content, list):
384
+ structured = [dict(item) for item in raw_content if isinstance(item, dict)]
385
+ if structured:
386
+ return structured
387
+ return [text_block(str(block.get("model_text") or ""))]