tracia 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tracia/__init__.py +152 -3
- tracia/_client.py +1100 -0
- tracia/_constants.py +39 -0
- tracia/_errors.py +87 -0
- tracia/_http.py +362 -0
- tracia/_llm.py +898 -0
- tracia/_session.py +244 -0
- tracia/_streaming.py +135 -0
- tracia/_types.py +564 -0
- tracia/_utils.py +116 -0
- tracia/py.typed +0 -0
- tracia/resources/__init__.py +6 -0
- tracia/resources/prompts.py +273 -0
- tracia/resources/spans.py +227 -0
- tracia-0.1.1.dist-info/METADATA +277 -0
- tracia-0.1.1.dist-info/RECORD +18 -0
- tracia-0.0.1.dist-info/METADATA +0 -52
- tracia-0.0.1.dist-info/RECORD +0 -5
- {tracia-0.0.1.dist-info → tracia-0.1.1.dist-info}/WHEEL +0 -0
- {tracia-0.0.1.dist-info → tracia-0.1.1.dist-info}/licenses/LICENSE +0 -0
tracia/_llm.py
ADDED
|
@@ -0,0 +1,898 @@
|
|
|
1
|
+
"""LiteLLM wrapper for unified LLM access."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator
|
|
9
|
+
|
|
10
|
+
from ._constants import ENV_VAR_MAP
|
|
11
|
+
from ._errors import TraciaError, TraciaErrorCode, sanitize_error_message
|
|
12
|
+
from ._types import (
|
|
13
|
+
ContentPart,
|
|
14
|
+
FinishReason,
|
|
15
|
+
LLMProvider,
|
|
16
|
+
LocalPromptMessage,
|
|
17
|
+
TextPart,
|
|
18
|
+
ToolCall,
|
|
19
|
+
ToolCallPart,
|
|
20
|
+
ToolChoice,
|
|
21
|
+
ToolDefinition,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from litellm import ModelResponse
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class CompletionResult:
|
|
30
|
+
"""Result from an LLM completion."""
|
|
31
|
+
|
|
32
|
+
text: str
|
|
33
|
+
input_tokens: int
|
|
34
|
+
output_tokens: int
|
|
35
|
+
total_tokens: int
|
|
36
|
+
tool_calls: list[ToolCall]
|
|
37
|
+
finish_reason: FinishReason
|
|
38
|
+
provider: LLMProvider
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Model to provider mapping for common models
|
|
42
|
+
_MODEL_PROVIDER_MAP: dict[str, LLMProvider] = {
|
|
43
|
+
# OpenAI
|
|
44
|
+
"gpt-3.5-turbo": LLMProvider.OPENAI,
|
|
45
|
+
"gpt-4": LLMProvider.OPENAI,
|
|
46
|
+
"gpt-4-turbo": LLMProvider.OPENAI,
|
|
47
|
+
"gpt-4o": LLMProvider.OPENAI,
|
|
48
|
+
"gpt-4o-mini": LLMProvider.OPENAI,
|
|
49
|
+
"gpt-4.1": LLMProvider.OPENAI,
|
|
50
|
+
"gpt-4.1-mini": LLMProvider.OPENAI,
|
|
51
|
+
"gpt-4.1-nano": LLMProvider.OPENAI,
|
|
52
|
+
"gpt-4.5-preview": LLMProvider.OPENAI,
|
|
53
|
+
"gpt-5": LLMProvider.OPENAI,
|
|
54
|
+
"o1": LLMProvider.OPENAI,
|
|
55
|
+
"o1-mini": LLMProvider.OPENAI,
|
|
56
|
+
"o1-preview": LLMProvider.OPENAI,
|
|
57
|
+
"o3": LLMProvider.OPENAI,
|
|
58
|
+
"o3-mini": LLMProvider.OPENAI,
|
|
59
|
+
"o4-mini": LLMProvider.OPENAI,
|
|
60
|
+
# Anthropic
|
|
61
|
+
"claude-3-haiku-20240307": LLMProvider.ANTHROPIC,
|
|
62
|
+
"claude-3-sonnet-20240229": LLMProvider.ANTHROPIC,
|
|
63
|
+
"claude-3-opus-20240229": LLMProvider.ANTHROPIC,
|
|
64
|
+
"claude-3-5-haiku-20241022": LLMProvider.ANTHROPIC,
|
|
65
|
+
"claude-3-5-sonnet-20241022": LLMProvider.ANTHROPIC,
|
|
66
|
+
"claude-sonnet-4-20250514": LLMProvider.ANTHROPIC,
|
|
67
|
+
"claude-opus-4-20250514": LLMProvider.ANTHROPIC,
|
|
68
|
+
# Google
|
|
69
|
+
"gemini-2.0-flash": LLMProvider.GOOGLE,
|
|
70
|
+
"gemini-2.0-flash-lite": LLMProvider.GOOGLE,
|
|
71
|
+
"gemini-2.5-pro": LLMProvider.GOOGLE,
|
|
72
|
+
"gemini-2.5-flash": LLMProvider.GOOGLE,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def resolve_provider(model: str, explicit_provider: LLMProvider | None) -> LLMProvider:
|
|
77
|
+
"""Resolve the provider for a model.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
model: The model name.
|
|
81
|
+
explicit_provider: Explicitly specified provider.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
The resolved provider.
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
TraciaError: If the provider cannot be determined.
|
|
88
|
+
"""
|
|
89
|
+
if explicit_provider is not None:
|
|
90
|
+
return explicit_provider
|
|
91
|
+
|
|
92
|
+
# Check the model registry
|
|
93
|
+
if model in _MODEL_PROVIDER_MAP:
|
|
94
|
+
return _MODEL_PROVIDER_MAP[model]
|
|
95
|
+
|
|
96
|
+
# Try prefix-based detection
|
|
97
|
+
if model.startswith("gpt-") or model.startswith("o1") or model.startswith("o3") or model.startswith("o4"):
|
|
98
|
+
return LLMProvider.OPENAI
|
|
99
|
+
if model.startswith("claude-"):
|
|
100
|
+
return LLMProvider.ANTHROPIC
|
|
101
|
+
if model.startswith("gemini-"):
|
|
102
|
+
return LLMProvider.GOOGLE
|
|
103
|
+
|
|
104
|
+
raise TraciaError(
|
|
105
|
+
code=TraciaErrorCode.UNSUPPORTED_MODEL,
|
|
106
|
+
message=f"Cannot determine provider for model '{model}'. Please specify the provider explicitly.",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def get_litellm_model(model: str, provider: LLMProvider) -> str:
|
|
111
|
+
"""Get the litellm-compatible model name.
|
|
112
|
+
|
|
113
|
+
LiteLLM requires a ``gemini/`` prefix to route Google AI Studio models
|
|
114
|
+
correctly. Without it, litellm defaults to the Vertex AI path which
|
|
115
|
+
requires Application Default Credentials instead of an API key.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
model: The user-facing model name (e.g. ``gemini-2.0-flash``).
|
|
119
|
+
provider: The resolved provider.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
The model string suitable for ``litellm.completion()``.
|
|
123
|
+
"""
|
|
124
|
+
if provider == LLMProvider.GOOGLE and not model.startswith("gemini/"):
|
|
125
|
+
return f"gemini/{model}"
|
|
126
|
+
return model
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def get_provider_api_key(
|
|
130
|
+
provider: LLMProvider,
|
|
131
|
+
provider_api_key: str | None = None,
|
|
132
|
+
) -> str:
|
|
133
|
+
"""Get the API key for a provider.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
provider: The LLM provider.
|
|
137
|
+
provider_api_key: Explicitly provided API key.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
The API key.
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
TraciaError: If no API key is found.
|
|
144
|
+
"""
|
|
145
|
+
if provider_api_key:
|
|
146
|
+
return provider_api_key
|
|
147
|
+
|
|
148
|
+
env_var = ENV_VAR_MAP.get(provider.value)
|
|
149
|
+
if env_var:
|
|
150
|
+
key = os.environ.get(env_var)
|
|
151
|
+
if key:
|
|
152
|
+
return key
|
|
153
|
+
|
|
154
|
+
raise TraciaError(
|
|
155
|
+
code=TraciaErrorCode.MISSING_PROVIDER_API_KEY,
|
|
156
|
+
message=f"No API key found for provider '{provider.value}'. "
|
|
157
|
+
f"Set the {ENV_VAR_MAP.get(provider.value, 'PROVIDER_API_KEY')} environment variable "
|
|
158
|
+
"or pass provider_api_key parameter.",
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def convert_messages(
|
|
163
|
+
messages: list[LocalPromptMessage],
|
|
164
|
+
) -> list[dict[str, Any]]:
|
|
165
|
+
"""Convert Tracia messages to LiteLLM/OpenAI format.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
messages: The Tracia messages.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Messages in LiteLLM format.
|
|
172
|
+
"""
|
|
173
|
+
result: list[dict[str, Any]] = []
|
|
174
|
+
|
|
175
|
+
for msg in messages:
|
|
176
|
+
# Handle tool role
|
|
177
|
+
if msg.role == "tool":
|
|
178
|
+
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
179
|
+
result.append({
|
|
180
|
+
"role": "tool",
|
|
181
|
+
"tool_call_id": msg.tool_call_id,
|
|
182
|
+
"content": content,
|
|
183
|
+
})
|
|
184
|
+
continue
|
|
185
|
+
|
|
186
|
+
# Handle developer role (map to system)
|
|
187
|
+
role = "system" if msg.role == "developer" else msg.role
|
|
188
|
+
|
|
189
|
+
# Handle string content
|
|
190
|
+
if isinstance(msg.content, str):
|
|
191
|
+
result.append({"role": role, "content": msg.content})
|
|
192
|
+
continue
|
|
193
|
+
|
|
194
|
+
# Handle list content (text parts and tool calls)
|
|
195
|
+
content_parts: list[dict[str, Any]] = []
|
|
196
|
+
tool_calls: list[dict[str, Any]] = []
|
|
197
|
+
|
|
198
|
+
for part in msg.content:
|
|
199
|
+
if isinstance(part, TextPart) or (isinstance(part, dict) and part.get("type") == "text"):
|
|
200
|
+
text = part.text if isinstance(part, TextPart) else part.get("text", "")
|
|
201
|
+
content_parts.append({"type": "text", "text": text})
|
|
202
|
+
elif isinstance(part, ToolCallPart) or (isinstance(part, dict) and part.get("type") == "tool_call"):
|
|
203
|
+
if isinstance(part, ToolCallPart):
|
|
204
|
+
tc_id = part.id
|
|
205
|
+
tc_name = part.name
|
|
206
|
+
tc_args = part.arguments
|
|
207
|
+
else:
|
|
208
|
+
tc_id = part.get("id", "")
|
|
209
|
+
tc_name = part.get("name", "")
|
|
210
|
+
tc_args = part.get("arguments", {})
|
|
211
|
+
|
|
212
|
+
tool_calls.append({
|
|
213
|
+
"id": tc_id,
|
|
214
|
+
"type": "function",
|
|
215
|
+
"function": {
|
|
216
|
+
"name": tc_name,
|
|
217
|
+
"arguments": json.dumps(tc_args) if isinstance(tc_args, dict) else tc_args,
|
|
218
|
+
},
|
|
219
|
+
})
|
|
220
|
+
|
|
221
|
+
# Build the message
|
|
222
|
+
msg_dict: dict[str, Any] = {"role": role}
|
|
223
|
+
|
|
224
|
+
if content_parts:
|
|
225
|
+
# If only text parts, we can simplify
|
|
226
|
+
if len(content_parts) == 1 and not tool_calls:
|
|
227
|
+
msg_dict["content"] = content_parts[0]["text"]
|
|
228
|
+
else:
|
|
229
|
+
msg_dict["content"] = content_parts
|
|
230
|
+
elif not tool_calls:
|
|
231
|
+
msg_dict["content"] = ""
|
|
232
|
+
|
|
233
|
+
if tool_calls:
|
|
234
|
+
msg_dict["tool_calls"] = tool_calls
|
|
235
|
+
|
|
236
|
+
result.append(msg_dict)
|
|
237
|
+
|
|
238
|
+
return result
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def convert_tools(tools: list[ToolDefinition] | None) -> list[dict[str, Any]] | None:
|
|
242
|
+
"""Convert tool definitions to LiteLLM format.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
tools: The tool definitions.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Tools in LiteLLM format.
|
|
249
|
+
"""
|
|
250
|
+
if not tools:
|
|
251
|
+
return None
|
|
252
|
+
|
|
253
|
+
result = []
|
|
254
|
+
for tool in tools:
|
|
255
|
+
result.append({
|
|
256
|
+
"type": "function",
|
|
257
|
+
"function": {
|
|
258
|
+
"name": tool.name,
|
|
259
|
+
"description": tool.description,
|
|
260
|
+
"parameters": tool.parameters.model_dump(exclude_none=True),
|
|
261
|
+
},
|
|
262
|
+
})
|
|
263
|
+
return result
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def convert_tool_choice(tool_choice: ToolChoice | None) -> str | dict[str, Any] | None:
|
|
267
|
+
"""Convert tool choice to LiteLLM format.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
tool_choice: The tool choice.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Tool choice in LiteLLM format.
|
|
274
|
+
"""
|
|
275
|
+
if tool_choice is None:
|
|
276
|
+
return None
|
|
277
|
+
|
|
278
|
+
if isinstance(tool_choice, str):
|
|
279
|
+
return tool_choice
|
|
280
|
+
|
|
281
|
+
if isinstance(tool_choice, dict) and "tool" in tool_choice:
|
|
282
|
+
return {"type": "function", "function": {"name": tool_choice["tool"]}}
|
|
283
|
+
|
|
284
|
+
return None
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def parse_finish_reason(reason: str | None) -> FinishReason:
|
|
288
|
+
"""Parse the finish reason from LiteLLM response.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
reason: The raw finish reason.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
The normalized finish reason.
|
|
295
|
+
"""
|
|
296
|
+
if reason == "tool_calls":
|
|
297
|
+
return "tool_calls"
|
|
298
|
+
if reason == "length":
|
|
299
|
+
return "max_tokens"
|
|
300
|
+
return "stop"
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def extract_tool_calls(response: "ModelResponse") -> list[ToolCall]:
|
|
304
|
+
"""Extract tool calls from a LiteLLM response.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
response: The LiteLLM response.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
The extracted tool calls.
|
|
311
|
+
"""
|
|
312
|
+
tool_calls: list[ToolCall] = []
|
|
313
|
+
|
|
314
|
+
choices = getattr(response, "choices", [])
|
|
315
|
+
if not choices:
|
|
316
|
+
return tool_calls
|
|
317
|
+
|
|
318
|
+
message = getattr(choices[0], "message", None)
|
|
319
|
+
if not message:
|
|
320
|
+
return tool_calls
|
|
321
|
+
|
|
322
|
+
raw_tool_calls = getattr(message, "tool_calls", None)
|
|
323
|
+
if not raw_tool_calls:
|
|
324
|
+
return tool_calls
|
|
325
|
+
|
|
326
|
+
for tc in raw_tool_calls:
|
|
327
|
+
func = getattr(tc, "function", None)
|
|
328
|
+
if func:
|
|
329
|
+
try:
|
|
330
|
+
args = json.loads(func.arguments) if isinstance(func.arguments, str) else func.arguments
|
|
331
|
+
except json.JSONDecodeError:
|
|
332
|
+
args = {}
|
|
333
|
+
|
|
334
|
+
tool_calls.append(ToolCall(
|
|
335
|
+
id=tc.id,
|
|
336
|
+
name=func.name,
|
|
337
|
+
arguments=args,
|
|
338
|
+
))
|
|
339
|
+
|
|
340
|
+
return tool_calls
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class LLMClient:
|
|
344
|
+
"""Client for making LLM calls via LiteLLM."""
|
|
345
|
+
|
|
346
|
+
def complete(
|
|
347
|
+
self,
|
|
348
|
+
model: str,
|
|
349
|
+
messages: list[LocalPromptMessage],
|
|
350
|
+
*,
|
|
351
|
+
provider: LLMProvider | None = None,
|
|
352
|
+
temperature: float | None = None,
|
|
353
|
+
max_tokens: int | None = None,
|
|
354
|
+
top_p: float | None = None,
|
|
355
|
+
stop: list[str] | None = None,
|
|
356
|
+
tools: list[ToolDefinition] | None = None,
|
|
357
|
+
tool_choice: ToolChoice | None = None,
|
|
358
|
+
api_key: str | None = None,
|
|
359
|
+
timeout: float | None = None,
|
|
360
|
+
) -> CompletionResult:
|
|
361
|
+
"""Make a synchronous completion request.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
model: The model name.
|
|
365
|
+
messages: The messages to send.
|
|
366
|
+
provider: The LLM provider.
|
|
367
|
+
temperature: Sampling temperature.
|
|
368
|
+
max_tokens: Maximum output tokens.
|
|
369
|
+
top_p: Top-p sampling.
|
|
370
|
+
stop: Stop sequences.
|
|
371
|
+
tools: Tool definitions.
|
|
372
|
+
tool_choice: Tool choice setting.
|
|
373
|
+
api_key: Provider API key.
|
|
374
|
+
timeout: Request timeout in seconds.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
The completion result.
|
|
378
|
+
|
|
379
|
+
Raises:
|
|
380
|
+
TraciaError: If the request fails.
|
|
381
|
+
"""
|
|
382
|
+
try:
|
|
383
|
+
import litellm
|
|
384
|
+
except ImportError as e:
|
|
385
|
+
raise TraciaError(
|
|
386
|
+
code=TraciaErrorCode.MISSING_PROVIDER_SDK,
|
|
387
|
+
message="litellm is not installed. Install it with: pip install litellm",
|
|
388
|
+
) from e
|
|
389
|
+
|
|
390
|
+
resolved_provider = resolve_provider(model, provider)
|
|
391
|
+
resolved_api_key = get_provider_api_key(resolved_provider, api_key)
|
|
392
|
+
|
|
393
|
+
# Build the request
|
|
394
|
+
litellm_messages = convert_messages(messages)
|
|
395
|
+
litellm_tools = convert_tools(tools)
|
|
396
|
+
litellm_tool_choice = convert_tool_choice(tool_choice)
|
|
397
|
+
|
|
398
|
+
request_kwargs: dict[str, Any] = {
|
|
399
|
+
"model": get_litellm_model(model, resolved_provider),
|
|
400
|
+
"messages": litellm_messages,
|
|
401
|
+
"api_key": resolved_api_key,
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
if temperature is not None:
|
|
405
|
+
request_kwargs["temperature"] = temperature
|
|
406
|
+
if max_tokens is not None:
|
|
407
|
+
request_kwargs["max_tokens"] = max_tokens
|
|
408
|
+
if top_p is not None:
|
|
409
|
+
request_kwargs["top_p"] = top_p
|
|
410
|
+
if stop is not None:
|
|
411
|
+
request_kwargs["stop"] = stop
|
|
412
|
+
if litellm_tools is not None:
|
|
413
|
+
request_kwargs["tools"] = litellm_tools
|
|
414
|
+
if litellm_tool_choice is not None:
|
|
415
|
+
request_kwargs["tool_choice"] = litellm_tool_choice
|
|
416
|
+
if timeout is not None:
|
|
417
|
+
request_kwargs["timeout"] = timeout
|
|
418
|
+
|
|
419
|
+
try:
|
|
420
|
+
response = litellm.completion(**request_kwargs)
|
|
421
|
+
except Exception as e:
|
|
422
|
+
error_msg = sanitize_error_message(str(e))
|
|
423
|
+
raise TraciaError(
|
|
424
|
+
code=TraciaErrorCode.PROVIDER_ERROR,
|
|
425
|
+
message=f"LLM provider error: {error_msg}",
|
|
426
|
+
) from e
|
|
427
|
+
|
|
428
|
+
# Extract result
|
|
429
|
+
usage = getattr(response, "usage", None)
|
|
430
|
+
choices = getattr(response, "choices", [])
|
|
431
|
+
message = choices[0].message if choices else None
|
|
432
|
+
content = getattr(message, "content", "") or ""
|
|
433
|
+
finish_reason = choices[0].finish_reason if choices else "stop"
|
|
434
|
+
|
|
435
|
+
return CompletionResult(
|
|
436
|
+
text=content,
|
|
437
|
+
input_tokens=getattr(usage, "prompt_tokens", 0) if usage else 0,
|
|
438
|
+
output_tokens=getattr(usage, "completion_tokens", 0) if usage else 0,
|
|
439
|
+
total_tokens=getattr(usage, "total_tokens", 0) if usage else 0,
|
|
440
|
+
tool_calls=extract_tool_calls(response),
|
|
441
|
+
finish_reason=parse_finish_reason(finish_reason),
|
|
442
|
+
provider=resolved_provider,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
async def acomplete(
|
|
446
|
+
self,
|
|
447
|
+
model: str,
|
|
448
|
+
messages: list[LocalPromptMessage],
|
|
449
|
+
*,
|
|
450
|
+
provider: LLMProvider | None = None,
|
|
451
|
+
temperature: float | None = None,
|
|
452
|
+
max_tokens: int | None = None,
|
|
453
|
+
top_p: float | None = None,
|
|
454
|
+
stop: list[str] | None = None,
|
|
455
|
+
tools: list[ToolDefinition] | None = None,
|
|
456
|
+
tool_choice: ToolChoice | None = None,
|
|
457
|
+
api_key: str | None = None,
|
|
458
|
+
timeout: float | None = None,
|
|
459
|
+
) -> CompletionResult:
|
|
460
|
+
"""Make an asynchronous completion request.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
model: The model name.
|
|
464
|
+
messages: The messages to send.
|
|
465
|
+
provider: The LLM provider.
|
|
466
|
+
temperature: Sampling temperature.
|
|
467
|
+
max_tokens: Maximum output tokens.
|
|
468
|
+
top_p: Top-p sampling.
|
|
469
|
+
stop: Stop sequences.
|
|
470
|
+
tools: Tool definitions.
|
|
471
|
+
tool_choice: Tool choice setting.
|
|
472
|
+
api_key: Provider API key.
|
|
473
|
+
timeout: Request timeout in seconds.
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
The completion result.
|
|
477
|
+
|
|
478
|
+
Raises:
|
|
479
|
+
TraciaError: If the request fails.
|
|
480
|
+
"""
|
|
481
|
+
try:
|
|
482
|
+
import litellm
|
|
483
|
+
except ImportError as e:
|
|
484
|
+
raise TraciaError(
|
|
485
|
+
code=TraciaErrorCode.MISSING_PROVIDER_SDK,
|
|
486
|
+
message="litellm is not installed. Install it with: pip install litellm",
|
|
487
|
+
) from e
|
|
488
|
+
|
|
489
|
+
resolved_provider = resolve_provider(model, provider)
|
|
490
|
+
resolved_api_key = get_provider_api_key(resolved_provider, api_key)
|
|
491
|
+
|
|
492
|
+
# Build the request
|
|
493
|
+
litellm_messages = convert_messages(messages)
|
|
494
|
+
litellm_tools = convert_tools(tools)
|
|
495
|
+
litellm_tool_choice = convert_tool_choice(tool_choice)
|
|
496
|
+
|
|
497
|
+
request_kwargs: dict[str, Any] = {
|
|
498
|
+
"model": get_litellm_model(model, resolved_provider),
|
|
499
|
+
"messages": litellm_messages,
|
|
500
|
+
"api_key": resolved_api_key,
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
if temperature is not None:
|
|
504
|
+
request_kwargs["temperature"] = temperature
|
|
505
|
+
if max_tokens is not None:
|
|
506
|
+
request_kwargs["max_tokens"] = max_tokens
|
|
507
|
+
if top_p is not None:
|
|
508
|
+
request_kwargs["top_p"] = top_p
|
|
509
|
+
if stop is not None:
|
|
510
|
+
request_kwargs["stop"] = stop
|
|
511
|
+
if litellm_tools is not None:
|
|
512
|
+
request_kwargs["tools"] = litellm_tools
|
|
513
|
+
if litellm_tool_choice is not None:
|
|
514
|
+
request_kwargs["tool_choice"] = litellm_tool_choice
|
|
515
|
+
if timeout is not None:
|
|
516
|
+
request_kwargs["timeout"] = timeout
|
|
517
|
+
|
|
518
|
+
try:
|
|
519
|
+
response = await litellm.acompletion(**request_kwargs)
|
|
520
|
+
except Exception as e:
|
|
521
|
+
error_msg = sanitize_error_message(str(e))
|
|
522
|
+
raise TraciaError(
|
|
523
|
+
code=TraciaErrorCode.PROVIDER_ERROR,
|
|
524
|
+
message=f"LLM provider error: {error_msg}",
|
|
525
|
+
) from e
|
|
526
|
+
|
|
527
|
+
# Extract result
|
|
528
|
+
usage = getattr(response, "usage", None)
|
|
529
|
+
choices = getattr(response, "choices", [])
|
|
530
|
+
message = choices[0].message if choices else None
|
|
531
|
+
content = getattr(message, "content", "") or ""
|
|
532
|
+
finish_reason = choices[0].finish_reason if choices else "stop"
|
|
533
|
+
|
|
534
|
+
return CompletionResult(
|
|
535
|
+
text=content,
|
|
536
|
+
input_tokens=getattr(usage, "prompt_tokens", 0) if usage else 0,
|
|
537
|
+
output_tokens=getattr(usage, "completion_tokens", 0) if usage else 0,
|
|
538
|
+
total_tokens=getattr(usage, "total_tokens", 0) if usage else 0,
|
|
539
|
+
tool_calls=extract_tool_calls(response),
|
|
540
|
+
finish_reason=parse_finish_reason(finish_reason),
|
|
541
|
+
provider=resolved_provider,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
def stream(
|
|
545
|
+
self,
|
|
546
|
+
model: str,
|
|
547
|
+
messages: list[LocalPromptMessage],
|
|
548
|
+
*,
|
|
549
|
+
provider: LLMProvider | None = None,
|
|
550
|
+
temperature: float | None = None,
|
|
551
|
+
max_tokens: int | None = None,
|
|
552
|
+
top_p: float | None = None,
|
|
553
|
+
stop: list[str] | None = None,
|
|
554
|
+
tools: list[ToolDefinition] | None = None,
|
|
555
|
+
tool_choice: ToolChoice | None = None,
|
|
556
|
+
api_key: str | None = None,
|
|
557
|
+
timeout: float | None = None,
|
|
558
|
+
) -> tuple[Iterator[str], list[CompletionResult], LLMProvider]:
|
|
559
|
+
"""Make a streaming completion request.
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
model: The model name.
|
|
563
|
+
messages: The messages to send.
|
|
564
|
+
provider: The LLM provider.
|
|
565
|
+
temperature: Sampling temperature.
|
|
566
|
+
max_tokens: Maximum output tokens.
|
|
567
|
+
top_p: Top-p sampling.
|
|
568
|
+
stop: Stop sequences.
|
|
569
|
+
tools: Tool definitions.
|
|
570
|
+
tool_choice: Tool choice setting.
|
|
571
|
+
api_key: Provider API key.
|
|
572
|
+
timeout: Request timeout in seconds.
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
A tuple of (chunk iterator, result holder list, provider).
|
|
576
|
+
|
|
577
|
+
Raises:
|
|
578
|
+
TraciaError: If the request fails.
|
|
579
|
+
"""
|
|
580
|
+
try:
|
|
581
|
+
import litellm
|
|
582
|
+
except ImportError as e:
|
|
583
|
+
raise TraciaError(
|
|
584
|
+
code=TraciaErrorCode.MISSING_PROVIDER_SDK,
|
|
585
|
+
message="litellm is not installed. Install it with: pip install litellm",
|
|
586
|
+
) from e
|
|
587
|
+
|
|
588
|
+
resolved_provider = resolve_provider(model, provider)
|
|
589
|
+
resolved_api_key = get_provider_api_key(resolved_provider, api_key)
|
|
590
|
+
|
|
591
|
+
# Build the request
|
|
592
|
+
litellm_messages = convert_messages(messages)
|
|
593
|
+
litellm_tools = convert_tools(tools)
|
|
594
|
+
litellm_tool_choice = convert_tool_choice(tool_choice)
|
|
595
|
+
|
|
596
|
+
request_kwargs: dict[str, Any] = {
|
|
597
|
+
"model": get_litellm_model(model, resolved_provider),
|
|
598
|
+
"messages": litellm_messages,
|
|
599
|
+
"api_key": resolved_api_key,
|
|
600
|
+
"stream": True,
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
if temperature is not None:
|
|
604
|
+
request_kwargs["temperature"] = temperature
|
|
605
|
+
if max_tokens is not None:
|
|
606
|
+
request_kwargs["max_tokens"] = max_tokens
|
|
607
|
+
if top_p is not None:
|
|
608
|
+
request_kwargs["top_p"] = top_p
|
|
609
|
+
if stop is not None:
|
|
610
|
+
request_kwargs["stop"] = stop
|
|
611
|
+
if litellm_tools is not None:
|
|
612
|
+
request_kwargs["tools"] = litellm_tools
|
|
613
|
+
if litellm_tool_choice is not None:
|
|
614
|
+
request_kwargs["tool_choice"] = litellm_tool_choice
|
|
615
|
+
if timeout is not None:
|
|
616
|
+
request_kwargs["timeout"] = timeout
|
|
617
|
+
|
|
618
|
+
result_holder: list[CompletionResult] = []
|
|
619
|
+
|
|
620
|
+
def generate_chunks() -> Iterator[str]:
|
|
621
|
+
full_text = ""
|
|
622
|
+
input_tokens = 0
|
|
623
|
+
output_tokens = 0
|
|
624
|
+
total_tokens = 0
|
|
625
|
+
tool_calls: list[ToolCall] = []
|
|
626
|
+
finish_reason: FinishReason = "stop"
|
|
627
|
+
tool_call_chunks: dict[int, dict[str, Any]] = {}
|
|
628
|
+
|
|
629
|
+
try:
|
|
630
|
+
response = litellm.completion(**request_kwargs)
|
|
631
|
+
|
|
632
|
+
for chunk in response:
|
|
633
|
+
choices = getattr(chunk, "choices", [])
|
|
634
|
+
if not choices:
|
|
635
|
+
continue
|
|
636
|
+
|
|
637
|
+
delta = getattr(choices[0], "delta", None)
|
|
638
|
+
if delta:
|
|
639
|
+
content = getattr(delta, "content", None)
|
|
640
|
+
if content:
|
|
641
|
+
full_text += content
|
|
642
|
+
yield content
|
|
643
|
+
|
|
644
|
+
# Handle streaming tool calls
|
|
645
|
+
delta_tool_calls = getattr(delta, "tool_calls", None)
|
|
646
|
+
if delta_tool_calls:
|
|
647
|
+
for tc in delta_tool_calls:
|
|
648
|
+
idx = tc.index
|
|
649
|
+
if idx not in tool_call_chunks:
|
|
650
|
+
tool_call_chunks[idx] = {
|
|
651
|
+
"id": "",
|
|
652
|
+
"name": "",
|
|
653
|
+
"arguments": "",
|
|
654
|
+
}
|
|
655
|
+
if tc.id:
|
|
656
|
+
tool_call_chunks[idx]["id"] = tc.id
|
|
657
|
+
if tc.function:
|
|
658
|
+
if tc.function.name:
|
|
659
|
+
tool_call_chunks[idx]["name"] = tc.function.name
|
|
660
|
+
if tc.function.arguments:
|
|
661
|
+
tool_call_chunks[idx]["arguments"] += tc.function.arguments
|
|
662
|
+
|
|
663
|
+
chunk_finish = getattr(choices[0], "finish_reason", None)
|
|
664
|
+
if chunk_finish:
|
|
665
|
+
finish_reason = parse_finish_reason(chunk_finish)
|
|
666
|
+
|
|
667
|
+
# Extract usage from final chunk
|
|
668
|
+
usage = getattr(chunk, "usage", None)
|
|
669
|
+
if usage:
|
|
670
|
+
input_tokens = getattr(usage, "prompt_tokens", 0)
|
|
671
|
+
output_tokens = getattr(usage, "completion_tokens", 0)
|
|
672
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
673
|
+
|
|
674
|
+
# Convert accumulated tool calls
|
|
675
|
+
for idx in sorted(tool_call_chunks.keys()):
|
|
676
|
+
tc_data = tool_call_chunks[idx]
|
|
677
|
+
try:
|
|
678
|
+
args = json.loads(tc_data["arguments"]) if tc_data["arguments"] else {}
|
|
679
|
+
except json.JSONDecodeError:
|
|
680
|
+
args = {}
|
|
681
|
+
tool_calls.append(ToolCall(
|
|
682
|
+
id=tc_data["id"],
|
|
683
|
+
name=tc_data["name"],
|
|
684
|
+
arguments=args,
|
|
685
|
+
))
|
|
686
|
+
|
|
687
|
+
result_holder.append(CompletionResult(
|
|
688
|
+
text=full_text,
|
|
689
|
+
input_tokens=input_tokens,
|
|
690
|
+
output_tokens=output_tokens,
|
|
691
|
+
total_tokens=total_tokens,
|
|
692
|
+
tool_calls=tool_calls,
|
|
693
|
+
finish_reason=finish_reason,
|
|
694
|
+
provider=resolved_provider,
|
|
695
|
+
))
|
|
696
|
+
|
|
697
|
+
except Exception as e:
|
|
698
|
+
error_msg = sanitize_error_message(str(e))
|
|
699
|
+
raise TraciaError(
|
|
700
|
+
code=TraciaErrorCode.PROVIDER_ERROR,
|
|
701
|
+
message=f"LLM provider error: {error_msg}",
|
|
702
|
+
) from e
|
|
703
|
+
|
|
704
|
+
return generate_chunks(), result_holder, resolved_provider
|
|
705
|
+
|
|
706
|
+
async def astream(
|
|
707
|
+
self,
|
|
708
|
+
model: str,
|
|
709
|
+
messages: list[LocalPromptMessage],
|
|
710
|
+
*,
|
|
711
|
+
provider: LLMProvider | None = None,
|
|
712
|
+
temperature: float | None = None,
|
|
713
|
+
max_tokens: int | None = None,
|
|
714
|
+
top_p: float | None = None,
|
|
715
|
+
stop: list[str] | None = None,
|
|
716
|
+
tools: list[ToolDefinition] | None = None,
|
|
717
|
+
tool_choice: ToolChoice | None = None,
|
|
718
|
+
api_key: str | None = None,
|
|
719
|
+
timeout: float | None = None,
|
|
720
|
+
) -> tuple[AsyncIterator[str], list[CompletionResult], LLMProvider]:
|
|
721
|
+
"""Make an async streaming completion request.
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
model: The model name.
|
|
725
|
+
messages: The messages to send.
|
|
726
|
+
provider: The LLM provider.
|
|
727
|
+
temperature: Sampling temperature.
|
|
728
|
+
max_tokens: Maximum output tokens.
|
|
729
|
+
top_p: Top-p sampling.
|
|
730
|
+
stop: Stop sequences.
|
|
731
|
+
tools: Tool definitions.
|
|
732
|
+
tool_choice: Tool choice setting.
|
|
733
|
+
api_key: Provider API key.
|
|
734
|
+
timeout: Request timeout in seconds.
|
|
735
|
+
|
|
736
|
+
Returns:
|
|
737
|
+
A tuple of (async chunk iterator, result holder list, provider).
|
|
738
|
+
|
|
739
|
+
Raises:
|
|
740
|
+
TraciaError: If the request fails.
|
|
741
|
+
"""
|
|
742
|
+
try:
|
|
743
|
+
import litellm
|
|
744
|
+
except ImportError as e:
|
|
745
|
+
raise TraciaError(
|
|
746
|
+
code=TraciaErrorCode.MISSING_PROVIDER_SDK,
|
|
747
|
+
message="litellm is not installed. Install it with: pip install litellm",
|
|
748
|
+
) from e
|
|
749
|
+
|
|
750
|
+
resolved_provider = resolve_provider(model, provider)
|
|
751
|
+
resolved_api_key = get_provider_api_key(resolved_provider, api_key)
|
|
752
|
+
|
|
753
|
+
# Build the request
|
|
754
|
+
litellm_messages = convert_messages(messages)
|
|
755
|
+
litellm_tools = convert_tools(tools)
|
|
756
|
+
litellm_tool_choice = convert_tool_choice(tool_choice)
|
|
757
|
+
|
|
758
|
+
request_kwargs: dict[str, Any] = {
|
|
759
|
+
"model": get_litellm_model(model, resolved_provider),
|
|
760
|
+
"messages": litellm_messages,
|
|
761
|
+
"api_key": resolved_api_key,
|
|
762
|
+
"stream": True,
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
if temperature is not None:
|
|
766
|
+
request_kwargs["temperature"] = temperature
|
|
767
|
+
if max_tokens is not None:
|
|
768
|
+
request_kwargs["max_tokens"] = max_tokens
|
|
769
|
+
if top_p is not None:
|
|
770
|
+
request_kwargs["top_p"] = top_p
|
|
771
|
+
if stop is not None:
|
|
772
|
+
request_kwargs["stop"] = stop
|
|
773
|
+
if litellm_tools is not None:
|
|
774
|
+
request_kwargs["tools"] = litellm_tools
|
|
775
|
+
if litellm_tool_choice is not None:
|
|
776
|
+
request_kwargs["tool_choice"] = litellm_tool_choice
|
|
777
|
+
if timeout is not None:
|
|
778
|
+
request_kwargs["timeout"] = timeout
|
|
779
|
+
|
|
780
|
+
result_holder: list[CompletionResult] = []
|
|
781
|
+
|
|
782
|
+
async def generate_chunks() -> AsyncIterator[str]:
|
|
783
|
+
full_text = ""
|
|
784
|
+
input_tokens = 0
|
|
785
|
+
output_tokens = 0
|
|
786
|
+
total_tokens = 0
|
|
787
|
+
tool_calls: list[ToolCall] = []
|
|
788
|
+
finish_reason: FinishReason = "stop"
|
|
789
|
+
tool_call_chunks: dict[int, dict[str, Any]] = {}
|
|
790
|
+
|
|
791
|
+
try:
|
|
792
|
+
response = await litellm.acompletion(**request_kwargs)
|
|
793
|
+
|
|
794
|
+
async for chunk in response:
|
|
795
|
+
choices = getattr(chunk, "choices", [])
|
|
796
|
+
if not choices:
|
|
797
|
+
continue
|
|
798
|
+
|
|
799
|
+
delta = getattr(choices[0], "delta", None)
|
|
800
|
+
if delta:
|
|
801
|
+
content = getattr(delta, "content", None)
|
|
802
|
+
if content:
|
|
803
|
+
full_text += content
|
|
804
|
+
yield content
|
|
805
|
+
|
|
806
|
+
# Handle streaming tool calls
|
|
807
|
+
delta_tool_calls = getattr(delta, "tool_calls", None)
|
|
808
|
+
if delta_tool_calls:
|
|
809
|
+
for tc in delta_tool_calls:
|
|
810
|
+
idx = tc.index
|
|
811
|
+
if idx not in tool_call_chunks:
|
|
812
|
+
tool_call_chunks[idx] = {
|
|
813
|
+
"id": "",
|
|
814
|
+
"name": "",
|
|
815
|
+
"arguments": "",
|
|
816
|
+
}
|
|
817
|
+
if tc.id:
|
|
818
|
+
tool_call_chunks[idx]["id"] = tc.id
|
|
819
|
+
if tc.function:
|
|
820
|
+
if tc.function.name:
|
|
821
|
+
tool_call_chunks[idx]["name"] = tc.function.name
|
|
822
|
+
if tc.function.arguments:
|
|
823
|
+
tool_call_chunks[idx]["arguments"] += tc.function.arguments
|
|
824
|
+
|
|
825
|
+
chunk_finish = getattr(choices[0], "finish_reason", None)
|
|
826
|
+
if chunk_finish:
|
|
827
|
+
finish_reason = parse_finish_reason(chunk_finish)
|
|
828
|
+
|
|
829
|
+
# Extract usage from final chunk
|
|
830
|
+
usage = getattr(chunk, "usage", None)
|
|
831
|
+
if usage:
|
|
832
|
+
input_tokens = getattr(usage, "prompt_tokens", 0)
|
|
833
|
+
output_tokens = getattr(usage, "completion_tokens", 0)
|
|
834
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
835
|
+
|
|
836
|
+
# Convert accumulated tool calls
|
|
837
|
+
for idx in sorted(tool_call_chunks.keys()):
|
|
838
|
+
tc_data = tool_call_chunks[idx]
|
|
839
|
+
try:
|
|
840
|
+
args = json.loads(tc_data["arguments"]) if tc_data["arguments"] else {}
|
|
841
|
+
except json.JSONDecodeError:
|
|
842
|
+
args = {}
|
|
843
|
+
tool_calls.append(ToolCall(
|
|
844
|
+
id=tc_data["id"],
|
|
845
|
+
name=tc_data["name"],
|
|
846
|
+
arguments=args,
|
|
847
|
+
))
|
|
848
|
+
|
|
849
|
+
result_holder.append(CompletionResult(
|
|
850
|
+
text=full_text,
|
|
851
|
+
input_tokens=input_tokens,
|
|
852
|
+
output_tokens=output_tokens,
|
|
853
|
+
total_tokens=total_tokens,
|
|
854
|
+
tool_calls=tool_calls,
|
|
855
|
+
finish_reason=finish_reason,
|
|
856
|
+
provider=resolved_provider,
|
|
857
|
+
))
|
|
858
|
+
|
|
859
|
+
except Exception as e:
|
|
860
|
+
error_msg = sanitize_error_message(str(e))
|
|
861
|
+
raise TraciaError(
|
|
862
|
+
code=TraciaErrorCode.PROVIDER_ERROR,
|
|
863
|
+
message=f"LLM provider error: {error_msg}",
|
|
864
|
+
) from e
|
|
865
|
+
|
|
866
|
+
return generate_chunks(), result_holder, resolved_provider
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
def build_assistant_message(
|
|
870
|
+
text: str,
|
|
871
|
+
tool_calls: list[ToolCall],
|
|
872
|
+
) -> LocalPromptMessage:
|
|
873
|
+
"""Build an assistant message from completion result.
|
|
874
|
+
|
|
875
|
+
Args:
|
|
876
|
+
text: The text content.
|
|
877
|
+
tool_calls: Any tool calls made.
|
|
878
|
+
|
|
879
|
+
Returns:
|
|
880
|
+
The assistant message.
|
|
881
|
+
"""
|
|
882
|
+
if not tool_calls:
|
|
883
|
+
return LocalPromptMessage(role="assistant", content=text)
|
|
884
|
+
|
|
885
|
+
content: list[ContentPart] = []
|
|
886
|
+
|
|
887
|
+
if text:
|
|
888
|
+
content.append(TextPart(type="text", text=text))
|
|
889
|
+
|
|
890
|
+
for tc in tool_calls:
|
|
891
|
+
content.append(ToolCallPart(
|
|
892
|
+
type="tool_call",
|
|
893
|
+
id=tc.id,
|
|
894
|
+
name=tc.name,
|
|
895
|
+
arguments=tc.arguments,
|
|
896
|
+
))
|
|
897
|
+
|
|
898
|
+
return LocalPromptMessage(role="assistant", content=content)
|