tracia 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tracia/_llm.py ADDED
@@ -0,0 +1,898 @@
1
+ """LiteLLM wrapper for unified LLM access."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from dataclasses import dataclass
8
+ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator
9
+
10
+ from ._constants import ENV_VAR_MAP
11
+ from ._errors import TraciaError, TraciaErrorCode, sanitize_error_message
12
+ from ._types import (
13
+ ContentPart,
14
+ FinishReason,
15
+ LLMProvider,
16
+ LocalPromptMessage,
17
+ TextPart,
18
+ ToolCall,
19
+ ToolCallPart,
20
+ ToolChoice,
21
+ ToolDefinition,
22
+ )
23
+
24
+ if TYPE_CHECKING:
25
+ from litellm import ModelResponse
26
+
27
+
28
+ @dataclass
29
+ class CompletionResult:
30
+ """Result from an LLM completion."""
31
+
32
+ text: str
33
+ input_tokens: int
34
+ output_tokens: int
35
+ total_tokens: int
36
+ tool_calls: list[ToolCall]
37
+ finish_reason: FinishReason
38
+ provider: LLMProvider
39
+
40
+
41
+ # Model to provider mapping for common models
42
+ _MODEL_PROVIDER_MAP: dict[str, LLMProvider] = {
43
+ # OpenAI
44
+ "gpt-3.5-turbo": LLMProvider.OPENAI,
45
+ "gpt-4": LLMProvider.OPENAI,
46
+ "gpt-4-turbo": LLMProvider.OPENAI,
47
+ "gpt-4o": LLMProvider.OPENAI,
48
+ "gpt-4o-mini": LLMProvider.OPENAI,
49
+ "gpt-4.1": LLMProvider.OPENAI,
50
+ "gpt-4.1-mini": LLMProvider.OPENAI,
51
+ "gpt-4.1-nano": LLMProvider.OPENAI,
52
+ "gpt-4.5-preview": LLMProvider.OPENAI,
53
+ "gpt-5": LLMProvider.OPENAI,
54
+ "o1": LLMProvider.OPENAI,
55
+ "o1-mini": LLMProvider.OPENAI,
56
+ "o1-preview": LLMProvider.OPENAI,
57
+ "o3": LLMProvider.OPENAI,
58
+ "o3-mini": LLMProvider.OPENAI,
59
+ "o4-mini": LLMProvider.OPENAI,
60
+ # Anthropic
61
+ "claude-3-haiku-20240307": LLMProvider.ANTHROPIC,
62
+ "claude-3-sonnet-20240229": LLMProvider.ANTHROPIC,
63
+ "claude-3-opus-20240229": LLMProvider.ANTHROPIC,
64
+ "claude-3-5-haiku-20241022": LLMProvider.ANTHROPIC,
65
+ "claude-3-5-sonnet-20241022": LLMProvider.ANTHROPIC,
66
+ "claude-sonnet-4-20250514": LLMProvider.ANTHROPIC,
67
+ "claude-opus-4-20250514": LLMProvider.ANTHROPIC,
68
+ # Google
69
+ "gemini-2.0-flash": LLMProvider.GOOGLE,
70
+ "gemini-2.0-flash-lite": LLMProvider.GOOGLE,
71
+ "gemini-2.5-pro": LLMProvider.GOOGLE,
72
+ "gemini-2.5-flash": LLMProvider.GOOGLE,
73
+ }
74
+
75
+
76
+ def resolve_provider(model: str, explicit_provider: LLMProvider | None) -> LLMProvider:
77
+ """Resolve the provider for a model.
78
+
79
+ Args:
80
+ model: The model name.
81
+ explicit_provider: Explicitly specified provider.
82
+
83
+ Returns:
84
+ The resolved provider.
85
+
86
+ Raises:
87
+ TraciaError: If the provider cannot be determined.
88
+ """
89
+ if explicit_provider is not None:
90
+ return explicit_provider
91
+
92
+ # Check the model registry
93
+ if model in _MODEL_PROVIDER_MAP:
94
+ return _MODEL_PROVIDER_MAP[model]
95
+
96
+ # Try prefix-based detection
97
+ if model.startswith("gpt-") or model.startswith("o1") or model.startswith("o3") or model.startswith("o4"):
98
+ return LLMProvider.OPENAI
99
+ if model.startswith("claude-"):
100
+ return LLMProvider.ANTHROPIC
101
+ if model.startswith("gemini-"):
102
+ return LLMProvider.GOOGLE
103
+
104
+ raise TraciaError(
105
+ code=TraciaErrorCode.UNSUPPORTED_MODEL,
106
+ message=f"Cannot determine provider for model '{model}'. Please specify the provider explicitly.",
107
+ )
108
+
109
+
110
+ def get_litellm_model(model: str, provider: LLMProvider) -> str:
111
+ """Get the litellm-compatible model name.
112
+
113
+ LiteLLM requires a ``gemini/`` prefix to route Google AI Studio models
114
+ correctly. Without it, litellm defaults to the Vertex AI path which
115
+ requires Application Default Credentials instead of an API key.
116
+
117
+ Args:
118
+ model: The user-facing model name (e.g. ``gemini-2.0-flash``).
119
+ provider: The resolved provider.
120
+
121
+ Returns:
122
+ The model string suitable for ``litellm.completion()``.
123
+ """
124
+ if provider == LLMProvider.GOOGLE and not model.startswith("gemini/"):
125
+ return f"gemini/{model}"
126
+ return model
127
+
128
+
129
+ def get_provider_api_key(
130
+ provider: LLMProvider,
131
+ provider_api_key: str | None = None,
132
+ ) -> str:
133
+ """Get the API key for a provider.
134
+
135
+ Args:
136
+ provider: The LLM provider.
137
+ provider_api_key: Explicitly provided API key.
138
+
139
+ Returns:
140
+ The API key.
141
+
142
+ Raises:
143
+ TraciaError: If no API key is found.
144
+ """
145
+ if provider_api_key:
146
+ return provider_api_key
147
+
148
+ env_var = ENV_VAR_MAP.get(provider.value)
149
+ if env_var:
150
+ key = os.environ.get(env_var)
151
+ if key:
152
+ return key
153
+
154
+ raise TraciaError(
155
+ code=TraciaErrorCode.MISSING_PROVIDER_API_KEY,
156
+ message=f"No API key found for provider '{provider.value}'. "
157
+ f"Set the {ENV_VAR_MAP.get(provider.value, 'PROVIDER_API_KEY')} environment variable "
158
+ "or pass provider_api_key parameter.",
159
+ )
160
+
161
+
162
+ def convert_messages(
163
+ messages: list[LocalPromptMessage],
164
+ ) -> list[dict[str, Any]]:
165
+ """Convert Tracia messages to LiteLLM/OpenAI format.
166
+
167
+ Args:
168
+ messages: The Tracia messages.
169
+
170
+ Returns:
171
+ Messages in LiteLLM format.
172
+ """
173
+ result: list[dict[str, Any]] = []
174
+
175
+ for msg in messages:
176
+ # Handle tool role
177
+ if msg.role == "tool":
178
+ content = msg.content if isinstance(msg.content, str) else str(msg.content)
179
+ result.append({
180
+ "role": "tool",
181
+ "tool_call_id": msg.tool_call_id,
182
+ "content": content,
183
+ })
184
+ continue
185
+
186
+ # Handle developer role (map to system)
187
+ role = "system" if msg.role == "developer" else msg.role
188
+
189
+ # Handle string content
190
+ if isinstance(msg.content, str):
191
+ result.append({"role": role, "content": msg.content})
192
+ continue
193
+
194
+ # Handle list content (text parts and tool calls)
195
+ content_parts: list[dict[str, Any]] = []
196
+ tool_calls: list[dict[str, Any]] = []
197
+
198
+ for part in msg.content:
199
+ if isinstance(part, TextPart) or (isinstance(part, dict) and part.get("type") == "text"):
200
+ text = part.text if isinstance(part, TextPart) else part.get("text", "")
201
+ content_parts.append({"type": "text", "text": text})
202
+ elif isinstance(part, ToolCallPart) or (isinstance(part, dict) and part.get("type") == "tool_call"):
203
+ if isinstance(part, ToolCallPart):
204
+ tc_id = part.id
205
+ tc_name = part.name
206
+ tc_args = part.arguments
207
+ else:
208
+ tc_id = part.get("id", "")
209
+ tc_name = part.get("name", "")
210
+ tc_args = part.get("arguments", {})
211
+
212
+ tool_calls.append({
213
+ "id": tc_id,
214
+ "type": "function",
215
+ "function": {
216
+ "name": tc_name,
217
+ "arguments": json.dumps(tc_args) if isinstance(tc_args, dict) else tc_args,
218
+ },
219
+ })
220
+
221
+ # Build the message
222
+ msg_dict: dict[str, Any] = {"role": role}
223
+
224
+ if content_parts:
225
+ # If only text parts, we can simplify
226
+ if len(content_parts) == 1 and not tool_calls:
227
+ msg_dict["content"] = content_parts[0]["text"]
228
+ else:
229
+ msg_dict["content"] = content_parts
230
+ elif not tool_calls:
231
+ msg_dict["content"] = ""
232
+
233
+ if tool_calls:
234
+ msg_dict["tool_calls"] = tool_calls
235
+
236
+ result.append(msg_dict)
237
+
238
+ return result
239
+
240
+
241
+ def convert_tools(tools: list[ToolDefinition] | None) -> list[dict[str, Any]] | None:
242
+ """Convert tool definitions to LiteLLM format.
243
+
244
+ Args:
245
+ tools: The tool definitions.
246
+
247
+ Returns:
248
+ Tools in LiteLLM format.
249
+ """
250
+ if not tools:
251
+ return None
252
+
253
+ result = []
254
+ for tool in tools:
255
+ result.append({
256
+ "type": "function",
257
+ "function": {
258
+ "name": tool.name,
259
+ "description": tool.description,
260
+ "parameters": tool.parameters.model_dump(exclude_none=True),
261
+ },
262
+ })
263
+ return result
264
+
265
+
266
+ def convert_tool_choice(tool_choice: ToolChoice | None) -> str | dict[str, Any] | None:
267
+ """Convert tool choice to LiteLLM format.
268
+
269
+ Args:
270
+ tool_choice: The tool choice.
271
+
272
+ Returns:
273
+ Tool choice in LiteLLM format.
274
+ """
275
+ if tool_choice is None:
276
+ return None
277
+
278
+ if isinstance(tool_choice, str):
279
+ return tool_choice
280
+
281
+ if isinstance(tool_choice, dict) and "tool" in tool_choice:
282
+ return {"type": "function", "function": {"name": tool_choice["tool"]}}
283
+
284
+ return None
285
+
286
+
287
+ def parse_finish_reason(reason: str | None) -> FinishReason:
288
+ """Parse the finish reason from LiteLLM response.
289
+
290
+ Args:
291
+ reason: The raw finish reason.
292
+
293
+ Returns:
294
+ The normalized finish reason.
295
+ """
296
+ if reason == "tool_calls":
297
+ return "tool_calls"
298
+ if reason == "length":
299
+ return "max_tokens"
300
+ return "stop"
301
+
302
+
303
+ def extract_tool_calls(response: "ModelResponse") -> list[ToolCall]:
304
+ """Extract tool calls from a LiteLLM response.
305
+
306
+ Args:
307
+ response: The LiteLLM response.
308
+
309
+ Returns:
310
+ The extracted tool calls.
311
+ """
312
+ tool_calls: list[ToolCall] = []
313
+
314
+ choices = getattr(response, "choices", [])
315
+ if not choices:
316
+ return tool_calls
317
+
318
+ message = getattr(choices[0], "message", None)
319
+ if not message:
320
+ return tool_calls
321
+
322
+ raw_tool_calls = getattr(message, "tool_calls", None)
323
+ if not raw_tool_calls:
324
+ return tool_calls
325
+
326
+ for tc in raw_tool_calls:
327
+ func = getattr(tc, "function", None)
328
+ if func:
329
+ try:
330
+ args = json.loads(func.arguments) if isinstance(func.arguments, str) else func.arguments
331
+ except json.JSONDecodeError:
332
+ args = {}
333
+
334
+ tool_calls.append(ToolCall(
335
+ id=tc.id,
336
+ name=func.name,
337
+ arguments=args,
338
+ ))
339
+
340
+ return tool_calls
341
+
342
+
343
+ class LLMClient:
344
+ """Client for making LLM calls via LiteLLM."""
345
+
346
+ def complete(
347
+ self,
348
+ model: str,
349
+ messages: list[LocalPromptMessage],
350
+ *,
351
+ provider: LLMProvider | None = None,
352
+ temperature: float | None = None,
353
+ max_tokens: int | None = None,
354
+ top_p: float | None = None,
355
+ stop: list[str] | None = None,
356
+ tools: list[ToolDefinition] | None = None,
357
+ tool_choice: ToolChoice | None = None,
358
+ api_key: str | None = None,
359
+ timeout: float | None = None,
360
+ ) -> CompletionResult:
361
+ """Make a synchronous completion request.
362
+
363
+ Args:
364
+ model: The model name.
365
+ messages: The messages to send.
366
+ provider: The LLM provider.
367
+ temperature: Sampling temperature.
368
+ max_tokens: Maximum output tokens.
369
+ top_p: Top-p sampling.
370
+ stop: Stop sequences.
371
+ tools: Tool definitions.
372
+ tool_choice: Tool choice setting.
373
+ api_key: Provider API key.
374
+ timeout: Request timeout in seconds.
375
+
376
+ Returns:
377
+ The completion result.
378
+
379
+ Raises:
380
+ TraciaError: If the request fails.
381
+ """
382
+ try:
383
+ import litellm
384
+ except ImportError as e:
385
+ raise TraciaError(
386
+ code=TraciaErrorCode.MISSING_PROVIDER_SDK,
387
+ message="litellm is not installed. Install it with: pip install litellm",
388
+ ) from e
389
+
390
+ resolved_provider = resolve_provider(model, provider)
391
+ resolved_api_key = get_provider_api_key(resolved_provider, api_key)
392
+
393
+ # Build the request
394
+ litellm_messages = convert_messages(messages)
395
+ litellm_tools = convert_tools(tools)
396
+ litellm_tool_choice = convert_tool_choice(tool_choice)
397
+
398
+ request_kwargs: dict[str, Any] = {
399
+ "model": get_litellm_model(model, resolved_provider),
400
+ "messages": litellm_messages,
401
+ "api_key": resolved_api_key,
402
+ }
403
+
404
+ if temperature is not None:
405
+ request_kwargs["temperature"] = temperature
406
+ if max_tokens is not None:
407
+ request_kwargs["max_tokens"] = max_tokens
408
+ if top_p is not None:
409
+ request_kwargs["top_p"] = top_p
410
+ if stop is not None:
411
+ request_kwargs["stop"] = stop
412
+ if litellm_tools is not None:
413
+ request_kwargs["tools"] = litellm_tools
414
+ if litellm_tool_choice is not None:
415
+ request_kwargs["tool_choice"] = litellm_tool_choice
416
+ if timeout is not None:
417
+ request_kwargs["timeout"] = timeout
418
+
419
+ try:
420
+ response = litellm.completion(**request_kwargs)
421
+ except Exception as e:
422
+ error_msg = sanitize_error_message(str(e))
423
+ raise TraciaError(
424
+ code=TraciaErrorCode.PROVIDER_ERROR,
425
+ message=f"LLM provider error: {error_msg}",
426
+ ) from e
427
+
428
+ # Extract result
429
+ usage = getattr(response, "usage", None)
430
+ choices = getattr(response, "choices", [])
431
+ message = choices[0].message if choices else None
432
+ content = getattr(message, "content", "") or ""
433
+ finish_reason = choices[0].finish_reason if choices else "stop"
434
+
435
+ return CompletionResult(
436
+ text=content,
437
+ input_tokens=getattr(usage, "prompt_tokens", 0) if usage else 0,
438
+ output_tokens=getattr(usage, "completion_tokens", 0) if usage else 0,
439
+ total_tokens=getattr(usage, "total_tokens", 0) if usage else 0,
440
+ tool_calls=extract_tool_calls(response),
441
+ finish_reason=parse_finish_reason(finish_reason),
442
+ provider=resolved_provider,
443
+ )
444
+
445
+ async def acomplete(
446
+ self,
447
+ model: str,
448
+ messages: list[LocalPromptMessage],
449
+ *,
450
+ provider: LLMProvider | None = None,
451
+ temperature: float | None = None,
452
+ max_tokens: int | None = None,
453
+ top_p: float | None = None,
454
+ stop: list[str] | None = None,
455
+ tools: list[ToolDefinition] | None = None,
456
+ tool_choice: ToolChoice | None = None,
457
+ api_key: str | None = None,
458
+ timeout: float | None = None,
459
+ ) -> CompletionResult:
460
+ """Make an asynchronous completion request.
461
+
462
+ Args:
463
+ model: The model name.
464
+ messages: The messages to send.
465
+ provider: The LLM provider.
466
+ temperature: Sampling temperature.
467
+ max_tokens: Maximum output tokens.
468
+ top_p: Top-p sampling.
469
+ stop: Stop sequences.
470
+ tools: Tool definitions.
471
+ tool_choice: Tool choice setting.
472
+ api_key: Provider API key.
473
+ timeout: Request timeout in seconds.
474
+
475
+ Returns:
476
+ The completion result.
477
+
478
+ Raises:
479
+ TraciaError: If the request fails.
480
+ """
481
+ try:
482
+ import litellm
483
+ except ImportError as e:
484
+ raise TraciaError(
485
+ code=TraciaErrorCode.MISSING_PROVIDER_SDK,
486
+ message="litellm is not installed. Install it with: pip install litellm",
487
+ ) from e
488
+
489
+ resolved_provider = resolve_provider(model, provider)
490
+ resolved_api_key = get_provider_api_key(resolved_provider, api_key)
491
+
492
+ # Build the request
493
+ litellm_messages = convert_messages(messages)
494
+ litellm_tools = convert_tools(tools)
495
+ litellm_tool_choice = convert_tool_choice(tool_choice)
496
+
497
+ request_kwargs: dict[str, Any] = {
498
+ "model": get_litellm_model(model, resolved_provider),
499
+ "messages": litellm_messages,
500
+ "api_key": resolved_api_key,
501
+ }
502
+
503
+ if temperature is not None:
504
+ request_kwargs["temperature"] = temperature
505
+ if max_tokens is not None:
506
+ request_kwargs["max_tokens"] = max_tokens
507
+ if top_p is not None:
508
+ request_kwargs["top_p"] = top_p
509
+ if stop is not None:
510
+ request_kwargs["stop"] = stop
511
+ if litellm_tools is not None:
512
+ request_kwargs["tools"] = litellm_tools
513
+ if litellm_tool_choice is not None:
514
+ request_kwargs["tool_choice"] = litellm_tool_choice
515
+ if timeout is not None:
516
+ request_kwargs["timeout"] = timeout
517
+
518
+ try:
519
+ response = await litellm.acompletion(**request_kwargs)
520
+ except Exception as e:
521
+ error_msg = sanitize_error_message(str(e))
522
+ raise TraciaError(
523
+ code=TraciaErrorCode.PROVIDER_ERROR,
524
+ message=f"LLM provider error: {error_msg}",
525
+ ) from e
526
+
527
+ # Extract result
528
+ usage = getattr(response, "usage", None)
529
+ choices = getattr(response, "choices", [])
530
+ message = choices[0].message if choices else None
531
+ content = getattr(message, "content", "") or ""
532
+ finish_reason = choices[0].finish_reason if choices else "stop"
533
+
534
+ return CompletionResult(
535
+ text=content,
536
+ input_tokens=getattr(usage, "prompt_tokens", 0) if usage else 0,
537
+ output_tokens=getattr(usage, "completion_tokens", 0) if usage else 0,
538
+ total_tokens=getattr(usage, "total_tokens", 0) if usage else 0,
539
+ tool_calls=extract_tool_calls(response),
540
+ finish_reason=parse_finish_reason(finish_reason),
541
+ provider=resolved_provider,
542
+ )
543
+
544
+ def stream(
545
+ self,
546
+ model: str,
547
+ messages: list[LocalPromptMessage],
548
+ *,
549
+ provider: LLMProvider | None = None,
550
+ temperature: float | None = None,
551
+ max_tokens: int | None = None,
552
+ top_p: float | None = None,
553
+ stop: list[str] | None = None,
554
+ tools: list[ToolDefinition] | None = None,
555
+ tool_choice: ToolChoice | None = None,
556
+ api_key: str | None = None,
557
+ timeout: float | None = None,
558
+ ) -> tuple[Iterator[str], list[CompletionResult], LLMProvider]:
559
+ """Make a streaming completion request.
560
+
561
+ Args:
562
+ model: The model name.
563
+ messages: The messages to send.
564
+ provider: The LLM provider.
565
+ temperature: Sampling temperature.
566
+ max_tokens: Maximum output tokens.
567
+ top_p: Top-p sampling.
568
+ stop: Stop sequences.
569
+ tools: Tool definitions.
570
+ tool_choice: Tool choice setting.
571
+ api_key: Provider API key.
572
+ timeout: Request timeout in seconds.
573
+
574
+ Returns:
575
+ A tuple of (chunk iterator, result holder list, provider).
576
+
577
+ Raises:
578
+ TraciaError: If the request fails.
579
+ """
580
+ try:
581
+ import litellm
582
+ except ImportError as e:
583
+ raise TraciaError(
584
+ code=TraciaErrorCode.MISSING_PROVIDER_SDK,
585
+ message="litellm is not installed. Install it with: pip install litellm",
586
+ ) from e
587
+
588
+ resolved_provider = resolve_provider(model, provider)
589
+ resolved_api_key = get_provider_api_key(resolved_provider, api_key)
590
+
591
+ # Build the request
592
+ litellm_messages = convert_messages(messages)
593
+ litellm_tools = convert_tools(tools)
594
+ litellm_tool_choice = convert_tool_choice(tool_choice)
595
+
596
+ request_kwargs: dict[str, Any] = {
597
+ "model": get_litellm_model(model, resolved_provider),
598
+ "messages": litellm_messages,
599
+ "api_key": resolved_api_key,
600
+ "stream": True,
601
+ }
602
+
603
+ if temperature is not None:
604
+ request_kwargs["temperature"] = temperature
605
+ if max_tokens is not None:
606
+ request_kwargs["max_tokens"] = max_tokens
607
+ if top_p is not None:
608
+ request_kwargs["top_p"] = top_p
609
+ if stop is not None:
610
+ request_kwargs["stop"] = stop
611
+ if litellm_tools is not None:
612
+ request_kwargs["tools"] = litellm_tools
613
+ if litellm_tool_choice is not None:
614
+ request_kwargs["tool_choice"] = litellm_tool_choice
615
+ if timeout is not None:
616
+ request_kwargs["timeout"] = timeout
617
+
618
+ result_holder: list[CompletionResult] = []
619
+
620
+ def generate_chunks() -> Iterator[str]:
621
+ full_text = ""
622
+ input_tokens = 0
623
+ output_tokens = 0
624
+ total_tokens = 0
625
+ tool_calls: list[ToolCall] = []
626
+ finish_reason: FinishReason = "stop"
627
+ tool_call_chunks: dict[int, dict[str, Any]] = {}
628
+
629
+ try:
630
+ response = litellm.completion(**request_kwargs)
631
+
632
+ for chunk in response:
633
+ choices = getattr(chunk, "choices", [])
634
+ if not choices:
635
+ continue
636
+
637
+ delta = getattr(choices[0], "delta", None)
638
+ if delta:
639
+ content = getattr(delta, "content", None)
640
+ if content:
641
+ full_text += content
642
+ yield content
643
+
644
+ # Handle streaming tool calls
645
+ delta_tool_calls = getattr(delta, "tool_calls", None)
646
+ if delta_tool_calls:
647
+ for tc in delta_tool_calls:
648
+ idx = tc.index
649
+ if idx not in tool_call_chunks:
650
+ tool_call_chunks[idx] = {
651
+ "id": "",
652
+ "name": "",
653
+ "arguments": "",
654
+ }
655
+ if tc.id:
656
+ tool_call_chunks[idx]["id"] = tc.id
657
+ if tc.function:
658
+ if tc.function.name:
659
+ tool_call_chunks[idx]["name"] = tc.function.name
660
+ if tc.function.arguments:
661
+ tool_call_chunks[idx]["arguments"] += tc.function.arguments
662
+
663
+ chunk_finish = getattr(choices[0], "finish_reason", None)
664
+ if chunk_finish:
665
+ finish_reason = parse_finish_reason(chunk_finish)
666
+
667
+ # Extract usage from final chunk
668
+ usage = getattr(chunk, "usage", None)
669
+ if usage:
670
+ input_tokens = getattr(usage, "prompt_tokens", 0)
671
+ output_tokens = getattr(usage, "completion_tokens", 0)
672
+ total_tokens = getattr(usage, "total_tokens", 0)
673
+
674
+ # Convert accumulated tool calls
675
+ for idx in sorted(tool_call_chunks.keys()):
676
+ tc_data = tool_call_chunks[idx]
677
+ try:
678
+ args = json.loads(tc_data["arguments"]) if tc_data["arguments"] else {}
679
+ except json.JSONDecodeError:
680
+ args = {}
681
+ tool_calls.append(ToolCall(
682
+ id=tc_data["id"],
683
+ name=tc_data["name"],
684
+ arguments=args,
685
+ ))
686
+
687
+ result_holder.append(CompletionResult(
688
+ text=full_text,
689
+ input_tokens=input_tokens,
690
+ output_tokens=output_tokens,
691
+ total_tokens=total_tokens,
692
+ tool_calls=tool_calls,
693
+ finish_reason=finish_reason,
694
+ provider=resolved_provider,
695
+ ))
696
+
697
+ except Exception as e:
698
+ error_msg = sanitize_error_message(str(e))
699
+ raise TraciaError(
700
+ code=TraciaErrorCode.PROVIDER_ERROR,
701
+ message=f"LLM provider error: {error_msg}",
702
+ ) from e
703
+
704
+ return generate_chunks(), result_holder, resolved_provider
705
+
706
+ async def astream(
707
+ self,
708
+ model: str,
709
+ messages: list[LocalPromptMessage],
710
+ *,
711
+ provider: LLMProvider | None = None,
712
+ temperature: float | None = None,
713
+ max_tokens: int | None = None,
714
+ top_p: float | None = None,
715
+ stop: list[str] | None = None,
716
+ tools: list[ToolDefinition] | None = None,
717
+ tool_choice: ToolChoice | None = None,
718
+ api_key: str | None = None,
719
+ timeout: float | None = None,
720
+ ) -> tuple[AsyncIterator[str], list[CompletionResult], LLMProvider]:
721
+ """Make an async streaming completion request.
722
+
723
+ Args:
724
+ model: The model name.
725
+ messages: The messages to send.
726
+ provider: The LLM provider.
727
+ temperature: Sampling temperature.
728
+ max_tokens: Maximum output tokens.
729
+ top_p: Top-p sampling.
730
+ stop: Stop sequences.
731
+ tools: Tool definitions.
732
+ tool_choice: Tool choice setting.
733
+ api_key: Provider API key.
734
+ timeout: Request timeout in seconds.
735
+
736
+ Returns:
737
+ A tuple of (async chunk iterator, result holder list, provider).
738
+
739
+ Raises:
740
+ TraciaError: If the request fails.
741
+ """
742
+ try:
743
+ import litellm
744
+ except ImportError as e:
745
+ raise TraciaError(
746
+ code=TraciaErrorCode.MISSING_PROVIDER_SDK,
747
+ message="litellm is not installed. Install it with: pip install litellm",
748
+ ) from e
749
+
750
+ resolved_provider = resolve_provider(model, provider)
751
+ resolved_api_key = get_provider_api_key(resolved_provider, api_key)
752
+
753
+ # Build the request
754
+ litellm_messages = convert_messages(messages)
755
+ litellm_tools = convert_tools(tools)
756
+ litellm_tool_choice = convert_tool_choice(tool_choice)
757
+
758
+ request_kwargs: dict[str, Any] = {
759
+ "model": get_litellm_model(model, resolved_provider),
760
+ "messages": litellm_messages,
761
+ "api_key": resolved_api_key,
762
+ "stream": True,
763
+ }
764
+
765
+ if temperature is not None:
766
+ request_kwargs["temperature"] = temperature
767
+ if max_tokens is not None:
768
+ request_kwargs["max_tokens"] = max_tokens
769
+ if top_p is not None:
770
+ request_kwargs["top_p"] = top_p
771
+ if stop is not None:
772
+ request_kwargs["stop"] = stop
773
+ if litellm_tools is not None:
774
+ request_kwargs["tools"] = litellm_tools
775
+ if litellm_tool_choice is not None:
776
+ request_kwargs["tool_choice"] = litellm_tool_choice
777
+ if timeout is not None:
778
+ request_kwargs["timeout"] = timeout
779
+
780
+ result_holder: list[CompletionResult] = []
781
+
782
+ async def generate_chunks() -> AsyncIterator[str]:
783
+ full_text = ""
784
+ input_tokens = 0
785
+ output_tokens = 0
786
+ total_tokens = 0
787
+ tool_calls: list[ToolCall] = []
788
+ finish_reason: FinishReason = "stop"
789
+ tool_call_chunks: dict[int, dict[str, Any]] = {}
790
+
791
+ try:
792
+ response = await litellm.acompletion(**request_kwargs)
793
+
794
+ async for chunk in response:
795
+ choices = getattr(chunk, "choices", [])
796
+ if not choices:
797
+ continue
798
+
799
+ delta = getattr(choices[0], "delta", None)
800
+ if delta:
801
+ content = getattr(delta, "content", None)
802
+ if content:
803
+ full_text += content
804
+ yield content
805
+
806
+ # Handle streaming tool calls
807
+ delta_tool_calls = getattr(delta, "tool_calls", None)
808
+ if delta_tool_calls:
809
+ for tc in delta_tool_calls:
810
+ idx = tc.index
811
+ if idx not in tool_call_chunks:
812
+ tool_call_chunks[idx] = {
813
+ "id": "",
814
+ "name": "",
815
+ "arguments": "",
816
+ }
817
+ if tc.id:
818
+ tool_call_chunks[idx]["id"] = tc.id
819
+ if tc.function:
820
+ if tc.function.name:
821
+ tool_call_chunks[idx]["name"] = tc.function.name
822
+ if tc.function.arguments:
823
+ tool_call_chunks[idx]["arguments"] += tc.function.arguments
824
+
825
+ chunk_finish = getattr(choices[0], "finish_reason", None)
826
+ if chunk_finish:
827
+ finish_reason = parse_finish_reason(chunk_finish)
828
+
829
+ # Extract usage from final chunk
830
+ usage = getattr(chunk, "usage", None)
831
+ if usage:
832
+ input_tokens = getattr(usage, "prompt_tokens", 0)
833
+ output_tokens = getattr(usage, "completion_tokens", 0)
834
+ total_tokens = getattr(usage, "total_tokens", 0)
835
+
836
+ # Convert accumulated tool calls
837
+ for idx in sorted(tool_call_chunks.keys()):
838
+ tc_data = tool_call_chunks[idx]
839
+ try:
840
+ args = json.loads(tc_data["arguments"]) if tc_data["arguments"] else {}
841
+ except json.JSONDecodeError:
842
+ args = {}
843
+ tool_calls.append(ToolCall(
844
+ id=tc_data["id"],
845
+ name=tc_data["name"],
846
+ arguments=args,
847
+ ))
848
+
849
+ result_holder.append(CompletionResult(
850
+ text=full_text,
851
+ input_tokens=input_tokens,
852
+ output_tokens=output_tokens,
853
+ total_tokens=total_tokens,
854
+ tool_calls=tool_calls,
855
+ finish_reason=finish_reason,
856
+ provider=resolved_provider,
857
+ ))
858
+
859
+ except Exception as e:
860
+ error_msg = sanitize_error_message(str(e))
861
+ raise TraciaError(
862
+ code=TraciaErrorCode.PROVIDER_ERROR,
863
+ message=f"LLM provider error: {error_msg}",
864
+ ) from e
865
+
866
+ return generate_chunks(), result_holder, resolved_provider
867
+
868
+
869
+ def build_assistant_message(
870
+ text: str,
871
+ tool_calls: list[ToolCall],
872
+ ) -> LocalPromptMessage:
873
+ """Build an assistant message from completion result.
874
+
875
+ Args:
876
+ text: The text content.
877
+ tool_calls: Any tool calls made.
878
+
879
+ Returns:
880
+ The assistant message.
881
+ """
882
+ if not tool_calls:
883
+ return LocalPromptMessage(role="assistant", content=text)
884
+
885
+ content: list[ContentPart] = []
886
+
887
+ if text:
888
+ content.append(TextPart(type="text", text=text))
889
+
890
+ for tc in tool_calls:
891
+ content.append(ToolCallPart(
892
+ type="tool_call",
893
+ id=tc.id,
894
+ name=tc.name,
895
+ arguments=tc.arguments,
896
+ ))
897
+
898
+ return LocalPromptMessage(role="assistant", content=content)