codemaster-cli 2.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codemaster_cli-2.2.0.dist-info/METADATA +645 -0
- codemaster_cli-2.2.0.dist-info/RECORD +170 -0
- codemaster_cli-2.2.0.dist-info/WHEEL +4 -0
- codemaster_cli-2.2.0.dist-info/entry_points.txt +3 -0
- vibe/__init__.py +6 -0
- vibe/acp/__init__.py +0 -0
- vibe/acp/acp_agent_loop.py +746 -0
- vibe/acp/entrypoint.py +81 -0
- vibe/acp/tools/__init__.py +0 -0
- vibe/acp/tools/base.py +100 -0
- vibe/acp/tools/builtins/bash.py +134 -0
- vibe/acp/tools/builtins/read_file.py +54 -0
- vibe/acp/tools/builtins/search_replace.py +129 -0
- vibe/acp/tools/builtins/todo.py +65 -0
- vibe/acp/tools/builtins/write_file.py +98 -0
- vibe/acp/tools/session_update.py +118 -0
- vibe/acp/utils.py +213 -0
- vibe/cli/__init__.py +0 -0
- vibe/cli/autocompletion/__init__.py +0 -0
- vibe/cli/autocompletion/base.py +22 -0
- vibe/cli/autocompletion/path_completion.py +177 -0
- vibe/cli/autocompletion/slash_command.py +99 -0
- vibe/cli/cli.py +188 -0
- vibe/cli/clipboard.py +69 -0
- vibe/cli/commands.py +116 -0
- vibe/cli/entrypoint.py +163 -0
- vibe/cli/history_manager.py +91 -0
- vibe/cli/plan_offer/adapters/http_whoami_gateway.py +67 -0
- vibe/cli/plan_offer/decide_plan_offer.py +87 -0
- vibe/cli/plan_offer/ports/whoami_gateway.py +23 -0
- vibe/cli/terminal_setup.py +323 -0
- vibe/cli/textual_ui/__init__.py +0 -0
- vibe/cli/textual_ui/ansi_markdown.py +58 -0
- vibe/cli/textual_ui/app.py +1546 -0
- vibe/cli/textual_ui/app.tcss +1020 -0
- vibe/cli/textual_ui/external_editor.py +32 -0
- vibe/cli/textual_ui/handlers/__init__.py +5 -0
- vibe/cli/textual_ui/handlers/event_handler.py +147 -0
- vibe/cli/textual_ui/widgets/__init__.py +0 -0
- vibe/cli/textual_ui/widgets/approval_app.py +192 -0
- vibe/cli/textual_ui/widgets/banner/banner.py +85 -0
- vibe/cli/textual_ui/widgets/banner/petit_chat.py +195 -0
- vibe/cli/textual_ui/widgets/braille_renderer.py +58 -0
- vibe/cli/textual_ui/widgets/chat_input/__init__.py +7 -0
- vibe/cli/textual_ui/widgets/chat_input/body.py +214 -0
- vibe/cli/textual_ui/widgets/chat_input/completion_manager.py +58 -0
- vibe/cli/textual_ui/widgets/chat_input/completion_popup.py +43 -0
- vibe/cli/textual_ui/widgets/chat_input/container.py +195 -0
- vibe/cli/textual_ui/widgets/chat_input/text_area.py +365 -0
- vibe/cli/textual_ui/widgets/compact.py +41 -0
- vibe/cli/textual_ui/widgets/config_app.py +171 -0
- vibe/cli/textual_ui/widgets/context_progress.py +30 -0
- vibe/cli/textual_ui/widgets/load_more.py +43 -0
- vibe/cli/textual_ui/widgets/loading.py +201 -0
- vibe/cli/textual_ui/widgets/messages.py +277 -0
- vibe/cli/textual_ui/widgets/no_markup_static.py +11 -0
- vibe/cli/textual_ui/widgets/path_display.py +28 -0
- vibe/cli/textual_ui/widgets/proxy_setup_app.py +127 -0
- vibe/cli/textual_ui/widgets/question_app.py +496 -0
- vibe/cli/textual_ui/widgets/spinner.py +194 -0
- vibe/cli/textual_ui/widgets/status_message.py +76 -0
- vibe/cli/textual_ui/widgets/teleport_message.py +31 -0
- vibe/cli/textual_ui/widgets/tool_widgets.py +371 -0
- vibe/cli/textual_ui/widgets/tools.py +201 -0
- vibe/cli/textual_ui/windowing/__init__.py +29 -0
- vibe/cli/textual_ui/windowing/history.py +105 -0
- vibe/cli/textual_ui/windowing/history_windowing.py +71 -0
- vibe/cli/textual_ui/windowing/state.py +105 -0
- vibe/cli/update_notifier/__init__.py +47 -0
- vibe/cli/update_notifier/adapters/filesystem_update_cache_repository.py +59 -0
- vibe/cli/update_notifier/adapters/github_update_gateway.py +101 -0
- vibe/cli/update_notifier/adapters/pypi_update_gateway.py +107 -0
- vibe/cli/update_notifier/ports/update_cache_repository.py +16 -0
- vibe/cli/update_notifier/ports/update_gateway.py +53 -0
- vibe/cli/update_notifier/update.py +139 -0
- vibe/cli/update_notifier/whats_new.py +49 -0
- vibe/core/__init__.py +5 -0
- vibe/core/agent_loop.py +1075 -0
- vibe/core/agents/__init__.py +31 -0
- vibe/core/agents/manager.py +165 -0
- vibe/core/agents/models.py +122 -0
- vibe/core/auth/__init__.py +6 -0
- vibe/core/auth/crypto.py +137 -0
- vibe/core/auth/github.py +178 -0
- vibe/core/autocompletion/__init__.py +0 -0
- vibe/core/autocompletion/completers.py +257 -0
- vibe/core/autocompletion/file_indexer/__init__.py +10 -0
- vibe/core/autocompletion/file_indexer/ignore_rules.py +156 -0
- vibe/core/autocompletion/file_indexer/indexer.py +179 -0
- vibe/core/autocompletion/file_indexer/store.py +169 -0
- vibe/core/autocompletion/file_indexer/watcher.py +71 -0
- vibe/core/autocompletion/fuzzy.py +189 -0
- vibe/core/autocompletion/path_prompt.py +108 -0
- vibe/core/autocompletion/path_prompt_adapter.py +149 -0
- vibe/core/config.py +673 -0
- vibe/core/config_PATCH_INSTRUCTIONS.md +77 -0
- vibe/core/llm/__init__.py +0 -0
- vibe/core/llm/backend/anthropic.py +630 -0
- vibe/core/llm/backend/base.py +38 -0
- vibe/core/llm/backend/factory.py +7 -0
- vibe/core/llm/backend/generic.py +425 -0
- vibe/core/llm/backend/mistral.py +381 -0
- vibe/core/llm/backend/vertex.py +115 -0
- vibe/core/llm/exceptions.py +195 -0
- vibe/core/llm/format.py +184 -0
- vibe/core/llm/message_utils.py +24 -0
- vibe/core/llm/types.py +120 -0
- vibe/core/middleware.py +209 -0
- vibe/core/output_formatters.py +85 -0
- vibe/core/paths/__init__.py +0 -0
- vibe/core/paths/config_paths.py +68 -0
- vibe/core/paths/global_paths.py +40 -0
- vibe/core/programmatic.py +56 -0
- vibe/core/prompts/__init__.py +32 -0
- vibe/core/prompts/cli.md +111 -0
- vibe/core/prompts/compact.md +48 -0
- vibe/core/prompts/dangerous_directory.md +5 -0
- vibe/core/prompts/explore.md +50 -0
- vibe/core/prompts/project_context.md +8 -0
- vibe/core/prompts/tests.md +1 -0
- vibe/core/proxy_setup.py +65 -0
- vibe/core/session/session_loader.py +222 -0
- vibe/core/session/session_logger.py +318 -0
- vibe/core/session/session_migration.py +41 -0
- vibe/core/skills/__init__.py +7 -0
- vibe/core/skills/manager.py +132 -0
- vibe/core/skills/models.py +92 -0
- vibe/core/skills/parser.py +39 -0
- vibe/core/system_prompt.py +466 -0
- vibe/core/telemetry/__init__.py +0 -0
- vibe/core/telemetry/send.py +185 -0
- vibe/core/teleport/errors.py +9 -0
- vibe/core/teleport/git.py +196 -0
- vibe/core/teleport/nuage.py +180 -0
- vibe/core/teleport/teleport.py +208 -0
- vibe/core/teleport/types.py +54 -0
- vibe/core/tools/base.py +336 -0
- vibe/core/tools/builtins/ask_user_question.py +134 -0
- vibe/core/tools/builtins/bash.py +357 -0
- vibe/core/tools/builtins/grep.py +310 -0
- vibe/core/tools/builtins/prompts/__init__.py +0 -0
- vibe/core/tools/builtins/prompts/ask_user_question.md +84 -0
- vibe/core/tools/builtins/prompts/bash.md +73 -0
- vibe/core/tools/builtins/prompts/grep.md +4 -0
- vibe/core/tools/builtins/prompts/read_file.md +13 -0
- vibe/core/tools/builtins/prompts/search_replace.md +43 -0
- vibe/core/tools/builtins/prompts/task.md +24 -0
- vibe/core/tools/builtins/prompts/todo.md +199 -0
- vibe/core/tools/builtins/prompts/write_file.md +42 -0
- vibe/core/tools/builtins/read_file.py +222 -0
- vibe/core/tools/builtins/search_replace.py +456 -0
- vibe/core/tools/builtins/task.py +154 -0
- vibe/core/tools/builtins/todo.py +134 -0
- vibe/core/tools/builtins/write_file.py +160 -0
- vibe/core/tools/manager.py +341 -0
- vibe/core/tools/mcp.py +397 -0
- vibe/core/tools/ui.py +68 -0
- vibe/core/trusted_folders.py +86 -0
- vibe/core/types.py +405 -0
- vibe/core/utils.py +396 -0
- vibe/setup/onboarding/__init__.py +39 -0
- vibe/setup/onboarding/base.py +14 -0
- vibe/setup/onboarding/onboarding.tcss +134 -0
- vibe/setup/onboarding/screens/__init__.py +5 -0
- vibe/setup/onboarding/screens/api_key.py +200 -0
- vibe/setup/onboarding/screens/provider_selection.py +87 -0
- vibe/setup/onboarding/screens/welcome.py +136 -0
- vibe/setup/trusted_folders/trust_folder_dialog.py +180 -0
- vibe/setup/trusted_folders/trust_folder_dialog.tcss +83 -0
- vibe/whats_new.md +5 -0
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncGenerator
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import types
|
|
7
|
+
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from vibe.core.llm.backend.anthropic import AnthropicAdapter
|
|
12
|
+
from vibe.core.llm.backend.base import APIAdapter, PreparedRequest
|
|
13
|
+
from vibe.core.llm.backend.vertex import VertexAnthropicAdapter
|
|
14
|
+
from vibe.core.llm.exceptions import BackendErrorBuilder
|
|
15
|
+
from vibe.core.llm.message_utils import merge_consecutive_user_messages
|
|
16
|
+
from vibe.core.types import (
|
|
17
|
+
AvailableTool,
|
|
18
|
+
LLMChunk,
|
|
19
|
+
LLMMessage,
|
|
20
|
+
LLMUsage,
|
|
21
|
+
Role,
|
|
22
|
+
StrToolChoice,
|
|
23
|
+
)
|
|
24
|
+
from vibe.core.utils import async_generator_retry, async_retry
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from vibe.core.config import ModelConfig, ProviderConfig
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class OpenAIAdapter(APIAdapter):
|
|
31
|
+
endpoint: ClassVar[str] = "/chat/completions"
|
|
32
|
+
|
|
33
|
+
def build_payload(
|
|
34
|
+
self,
|
|
35
|
+
model_name: str,
|
|
36
|
+
converted_messages: list[dict[str, Any]],
|
|
37
|
+
temperature: float,
|
|
38
|
+
tools: list[AvailableTool] | None,
|
|
39
|
+
max_tokens: int | None,
|
|
40
|
+
tool_choice: StrToolChoice | AvailableTool | None,
|
|
41
|
+
) -> dict[str, Any]:
|
|
42
|
+
payload = {
|
|
43
|
+
"model": model_name,
|
|
44
|
+
"messages": converted_messages,
|
|
45
|
+
"temperature": temperature,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
if tools:
|
|
49
|
+
payload["tools"] = [tool.model_dump(exclude_none=True) for tool in tools]
|
|
50
|
+
if tool_choice:
|
|
51
|
+
payload["tool_choice"] = (
|
|
52
|
+
tool_choice
|
|
53
|
+
if isinstance(tool_choice, str)
|
|
54
|
+
else tool_choice.model_dump()
|
|
55
|
+
)
|
|
56
|
+
if max_tokens is not None:
|
|
57
|
+
payload["max_tokens"] = max_tokens
|
|
58
|
+
|
|
59
|
+
return payload
|
|
60
|
+
|
|
61
|
+
def build_headers(self, api_key: str | None = None) -> dict[str, str]:
|
|
62
|
+
headers = {"Content-Type": "application/json"}
|
|
63
|
+
if api_key:
|
|
64
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
65
|
+
return headers
|
|
66
|
+
|
|
67
|
+
def _reasoning_to_api(
|
|
68
|
+
self, msg_dict: dict[str, Any], field_name: str
|
|
69
|
+
) -> dict[str, Any]:
|
|
70
|
+
if field_name != "reasoning_content" and "reasoning_content" in msg_dict:
|
|
71
|
+
msg_dict[field_name] = msg_dict.pop("reasoning_content")
|
|
72
|
+
return msg_dict
|
|
73
|
+
|
|
74
|
+
def _reasoning_from_api(
|
|
75
|
+
self, msg_dict: dict[str, Any], field_name: str
|
|
76
|
+
) -> dict[str, Any]:
|
|
77
|
+
if field_name != "reasoning_content" and field_name in msg_dict:
|
|
78
|
+
msg_dict["reasoning_content"] = msg_dict.pop(field_name)
|
|
79
|
+
return msg_dict
|
|
80
|
+
|
|
81
|
+
def prepare_request( # noqa: PLR0913
|
|
82
|
+
self,
|
|
83
|
+
*,
|
|
84
|
+
model_name: str,
|
|
85
|
+
messages: list[LLMMessage],
|
|
86
|
+
temperature: float,
|
|
87
|
+
tools: list[AvailableTool] | None,
|
|
88
|
+
max_tokens: int | None,
|
|
89
|
+
tool_choice: StrToolChoice | AvailableTool | None,
|
|
90
|
+
enable_streaming: bool,
|
|
91
|
+
provider: ProviderConfig,
|
|
92
|
+
api_key: str | None = None,
|
|
93
|
+
thinking: str = "off",
|
|
94
|
+
) -> PreparedRequest:
|
|
95
|
+
merged_messages = merge_consecutive_user_messages(messages)
|
|
96
|
+
field_name = provider.reasoning_field_name
|
|
97
|
+
converted_messages = [
|
|
98
|
+
self._reasoning_to_api(
|
|
99
|
+
msg.model_dump(exclude_none=True, exclude={"message_id"}), field_name
|
|
100
|
+
)
|
|
101
|
+
for msg in merged_messages
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
payload = self.build_payload(
|
|
105
|
+
model_name, converted_messages, temperature, tools, max_tokens, tool_choice
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if enable_streaming:
|
|
109
|
+
payload["stream"] = True
|
|
110
|
+
stream_options = {"include_usage": True}
|
|
111
|
+
if provider.name == "mistral":
|
|
112
|
+
stream_options["stream_tool_calls"] = True
|
|
113
|
+
payload["stream_options"] = stream_options
|
|
114
|
+
|
|
115
|
+
headers = self.build_headers(api_key)
|
|
116
|
+
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
117
|
+
|
|
118
|
+
return PreparedRequest(self.endpoint, headers, body)
|
|
119
|
+
|
|
120
|
+
def _parse_message(
|
|
121
|
+
self, data: dict[str, Any], field_name: str
|
|
122
|
+
) -> LLMMessage | None:
|
|
123
|
+
if data.get("choices"):
|
|
124
|
+
choice = data["choices"][0]
|
|
125
|
+
if "message" in choice:
|
|
126
|
+
msg_dict = self._reasoning_from_api(choice["message"], field_name)
|
|
127
|
+
return LLMMessage.model_validate(msg_dict)
|
|
128
|
+
if "delta" in choice:
|
|
129
|
+
msg_dict = self._reasoning_from_api(choice["delta"], field_name)
|
|
130
|
+
return LLMMessage.model_validate(msg_dict)
|
|
131
|
+
raise ValueError("Invalid response data: missing message or delta")
|
|
132
|
+
|
|
133
|
+
if "message" in data:
|
|
134
|
+
msg_dict = self._reasoning_from_api(data["message"], field_name)
|
|
135
|
+
return LLMMessage.model_validate(msg_dict)
|
|
136
|
+
if "delta" in data:
|
|
137
|
+
msg_dict = self._reasoning_from_api(data["delta"], field_name)
|
|
138
|
+
return LLMMessage.model_validate(msg_dict)
|
|
139
|
+
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
def parse_response(
|
|
143
|
+
self, data: dict[str, Any], provider: ProviderConfig
|
|
144
|
+
) -> LLMChunk:
|
|
145
|
+
message = self._parse_message(data, provider.reasoning_field_name)
|
|
146
|
+
if message is None:
|
|
147
|
+
message = LLMMessage(role=Role.assistant, content="")
|
|
148
|
+
|
|
149
|
+
usage_data = data.get("usage") or {}
|
|
150
|
+
usage = LLMUsage(
|
|
151
|
+
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
152
|
+
completion_tokens=usage_data.get("completion_tokens", 0),
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return LLMChunk(message=message, usage=usage)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
ADAPTERS: dict[str, APIAdapter] = {
|
|
159
|
+
"openai": OpenAIAdapter(),
|
|
160
|
+
"anthropic": AnthropicAdapter(),
|
|
161
|
+
"vertex-anthropic": VertexAnthropicAdapter(),
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class GenericBackend:
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
*,
|
|
169
|
+
client: httpx.AsyncClient | None = None,
|
|
170
|
+
provider: ProviderConfig,
|
|
171
|
+
timeout: float = 720.0,
|
|
172
|
+
) -> None:
|
|
173
|
+
"""Initialize the backend.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
client: Optional httpx client to use. If not provided, one will be created.
|
|
177
|
+
"""
|
|
178
|
+
self._client = client
|
|
179
|
+
self._owns_client = client is None
|
|
180
|
+
self._provider = provider
|
|
181
|
+
self._timeout = timeout
|
|
182
|
+
|
|
183
|
+
async def __aenter__(self) -> GenericBackend:
|
|
184
|
+
if self._client is None:
|
|
185
|
+
self._client = httpx.AsyncClient(
|
|
186
|
+
timeout=httpx.Timeout(self._timeout),
|
|
187
|
+
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
|
188
|
+
)
|
|
189
|
+
return self
|
|
190
|
+
|
|
191
|
+
async def __aexit__(
|
|
192
|
+
self,
|
|
193
|
+
exc_type: type[BaseException] | None,
|
|
194
|
+
exc_val: BaseException | None,
|
|
195
|
+
exc_tb: types.TracebackType | None,
|
|
196
|
+
) -> None:
|
|
197
|
+
if self._owns_client and self._client:
|
|
198
|
+
await self._client.aclose()
|
|
199
|
+
self._client = None
|
|
200
|
+
|
|
201
|
+
def _get_client(self) -> httpx.AsyncClient:
|
|
202
|
+
if self._client is None:
|
|
203
|
+
self._client = httpx.AsyncClient(
|
|
204
|
+
timeout=httpx.Timeout(self._timeout),
|
|
205
|
+
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
|
206
|
+
)
|
|
207
|
+
self._owns_client = True
|
|
208
|
+
return self._client
|
|
209
|
+
|
|
210
|
+
async def complete(
|
|
211
|
+
self,
|
|
212
|
+
*,
|
|
213
|
+
model: ModelConfig,
|
|
214
|
+
messages: list[LLMMessage],
|
|
215
|
+
temperature: float = 0.2,
|
|
216
|
+
tools: list[AvailableTool] | None = None,
|
|
217
|
+
max_tokens: int | None = None,
|
|
218
|
+
tool_choice: StrToolChoice | AvailableTool | None = None,
|
|
219
|
+
extra_headers: dict[str, str] | None = None,
|
|
220
|
+
) -> LLMChunk:
|
|
221
|
+
api_key = (
|
|
222
|
+
os.getenv(self._provider.api_key_env_var)
|
|
223
|
+
if self._provider.api_key_env_var
|
|
224
|
+
else None
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
api_style = getattr(self._provider, "api_style", "openai")
|
|
228
|
+
adapter = ADAPTERS[api_style]
|
|
229
|
+
|
|
230
|
+
req = adapter.prepare_request(
|
|
231
|
+
model_name=model.name,
|
|
232
|
+
messages=messages,
|
|
233
|
+
temperature=temperature,
|
|
234
|
+
tools=tools,
|
|
235
|
+
max_tokens=max_tokens,
|
|
236
|
+
tool_choice=tool_choice,
|
|
237
|
+
enable_streaming=False,
|
|
238
|
+
provider=self._provider,
|
|
239
|
+
api_key=api_key,
|
|
240
|
+
thinking=model.thinking,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
headers = req.headers
|
|
244
|
+
if extra_headers:
|
|
245
|
+
headers.update(extra_headers)
|
|
246
|
+
|
|
247
|
+
base = req.base_url or self._provider.api_base
|
|
248
|
+
url = f"{base}{req.endpoint}"
|
|
249
|
+
|
|
250
|
+
try:
|
|
251
|
+
res_data, _ = await self._make_request(url, req.body, headers)
|
|
252
|
+
return adapter.parse_response(res_data, self._provider)
|
|
253
|
+
|
|
254
|
+
except httpx.HTTPStatusError as e:
|
|
255
|
+
raise BackendErrorBuilder.build_http_error(
|
|
256
|
+
provider=self._provider.name,
|
|
257
|
+
endpoint=url,
|
|
258
|
+
response=e.response,
|
|
259
|
+
headers=e.response.headers,
|
|
260
|
+
model=model.name,
|
|
261
|
+
messages=messages,
|
|
262
|
+
temperature=temperature,
|
|
263
|
+
has_tools=bool(tools),
|
|
264
|
+
tool_choice=tool_choice,
|
|
265
|
+
) from e
|
|
266
|
+
except httpx.RequestError as e:
|
|
267
|
+
raise BackendErrorBuilder.build_request_error(
|
|
268
|
+
provider=self._provider.name,
|
|
269
|
+
endpoint=url,
|
|
270
|
+
error=e,
|
|
271
|
+
model=model.name,
|
|
272
|
+
messages=messages,
|
|
273
|
+
temperature=temperature,
|
|
274
|
+
has_tools=bool(tools),
|
|
275
|
+
tool_choice=tool_choice,
|
|
276
|
+
) from e
|
|
277
|
+
|
|
278
|
+
async def complete_streaming(
|
|
279
|
+
self,
|
|
280
|
+
*,
|
|
281
|
+
model: ModelConfig,
|
|
282
|
+
messages: list[LLMMessage],
|
|
283
|
+
temperature: float = 0.2,
|
|
284
|
+
tools: list[AvailableTool] | None = None,
|
|
285
|
+
max_tokens: int | None = None,
|
|
286
|
+
tool_choice: StrToolChoice | AvailableTool | None = None,
|
|
287
|
+
extra_headers: dict[str, str] | None = None,
|
|
288
|
+
) -> AsyncGenerator[LLMChunk, None]:
|
|
289
|
+
api_key = (
|
|
290
|
+
os.getenv(self._provider.api_key_env_var)
|
|
291
|
+
if self._provider.api_key_env_var
|
|
292
|
+
else None
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
api_style = getattr(self._provider, "api_style", "openai")
|
|
296
|
+
adapter = ADAPTERS[api_style]
|
|
297
|
+
|
|
298
|
+
req = adapter.prepare_request(
|
|
299
|
+
model_name=model.name,
|
|
300
|
+
messages=messages,
|
|
301
|
+
temperature=temperature,
|
|
302
|
+
tools=tools,
|
|
303
|
+
max_tokens=max_tokens,
|
|
304
|
+
tool_choice=tool_choice,
|
|
305
|
+
enable_streaming=True,
|
|
306
|
+
provider=self._provider,
|
|
307
|
+
api_key=api_key,
|
|
308
|
+
thinking=model.thinking,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
headers = req.headers
|
|
312
|
+
if extra_headers:
|
|
313
|
+
headers.update(extra_headers)
|
|
314
|
+
|
|
315
|
+
base = req.base_url or self._provider.api_base
|
|
316
|
+
url = f"{base}{req.endpoint}"
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
async for res_data in self._make_streaming_request(url, req.body, headers):
|
|
320
|
+
yield adapter.parse_response(res_data, self._provider)
|
|
321
|
+
|
|
322
|
+
except httpx.HTTPStatusError as e:
|
|
323
|
+
raise BackendErrorBuilder.build_http_error(
|
|
324
|
+
provider=self._provider.name,
|
|
325
|
+
endpoint=url,
|
|
326
|
+
response=e.response,
|
|
327
|
+
headers=e.response.headers,
|
|
328
|
+
model=model.name,
|
|
329
|
+
messages=messages,
|
|
330
|
+
temperature=temperature,
|
|
331
|
+
has_tools=bool(tools),
|
|
332
|
+
tool_choice=tool_choice,
|
|
333
|
+
) from e
|
|
334
|
+
except httpx.RequestError as e:
|
|
335
|
+
raise BackendErrorBuilder.build_request_error(
|
|
336
|
+
provider=self._provider.name,
|
|
337
|
+
endpoint=url,
|
|
338
|
+
error=e,
|
|
339
|
+
model=model.name,
|
|
340
|
+
messages=messages,
|
|
341
|
+
temperature=temperature,
|
|
342
|
+
has_tools=bool(tools),
|
|
343
|
+
tool_choice=tool_choice,
|
|
344
|
+
) from e
|
|
345
|
+
|
|
346
|
+
class HTTPResponse(NamedTuple):
|
|
347
|
+
data: dict[str, Any]
|
|
348
|
+
headers: dict[str, str]
|
|
349
|
+
|
|
350
|
+
@async_retry(tries=3)
|
|
351
|
+
async def _make_request(
|
|
352
|
+
self, url: str, data: bytes, headers: dict[str, str]
|
|
353
|
+
) -> HTTPResponse:
|
|
354
|
+
client = self._get_client()
|
|
355
|
+
response = await client.post(url, content=data, headers=headers)
|
|
356
|
+
response.raise_for_status()
|
|
357
|
+
|
|
358
|
+
response_headers = dict(response.headers.items())
|
|
359
|
+
response_body = response.json()
|
|
360
|
+
return self.HTTPResponse(response_body, response_headers)
|
|
361
|
+
|
|
362
|
+
@async_generator_retry(tries=3)
|
|
363
|
+
async def _make_streaming_request(
|
|
364
|
+
self, url: str, data: bytes, headers: dict[str, str]
|
|
365
|
+
) -> AsyncGenerator[dict[str, Any]]:
|
|
366
|
+
client = self._get_client()
|
|
367
|
+
async with client.stream(
|
|
368
|
+
method="POST", url=url, content=data, headers=headers
|
|
369
|
+
) as response:
|
|
370
|
+
if not response.is_success:
|
|
371
|
+
await response.aread()
|
|
372
|
+
response.raise_for_status()
|
|
373
|
+
async for line in response.aiter_lines():
|
|
374
|
+
if line.strip() == "":
|
|
375
|
+
continue
|
|
376
|
+
|
|
377
|
+
DELIM_CHAR = ":"
|
|
378
|
+
if f"{DELIM_CHAR} " not in line:
|
|
379
|
+
raise ValueError(
|
|
380
|
+
f"Stream chunk improperly formatted. "
|
|
381
|
+
f"Expected `key{DELIM_CHAR} value`, received `{line}`"
|
|
382
|
+
)
|
|
383
|
+
delim_index = line.find(DELIM_CHAR)
|
|
384
|
+
key = line[0:delim_index]
|
|
385
|
+
value = line[delim_index + 2 :]
|
|
386
|
+
|
|
387
|
+
if key != "data":
|
|
388
|
+
# This might be the case with openrouter, so we just ignore it
|
|
389
|
+
continue
|
|
390
|
+
if value == "[DONE]":
|
|
391
|
+
return
|
|
392
|
+
yield json.loads(value.strip())
|
|
393
|
+
|
|
394
|
+
async def count_tokens(
|
|
395
|
+
self,
|
|
396
|
+
*,
|
|
397
|
+
model: ModelConfig,
|
|
398
|
+
messages: list[LLMMessage],
|
|
399
|
+
temperature: float = 0.0,
|
|
400
|
+
tools: list[AvailableTool] | None = None,
|
|
401
|
+
tool_choice: StrToolChoice | AvailableTool | None = None,
|
|
402
|
+
extra_headers: dict[str, str] | None = None,
|
|
403
|
+
) -> int:
|
|
404
|
+
probe_messages = list(messages)
|
|
405
|
+
if not probe_messages or probe_messages[-1].role != Role.user:
|
|
406
|
+
probe_messages.append(LLMMessage(role=Role.user, content=""))
|
|
407
|
+
|
|
408
|
+
result = await self.complete(
|
|
409
|
+
model=model,
|
|
410
|
+
messages=probe_messages,
|
|
411
|
+
temperature=temperature,
|
|
412
|
+
tools=tools,
|
|
413
|
+
max_tokens=16, # Minimal amount for openrouter with openai models
|
|
414
|
+
tool_choice=tool_choice,
|
|
415
|
+
extra_headers=extra_headers,
|
|
416
|
+
)
|
|
417
|
+
if result.usage is None:
|
|
418
|
+
raise ValueError("Missing usage in non streaming completion")
|
|
419
|
+
|
|
420
|
+
return result.usage.prompt_tokens
|
|
421
|
+
|
|
422
|
+
async def close(self) -> None:
|
|
423
|
+
if self._owns_client and self._client:
|
|
424
|
+
await self._client.aclose()
|
|
425
|
+
self._client = None
|