tracia 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tracia/_client.py ADDED
@@ -0,0 +1,1100 @@
1
+ """Main Tracia client implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import time
7
+ import warnings
8
+ from concurrent.futures import Future, ThreadPoolExecutor
9
+ from threading import Event, Lock
10
+ from typing import Any, Callable, Literal, overload
11
+
12
+ from ._constants import (
13
+ BASE_URL,
14
+ DEFAULT_TIMEOUT_MS,
15
+ MAX_PENDING_SPANS,
16
+ SPAN_RETRY_ATTEMPTS,
17
+ SPAN_RETRY_DELAY_MS,
18
+ SPAN_STATUS_ERROR,
19
+ SPAN_STATUS_SUCCESS,
20
+ )
21
+ from ._errors import TraciaError, TraciaErrorCode, sanitize_error_message
22
+ from ._http import AsyncHttpClient, HttpClient
23
+ from ._llm import LLMClient, build_assistant_message, resolve_provider
24
+ from ._session import TraciaSession
25
+ from ._streaming import AsyncLocalStream, LocalStream
26
+ from ._types import (
27
+ CreateSpanPayload,
28
+ LocalPromptMessage,
29
+ LLMProvider,
30
+ RunLocalResult,
31
+ StreamResult,
32
+ TokenUsage,
33
+ ToolCall,
34
+ ToolDefinition,
35
+ ToolChoice,
36
+ )
37
+ from ._utils import (
38
+ generate_span_id,
39
+ generate_trace_id,
40
+ interpolate_message_content,
41
+ is_valid_span_id_format,
42
+ )
43
+ from .resources import Prompts, Spans
44
+
45
+
46
+ class Tracia:
47
+ """Main Tracia client for LLM prompt management and tracing.
48
+
49
+ Example usage:
50
+ ```python
51
+ from tracia import Tracia
52
+
53
+ client = Tracia(api_key="your_api_key")
54
+
55
+ # Run a local prompt
56
+ result = client.run_local(
57
+ model="gpt-4o",
58
+ messages=[{"role": "user", "content": "Hello!"}]
59
+ )
60
+ print(result.text)
61
+
62
+ # Run with streaming
63
+ stream = client.run_local(
64
+ model="gpt-4o",
65
+ messages=[{"role": "user", "content": "Tell me a story"}],
66
+ stream=True
67
+ )
68
+ for chunk in stream:
69
+ print(chunk, end="")
70
+
71
+ # Create a session for multi-turn conversations
72
+ session = client.create_session()
73
+ r1 = session.run_local(model="gpt-4o", messages=[...])
74
+ r2 = session.run_local(model="gpt-4o", messages=[...]) # Linked
75
+
76
+ # Clean up
77
+ client.close()
78
+ ```
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ api_key: str,
84
+ *,
85
+ base_url: str = BASE_URL,
86
+ on_span_error: Callable[[Exception, str], None] | None = None,
87
+ ) -> None:
88
+ """Initialize the Tracia client.
89
+
90
+ Args:
91
+ api_key: Your Tracia API key.
92
+ base_url: The base URL for the Tracia API.
93
+ on_span_error: Optional callback for span creation errors.
94
+ """
95
+ self._api_key = api_key
96
+ self._base_url = base_url
97
+ self._on_span_error = on_span_error
98
+
99
+ # HTTP clients
100
+ self._http_client = HttpClient(api_key, base_url)
101
+ self._async_http_client: AsyncHttpClient | None = None
102
+
103
+ # LLM client
104
+ self._llm_client = LLMClient()
105
+
106
+ # Resources
107
+ self.prompts = Prompts(self._http_client)
108
+ self.spans = Spans(self._http_client)
109
+
110
+ # Pending spans management
111
+ self._pending_spans: dict[str, Future[None]] = {}
112
+ self._pending_spans_lock = Lock()
113
+ self._executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix="tracia-span-")
114
+ self._closed = False
115
+
116
+ def _get_async_http_client(self) -> AsyncHttpClient:
117
+ """Get or create the async HTTP client."""
118
+ if self._async_http_client is None:
119
+ self._async_http_client = AsyncHttpClient(self._api_key, self._base_url)
120
+ # Update resources with async client
121
+ self.prompts._async_client = self._async_http_client
122
+ self.spans._async_client = self._async_http_client
123
+ return self._async_http_client
124
+
125
+ def _validate_run_local_input(
126
+ self,
127
+ messages: list[dict[str, Any] | LocalPromptMessage],
128
+ model: str,
129
+ span_id: str | None,
130
+ ) -> None:
131
+ """Validate run_local input parameters."""
132
+ if not model or not model.strip():
133
+ raise TraciaError(
134
+ code=TraciaErrorCode.INVALID_REQUEST,
135
+ message="model is required and cannot be empty",
136
+ )
137
+
138
+ if not messages or len(messages) == 0:
139
+ raise TraciaError(
140
+ code=TraciaErrorCode.INVALID_REQUEST,
141
+ message="messages is required and cannot be empty",
142
+ )
143
+
144
+ if span_id is not None and not is_valid_span_id_format(span_id):
145
+ raise TraciaError(
146
+ code=TraciaErrorCode.INVALID_REQUEST,
147
+ message=f"Invalid span_id format: {span_id}. Expected sp_XXXXXXXXXXXXXXXX or tr_XXXXXXXXXXXXXXXX",
148
+ )
149
+
150
+ # Validate tool messages have tool_call_id
151
+ for msg in messages:
152
+ if isinstance(msg, dict):
153
+ if msg.get("role") == "tool" and not msg.get("tool_call_id") and not msg.get("toolCallId"):
154
+ raise TraciaError(
155
+ code=TraciaErrorCode.INVALID_REQUEST,
156
+ message="Tool messages must have tool_call_id",
157
+ )
158
+ elif isinstance(msg, LocalPromptMessage):
159
+ if msg.role == "tool" and not msg.tool_call_id:
160
+ raise TraciaError(
161
+ code=TraciaErrorCode.INVALID_REQUEST,
162
+ message="Tool messages must have tool_call_id",
163
+ )
164
+
165
+ def _convert_messages(
166
+ self,
167
+ messages: list[dict[str, Any] | LocalPromptMessage],
168
+ variables: dict[str, str] | None,
169
+ ) -> list[LocalPromptMessage]:
170
+ """Convert dict messages to LocalPromptMessage and interpolate variables."""
171
+ result: list[LocalPromptMessage] = []
172
+
173
+ for msg in messages:
174
+ if isinstance(msg, LocalPromptMessage):
175
+ pm = msg
176
+ else:
177
+ pm = LocalPromptMessage(
178
+ role=msg["role"],
179
+ content=msg.get("content", ""),
180
+ tool_call_id=msg.get("tool_call_id") or msg.get("toolCallId"),
181
+ tool_name=msg.get("tool_name") or msg.get("toolName"),
182
+ )
183
+
184
+ # Interpolate variables (skip tool messages)
185
+ if pm.role != "tool" and variables:
186
+ pm = LocalPromptMessage(
187
+ role=pm.role,
188
+ content=interpolate_message_content(pm.content, variables),
189
+ tool_call_id=pm.tool_call_id,
190
+ tool_name=pm.tool_name,
191
+ )
192
+
193
+ result.append(pm)
194
+
195
+ return result
196
+
197
+ def _schedule_span_creation(
198
+ self,
199
+ payload: CreateSpanPayload,
200
+ ) -> None:
201
+ """Schedule span creation in the background with retry logic."""
202
+ span_id = payload.span_id
203
+
204
+ def create_span_with_retry() -> None:
205
+ last_error: Exception | None = None
206
+
207
+ for attempt in range(SPAN_RETRY_ATTEMPTS + 1):
208
+ try:
209
+ self._http_client.post(
210
+ "/api/v1/spans",
211
+ payload.model_dump(by_alias=True, exclude_none=True),
212
+ )
213
+ return
214
+ except Exception as e:
215
+ last_error = e
216
+ if attempt < SPAN_RETRY_ATTEMPTS:
217
+ delay = SPAN_RETRY_DELAY_MS * (attempt + 1) / 1000.0
218
+ time.sleep(delay)
219
+
220
+ # All retries failed
221
+ if self._on_span_error and last_error:
222
+ try:
223
+ self._on_span_error(last_error, span_id)
224
+ except Exception:
225
+ pass
226
+
227
+ # Evict old spans if at capacity
228
+ with self._pending_spans_lock:
229
+ if len(self._pending_spans) >= MAX_PENDING_SPANS:
230
+ # Remove the oldest span
231
+ oldest_key = next(iter(self._pending_spans))
232
+ del self._pending_spans[oldest_key]
233
+
234
+ future = self._executor.submit(create_span_with_retry)
235
+ self._pending_spans[span_id] = future
236
+
237
+ # Clean up when done
238
+ def on_done(f: Future[None]) -> None:
239
+ with self._pending_spans_lock:
240
+ self._pending_spans.pop(span_id, None)
241
+
242
+ future.add_done_callback(on_done)
243
+
244
+ @overload
245
+ def run_local(
246
+ self,
247
+ *,
248
+ messages: list[dict[str, Any] | LocalPromptMessage],
249
+ model: str,
250
+ stream: Literal[False] = False,
251
+ provider: LLMProvider | None = None,
252
+ temperature: float | None = None,
253
+ max_output_tokens: int | None = None,
254
+ top_p: float | None = None,
255
+ stop_sequences: list[str] | None = None,
256
+ timeout_ms: int | None = None,
257
+ variables: dict[str, str] | None = None,
258
+ provider_api_key: str | None = None,
259
+ tags: list[str] | None = None,
260
+ user_id: str | None = None,
261
+ session_id: str | None = None,
262
+ send_trace: bool | None = None,
263
+ span_id: str | None = None,
264
+ tools: list[ToolDefinition] | None = None,
265
+ tool_choice: ToolChoice | None = None,
266
+ trace_id: str | None = None,
267
+ parent_span_id: str | None = None,
268
+ ) -> RunLocalResult: ...
269
+
270
+ @overload
271
+ def run_local(
272
+ self,
273
+ *,
274
+ messages: list[dict[str, Any] | LocalPromptMessage],
275
+ model: str,
276
+ stream: Literal[True],
277
+ provider: LLMProvider | None = None,
278
+ temperature: float | None = None,
279
+ max_output_tokens: int | None = None,
280
+ top_p: float | None = None,
281
+ stop_sequences: list[str] | None = None,
282
+ timeout_ms: int | None = None,
283
+ variables: dict[str, str] | None = None,
284
+ provider_api_key: str | None = None,
285
+ tags: list[str] | None = None,
286
+ user_id: str | None = None,
287
+ session_id: str | None = None,
288
+ send_trace: bool | None = None,
289
+ span_id: str | None = None,
290
+ tools: list[ToolDefinition] | None = None,
291
+ tool_choice: ToolChoice | None = None,
292
+ trace_id: str | None = None,
293
+ parent_span_id: str | None = None,
294
+ ) -> LocalStream: ...
295
+
296
+ def run_local(
297
+ self,
298
+ *,
299
+ messages: list[dict[str, Any] | LocalPromptMessage],
300
+ model: str,
301
+ stream: bool = False,
302
+ provider: LLMProvider | None = None,
303
+ temperature: float | None = None,
304
+ max_output_tokens: int | None = None,
305
+ top_p: float | None = None,
306
+ stop_sequences: list[str] | None = None,
307
+ timeout_ms: int | None = None,
308
+ variables: dict[str, str] | None = None,
309
+ provider_api_key: str | None = None,
310
+ tags: list[str] | None = None,
311
+ user_id: str | None = None,
312
+ session_id: str | None = None,
313
+ send_trace: bool | None = None,
314
+ span_id: str | None = None,
315
+ tools: list[ToolDefinition] | None = None,
316
+ tool_choice: ToolChoice | None = None,
317
+ trace_id: str | None = None,
318
+ parent_span_id: str | None = None,
319
+ ) -> RunLocalResult | LocalStream:
320
+ """Run a local prompt with an LLM.
321
+
322
+ Args:
323
+ messages: The messages to send.
324
+ model: The model name (e.g., "gpt-4o", "claude-sonnet-4-20250514").
325
+ stream: Whether to stream the response.
326
+ provider: The LLM provider (auto-detected if not specified).
327
+ temperature: Sampling temperature.
328
+ max_output_tokens: Maximum output tokens.
329
+ top_p: Top-p sampling parameter.
330
+ stop_sequences: Stop sequences.
331
+ timeout_ms: Request timeout in milliseconds.
332
+ variables: Variables to interpolate into messages.
333
+ provider_api_key: API key for the LLM provider.
334
+ tags: Tags for the span.
335
+ user_id: User ID for the span.
336
+ session_id: Session ID for the span.
337
+ send_trace: Whether to send trace data (default True).
338
+ span_id: Custom span ID.
339
+ tools: Tool definitions for function calling.
340
+ tool_choice: Tool choice setting.
341
+ trace_id: Trace ID for linking spans.
342
+ parent_span_id: Parent span ID for nested spans.
343
+
344
+ Returns:
345
+ RunLocalResult for non-streaming, LocalStream for streaming.
346
+
347
+ Raises:
348
+ TraciaError: If the request fails.
349
+ """
350
+ # Validate input
351
+ self._validate_run_local_input(messages, model, span_id)
352
+
353
+ # Convert and interpolate messages
354
+ prompt_messages = self._convert_messages(messages, variables)
355
+
356
+ # Generate IDs
357
+ effective_span_id = span_id or generate_span_id()
358
+ effective_trace_id = trace_id or generate_trace_id()
359
+ should_send_trace = send_trace is not False
360
+
361
+ # Calculate timeout
362
+ timeout_seconds = (timeout_ms or DEFAULT_TIMEOUT_MS) / 1000.0
363
+
364
+ if stream:
365
+ return self._run_local_streaming(
366
+ messages=prompt_messages,
367
+ model=model,
368
+ provider=provider,
369
+ temperature=temperature,
370
+ max_output_tokens=max_output_tokens,
371
+ top_p=top_p,
372
+ stop_sequences=stop_sequences,
373
+ timeout=timeout_seconds,
374
+ provider_api_key=provider_api_key,
375
+ tags=tags,
376
+ user_id=user_id,
377
+ session_id=session_id,
378
+ send_trace=should_send_trace,
379
+ span_id=effective_span_id,
380
+ trace_id=effective_trace_id,
381
+ parent_span_id=parent_span_id,
382
+ tools=tools,
383
+ tool_choice=tool_choice,
384
+ variables=variables,
385
+ )
386
+ else:
387
+ return self._run_local_non_streaming(
388
+ messages=prompt_messages,
389
+ model=model,
390
+ provider=provider,
391
+ temperature=temperature,
392
+ max_output_tokens=max_output_tokens,
393
+ top_p=top_p,
394
+ stop_sequences=stop_sequences,
395
+ timeout=timeout_seconds,
396
+ provider_api_key=provider_api_key,
397
+ tags=tags,
398
+ user_id=user_id,
399
+ session_id=session_id,
400
+ send_trace=should_send_trace,
401
+ span_id=effective_span_id,
402
+ trace_id=effective_trace_id,
403
+ parent_span_id=parent_span_id,
404
+ tools=tools,
405
+ tool_choice=tool_choice,
406
+ variables=variables,
407
+ )
408
+
409
+ def _run_local_non_streaming(
410
+ self,
411
+ *,
412
+ messages: list[LocalPromptMessage],
413
+ model: str,
414
+ provider: LLMProvider | None,
415
+ temperature: float | None,
416
+ max_output_tokens: int | None,
417
+ top_p: float | None,
418
+ stop_sequences: list[str] | None,
419
+ timeout: float,
420
+ provider_api_key: str | None,
421
+ tags: list[str] | None,
422
+ user_id: str | None,
423
+ session_id: str | None,
424
+ send_trace: bool,
425
+ span_id: str,
426
+ trace_id: str,
427
+ parent_span_id: str | None,
428
+ tools: list[ToolDefinition] | None,
429
+ tool_choice: ToolChoice | None,
430
+ variables: dict[str, str] | None,
431
+ ) -> RunLocalResult:
432
+ """Run local prompt without streaming."""
433
+ start_time = time.time()
434
+ error_message: str | None = None
435
+ result_text = ""
436
+ result_tool_calls: list[ToolCall] = []
437
+ result_usage = TokenUsage(inputTokens=0, outputTokens=0, totalTokens=0)
438
+ result_provider = resolve_provider(model, provider)
439
+ finish_reason = "stop"
440
+
441
+ try:
442
+ completion = self._llm_client.complete(
443
+ model=model,
444
+ messages=messages,
445
+ provider=result_provider,
446
+ temperature=temperature,
447
+ max_tokens=max_output_tokens,
448
+ top_p=top_p,
449
+ stop=stop_sequences,
450
+ tools=tools,
451
+ tool_choice=tool_choice,
452
+ api_key=provider_api_key,
453
+ timeout=timeout,
454
+ )
455
+
456
+ result_text = completion.text
457
+ result_tool_calls = completion.tool_calls
458
+ result_provider = completion.provider
459
+ finish_reason = completion.finish_reason
460
+ result_usage = TokenUsage(
461
+ inputTokens=completion.input_tokens,
462
+ outputTokens=completion.output_tokens,
463
+ totalTokens=completion.total_tokens,
464
+ )
465
+
466
+ except TraciaError:
467
+ raise
468
+ except Exception as e:
469
+ error_message = sanitize_error_message(str(e))
470
+ raise TraciaError(
471
+ code=TraciaErrorCode.PROVIDER_ERROR,
472
+ message=f"LLM provider error: {error_message}",
473
+ ) from e
474
+ finally:
475
+ latency_ms = int((time.time() - start_time) * 1000)
476
+
477
+ if send_trace:
478
+ payload = CreateSpanPayload(
479
+ spanId=span_id,
480
+ model=model,
481
+ provider=result_provider,
482
+ input={"messages": [m.model_dump(by_alias=True, exclude_none=True) for m in messages]},
483
+ variables=variables,
484
+ output=result_text if not error_message else None,
485
+ status=SPAN_STATUS_ERROR if error_message else SPAN_STATUS_SUCCESS,
486
+ error=error_message,
487
+ latencyMs=latency_ms,
488
+ inputTokens=result_usage.input_tokens,
489
+ outputTokens=result_usage.output_tokens,
490
+ totalTokens=result_usage.total_tokens,
491
+ tags=tags,
492
+ userId=user_id,
493
+ sessionId=session_id,
494
+ temperature=temperature,
495
+ maxOutputTokens=max_output_tokens,
496
+ topP=top_p,
497
+ tools=tools,
498
+ toolCalls=result_tool_calls if result_tool_calls else None,
499
+ traceId=trace_id,
500
+ parentSpanId=parent_span_id,
501
+ )
502
+ self._schedule_span_creation(payload)
503
+
504
+ return RunLocalResult(
505
+ text=result_text,
506
+ spanId=span_id,
507
+ traceId=trace_id,
508
+ latencyMs=latency_ms,
509
+ usage=result_usage,
510
+ cost=None,
511
+ provider=result_provider,
512
+ model=model,
513
+ toolCalls=result_tool_calls,
514
+ finishReason=finish_reason,
515
+ message=build_assistant_message(result_text, result_tool_calls),
516
+ )
517
+
518
+ def _run_local_streaming(
519
+ self,
520
+ *,
521
+ messages: list[LocalPromptMessage],
522
+ model: str,
523
+ provider: LLMProvider | None,
524
+ temperature: float | None,
525
+ max_output_tokens: int | None,
526
+ top_p: float | None,
527
+ stop_sequences: list[str] | None,
528
+ timeout: float,
529
+ provider_api_key: str | None,
530
+ tags: list[str] | None,
531
+ user_id: str | None,
532
+ session_id: str | None,
533
+ send_trace: bool,
534
+ span_id: str,
535
+ trace_id: str,
536
+ parent_span_id: str | None,
537
+ tools: list[ToolDefinition] | None,
538
+ tool_choice: ToolChoice | None,
539
+ variables: dict[str, str] | None,
540
+ ) -> LocalStream:
541
+ """Run local prompt with streaming."""
542
+ start_time = time.time()
543
+ abort_event = Event()
544
+ result_future: Future[StreamResult] = Future()
545
+
546
+ chunks_iter, result_holder, resolved_provider = self._llm_client.stream(
547
+ model=model,
548
+ messages=messages,
549
+ provider=provider,
550
+ temperature=temperature,
551
+ max_tokens=max_output_tokens,
552
+ top_p=top_p,
553
+ stop=stop_sequences,
554
+ tools=tools,
555
+ tool_choice=tool_choice,
556
+ api_key=provider_api_key,
557
+ timeout=timeout,
558
+ )
559
+
560
+ def wrapped_chunks():
561
+ error_message: str | None = None
562
+ aborted = False
563
+
564
+ try:
565
+ for chunk in chunks_iter:
566
+ if abort_event.is_set():
567
+ aborted = True
568
+ break
569
+ yield chunk
570
+ except Exception as e:
571
+ error_message = sanitize_error_message(str(e))
572
+ raise
573
+ finally:
574
+ latency_ms = int((time.time() - start_time) * 1000)
575
+
576
+ # Get completion result
577
+ completion = result_holder[0] if result_holder else None
578
+
579
+ result_text = completion.text if completion else ""
580
+ result_tool_calls = completion.tool_calls if completion else []
581
+ result_usage = TokenUsage(
582
+ inputTokens=completion.input_tokens if completion else 0,
583
+ outputTokens=completion.output_tokens if completion else 0,
584
+ totalTokens=completion.total_tokens if completion else 0,
585
+ )
586
+ finish_reason = completion.finish_reason if completion else "stop"
587
+
588
+ if send_trace:
589
+ payload = CreateSpanPayload(
590
+ spanId=span_id,
591
+ model=model,
592
+ provider=resolved_provider,
593
+ input={"messages": [m.model_dump(by_alias=True, exclude_none=True) for m in messages]},
594
+ variables=variables,
595
+ output=result_text if not error_message else None,
596
+ status=SPAN_STATUS_ERROR if error_message else SPAN_STATUS_SUCCESS,
597
+ error=error_message,
598
+ latencyMs=latency_ms,
599
+ inputTokens=result_usage.input_tokens,
600
+ outputTokens=result_usage.output_tokens,
601
+ totalTokens=result_usage.total_tokens,
602
+ tags=tags,
603
+ userId=user_id,
604
+ sessionId=session_id,
605
+ temperature=temperature,
606
+ maxOutputTokens=max_output_tokens,
607
+ topP=top_p,
608
+ tools=tools,
609
+ toolCalls=result_tool_calls if result_tool_calls else None,
610
+ traceId=trace_id,
611
+ parentSpanId=parent_span_id,
612
+ )
613
+ self._schedule_span_creation(payload)
614
+
615
+ # Set the result
616
+ stream_result = StreamResult(
617
+ text=result_text,
618
+ spanId=span_id,
619
+ traceId=trace_id,
620
+ latencyMs=latency_ms,
621
+ usage=result_usage,
622
+ cost=None,
623
+ provider=resolved_provider,
624
+ model=model,
625
+ toolCalls=result_tool_calls,
626
+ finishReason=finish_reason,
627
+ message=build_assistant_message(result_text, result_tool_calls),
628
+ aborted=aborted,
629
+ )
630
+ result_future.set_result(stream_result)
631
+
632
+ return LocalStream(
633
+ span_id=span_id,
634
+ trace_id=trace_id,
635
+ chunks=wrapped_chunks(),
636
+ result_holder=result_holder,
637
+ result_future=result_future,
638
+ abort_event=abort_event,
639
+ )
640
+
641
+ @overload
642
+ async def arun_local(
643
+ self,
644
+ *,
645
+ messages: list[dict[str, Any] | LocalPromptMessage],
646
+ model: str,
647
+ stream: Literal[False] = False,
648
+ provider: LLMProvider | None = None,
649
+ temperature: float | None = None,
650
+ max_output_tokens: int | None = None,
651
+ top_p: float | None = None,
652
+ stop_sequences: list[str] | None = None,
653
+ timeout_ms: int | None = None,
654
+ variables: dict[str, str] | None = None,
655
+ provider_api_key: str | None = None,
656
+ tags: list[str] | None = None,
657
+ user_id: str | None = None,
658
+ session_id: str | None = None,
659
+ send_trace: bool | None = None,
660
+ span_id: str | None = None,
661
+ tools: list[ToolDefinition] | None = None,
662
+ tool_choice: ToolChoice | None = None,
663
+ trace_id: str | None = None,
664
+ parent_span_id: str | None = None,
665
+ ) -> RunLocalResult: ...
666
+
667
+ @overload
668
+ async def arun_local(
669
+ self,
670
+ *,
671
+ messages: list[dict[str, Any] | LocalPromptMessage],
672
+ model: str,
673
+ stream: Literal[True],
674
+ provider: LLMProvider | None = None,
675
+ temperature: float | None = None,
676
+ max_output_tokens: int | None = None,
677
+ top_p: float | None = None,
678
+ stop_sequences: list[str] | None = None,
679
+ timeout_ms: int | None = None,
680
+ variables: dict[str, str] | None = None,
681
+ provider_api_key: str | None = None,
682
+ tags: list[str] | None = None,
683
+ user_id: str | None = None,
684
+ session_id: str | None = None,
685
+ send_trace: bool | None = None,
686
+ span_id: str | None = None,
687
+ tools: list[ToolDefinition] | None = None,
688
+ tool_choice: ToolChoice | None = None,
689
+ trace_id: str | None = None,
690
+ parent_span_id: str | None = None,
691
+ ) -> AsyncLocalStream: ...
692
+
693
+ async def arun_local(
694
+ self,
695
+ *,
696
+ messages: list[dict[str, Any] | LocalPromptMessage],
697
+ model: str,
698
+ stream: bool = False,
699
+ provider: LLMProvider | None = None,
700
+ temperature: float | None = None,
701
+ max_output_tokens: int | None = None,
702
+ top_p: float | None = None,
703
+ stop_sequences: list[str] | None = None,
704
+ timeout_ms: int | None = None,
705
+ variables: dict[str, str] | None = None,
706
+ provider_api_key: str | None = None,
707
+ tags: list[str] | None = None,
708
+ user_id: str | None = None,
709
+ session_id: str | None = None,
710
+ send_trace: bool | None = None,
711
+ span_id: str | None = None,
712
+ tools: list[ToolDefinition] | None = None,
713
+ tool_choice: ToolChoice | None = None,
714
+ trace_id: str | None = None,
715
+ parent_span_id: str | None = None,
716
+ ) -> RunLocalResult | AsyncLocalStream:
717
+ """Run a local prompt asynchronously.
718
+
719
+ See run_local for parameter documentation.
720
+ """
721
+ # Validate input
722
+ self._validate_run_local_input(messages, model, span_id)
723
+
724
+ # Convert and interpolate messages
725
+ prompt_messages = self._convert_messages(messages, variables)
726
+
727
+ # Generate IDs
728
+ effective_span_id = span_id or generate_span_id()
729
+ effective_trace_id = trace_id or generate_trace_id()
730
+ should_send_trace = send_trace is not False
731
+
732
+ # Calculate timeout
733
+ timeout_seconds = (timeout_ms or DEFAULT_TIMEOUT_MS) / 1000.0
734
+
735
+ if stream:
736
+ return await self._arun_local_streaming(
737
+ messages=prompt_messages,
738
+ model=model,
739
+ provider=provider,
740
+ temperature=temperature,
741
+ max_output_tokens=max_output_tokens,
742
+ top_p=top_p,
743
+ stop_sequences=stop_sequences,
744
+ timeout=timeout_seconds,
745
+ provider_api_key=provider_api_key,
746
+ tags=tags,
747
+ user_id=user_id,
748
+ session_id=session_id,
749
+ send_trace=should_send_trace,
750
+ span_id=effective_span_id,
751
+ trace_id=effective_trace_id,
752
+ parent_span_id=parent_span_id,
753
+ tools=tools,
754
+ tool_choice=tool_choice,
755
+ variables=variables,
756
+ )
757
+ else:
758
+ return await self._arun_local_non_streaming(
759
+ messages=prompt_messages,
760
+ model=model,
761
+ provider=provider,
762
+ temperature=temperature,
763
+ max_output_tokens=max_output_tokens,
764
+ top_p=top_p,
765
+ stop_sequences=stop_sequences,
766
+ timeout=timeout_seconds,
767
+ provider_api_key=provider_api_key,
768
+ tags=tags,
769
+ user_id=user_id,
770
+ session_id=session_id,
771
+ send_trace=should_send_trace,
772
+ span_id=effective_span_id,
773
+ trace_id=effective_trace_id,
774
+ parent_span_id=parent_span_id,
775
+ tools=tools,
776
+ tool_choice=tool_choice,
777
+ variables=variables,
778
+ )
779
+
780
+ async def _arun_local_non_streaming(
781
+ self,
782
+ *,
783
+ messages: list[LocalPromptMessage],
784
+ model: str,
785
+ provider: LLMProvider | None,
786
+ temperature: float | None,
787
+ max_output_tokens: int | None,
788
+ top_p: float | None,
789
+ stop_sequences: list[str] | None,
790
+ timeout: float,
791
+ provider_api_key: str | None,
792
+ tags: list[str] | None,
793
+ user_id: str | None,
794
+ session_id: str | None,
795
+ send_trace: bool,
796
+ span_id: str,
797
+ trace_id: str,
798
+ parent_span_id: str | None,
799
+ tools: list[ToolDefinition] | None,
800
+ tool_choice: ToolChoice | None,
801
+ variables: dict[str, str] | None,
802
+ ) -> RunLocalResult:
803
+ """Run local prompt without streaming (async)."""
804
+ start_time = time.time()
805
+ error_message: str | None = None
806
+ result_text = ""
807
+ result_tool_calls: list[ToolCall] = []
808
+ result_usage = TokenUsage(inputTokens=0, outputTokens=0, totalTokens=0)
809
+ result_provider = resolve_provider(model, provider)
810
+ finish_reason = "stop"
811
+ latency_ms = 0
812
+
813
+ try:
814
+ completion = await self._llm_client.acomplete(
815
+ model=model,
816
+ messages=messages,
817
+ provider=result_provider,
818
+ temperature=temperature,
819
+ max_tokens=max_output_tokens,
820
+ top_p=top_p,
821
+ stop=stop_sequences,
822
+ tools=tools,
823
+ tool_choice=tool_choice,
824
+ api_key=provider_api_key,
825
+ timeout=timeout,
826
+ )
827
+
828
+ result_text = completion.text
829
+ result_tool_calls = completion.tool_calls
830
+ result_provider = completion.provider
831
+ finish_reason = completion.finish_reason
832
+ result_usage = TokenUsage(
833
+ inputTokens=completion.input_tokens,
834
+ outputTokens=completion.output_tokens,
835
+ totalTokens=completion.total_tokens,
836
+ )
837
+
838
+ except TraciaError:
839
+ raise
840
+ except Exception as e:
841
+ error_message = sanitize_error_message(str(e))
842
+ raise TraciaError(
843
+ code=TraciaErrorCode.PROVIDER_ERROR,
844
+ message=f"LLM provider error: {error_message}",
845
+ ) from e
846
+ finally:
847
+ latency_ms = int((time.time() - start_time) * 1000)
848
+
849
+ if send_trace:
850
+ payload = CreateSpanPayload(
851
+ spanId=span_id,
852
+ model=model,
853
+ provider=result_provider,
854
+ input={"messages": [m.model_dump(by_alias=True, exclude_none=True) for m in messages]},
855
+ variables=variables,
856
+ output=result_text if not error_message else None,
857
+ status=SPAN_STATUS_ERROR if error_message else SPAN_STATUS_SUCCESS,
858
+ error=error_message,
859
+ latencyMs=latency_ms,
860
+ inputTokens=result_usage.input_tokens,
861
+ outputTokens=result_usage.output_tokens,
862
+ totalTokens=result_usage.total_tokens,
863
+ tags=tags,
864
+ userId=user_id,
865
+ sessionId=session_id,
866
+ temperature=temperature,
867
+ maxOutputTokens=max_output_tokens,
868
+ topP=top_p,
869
+ tools=tools,
870
+ toolCalls=result_tool_calls if result_tool_calls else None,
871
+ traceId=trace_id,
872
+ parentSpanId=parent_span_id,
873
+ )
874
+ self._schedule_span_creation(payload)
875
+
876
+ return RunLocalResult(
877
+ text=result_text,
878
+ spanId=span_id,
879
+ traceId=trace_id,
880
+ latencyMs=latency_ms,
881
+ usage=result_usage,
882
+ cost=None,
883
+ provider=result_provider,
884
+ model=model,
885
+ toolCalls=result_tool_calls,
886
+ finishReason=finish_reason,
887
+ message=build_assistant_message(result_text, result_tool_calls),
888
+ )
889
+
890
+ async def _arun_local_streaming(
891
+ self,
892
+ *,
893
+ messages: list[LocalPromptMessage],
894
+ model: str,
895
+ provider: LLMProvider | None,
896
+ temperature: float | None,
897
+ max_output_tokens: int | None,
898
+ top_p: float | None,
899
+ stop_sequences: list[str] | None,
900
+ timeout: float,
901
+ provider_api_key: str | None,
902
+ tags: list[str] | None,
903
+ user_id: str | None,
904
+ session_id: str | None,
905
+ send_trace: bool,
906
+ span_id: str,
907
+ trace_id: str,
908
+ parent_span_id: str | None,
909
+ tools: list[ToolDefinition] | None,
910
+ tool_choice: ToolChoice | None,
911
+ variables: dict[str, str] | None,
912
+ ) -> AsyncLocalStream:
913
+ """Run local prompt with streaming (async)."""
914
+ start_time = time.time()
915
+ abort_event = asyncio.Event()
916
+ loop = asyncio.get_running_loop()
917
+ result_future: asyncio.Future[StreamResult] = loop.create_future()
918
+
919
+ chunks_iter, result_holder, resolved_provider = await self._llm_client.astream(
920
+ model=model,
921
+ messages=messages,
922
+ provider=provider,
923
+ temperature=temperature,
924
+ max_tokens=max_output_tokens,
925
+ top_p=top_p,
926
+ stop=stop_sequences,
927
+ tools=tools,
928
+ tool_choice=tool_choice,
929
+ api_key=provider_api_key,
930
+ timeout=timeout,
931
+ )
932
+
933
+ async def wrapped_chunks():
934
+ error_message: str | None = None
935
+ aborted = False
936
+
937
+ try:
938
+ async for chunk in chunks_iter:
939
+ if abort_event.is_set():
940
+ aborted = True
941
+ break
942
+ yield chunk
943
+ except Exception as e:
944
+ error_message = sanitize_error_message(str(e))
945
+ raise
946
+ finally:
947
+ latency_ms = int((time.time() - start_time) * 1000)
948
+
949
+ # Get completion result
950
+ completion = result_holder[0] if result_holder else None
951
+
952
+ result_text = completion.text if completion else ""
953
+ result_tool_calls = completion.tool_calls if completion else []
954
+ result_usage = TokenUsage(
955
+ inputTokens=completion.input_tokens if completion else 0,
956
+ outputTokens=completion.output_tokens if completion else 0,
957
+ totalTokens=completion.total_tokens if completion else 0,
958
+ )
959
+ finish_reason = completion.finish_reason if completion else "stop"
960
+
961
+ if send_trace:
962
+ payload = CreateSpanPayload(
963
+ spanId=span_id,
964
+ model=model,
965
+ provider=resolved_provider,
966
+ input={"messages": [m.model_dump(by_alias=True, exclude_none=True) for m in messages]},
967
+ variables=variables,
968
+ output=result_text if not error_message else None,
969
+ status=SPAN_STATUS_ERROR if error_message else SPAN_STATUS_SUCCESS,
970
+ error=error_message,
971
+ latencyMs=latency_ms,
972
+ inputTokens=result_usage.input_tokens,
973
+ outputTokens=result_usage.output_tokens,
974
+ totalTokens=result_usage.total_tokens,
975
+ tags=tags,
976
+ userId=user_id,
977
+ sessionId=session_id,
978
+ temperature=temperature,
979
+ maxOutputTokens=max_output_tokens,
980
+ topP=top_p,
981
+ tools=tools,
982
+ toolCalls=result_tool_calls if result_tool_calls else None,
983
+ traceId=trace_id,
984
+ parentSpanId=parent_span_id,
985
+ )
986
+ self._schedule_span_creation(payload)
987
+
988
+ # Set the result
989
+ stream_result = StreamResult(
990
+ text=result_text,
991
+ spanId=span_id,
992
+ traceId=trace_id,
993
+ latencyMs=latency_ms,
994
+ usage=result_usage,
995
+ cost=None,
996
+ provider=resolved_provider,
997
+ model=model,
998
+ toolCalls=result_tool_calls,
999
+ finishReason=finish_reason,
1000
+ message=build_assistant_message(result_text, result_tool_calls),
1001
+ aborted=aborted,
1002
+ )
1003
+ if not result_future.done():
1004
+ result_future.set_result(stream_result)
1005
+
1006
+ return AsyncLocalStream(
1007
+ span_id=span_id,
1008
+ trace_id=trace_id,
1009
+ chunks=wrapped_chunks(),
1010
+ result_holder=result_holder,
1011
+ result_future=result_future,
1012
+ abort_event=abort_event,
1013
+ )
1014
+
1015
+ def create_session(
1016
+ self,
1017
+ *,
1018
+ trace_id: str | None = None,
1019
+ parent_span_id: str | None = None,
1020
+ ) -> TraciaSession:
1021
+ """Create a new session for linked runs.
1022
+
1023
+ Args:
1024
+ trace_id: Optional initial trace ID.
1025
+ parent_span_id: Optional initial parent span ID.
1026
+
1027
+ Returns:
1028
+ A new TraciaSession instance.
1029
+ """
1030
+ return TraciaSession(self, trace_id, parent_span_id)
1031
+
1032
+ def flush(self) -> None:
1033
+ """Wait for all pending span creations to complete."""
1034
+ with self._pending_spans_lock:
1035
+ futures = list(self._pending_spans.values())
1036
+
1037
+ for future in futures:
1038
+ try:
1039
+ future.result(timeout=30.0)
1040
+ except Exception:
1041
+ pass
1042
+
1043
+ async def aflush(self) -> None:
1044
+ """Wait for all pending span creations to complete (async)."""
1045
+ # In async context, we still use the sync executor
1046
+ loop = asyncio.get_running_loop()
1047
+ await loop.run_in_executor(None, self.flush)
1048
+
1049
+ def close(self) -> None:
1050
+ """Close the client and release resources."""
1051
+ if self._closed:
1052
+ return
1053
+ self._closed = True
1054
+ self.flush()
1055
+ self._http_client.close()
1056
+ if self._async_http_client:
1057
+ try:
1058
+ loop = asyncio.get_running_loop()
1059
+ loop.create_task(self._async_http_client.aclose())
1060
+ except RuntimeError:
1061
+ # No running loop - run synchronously
1062
+ try:
1063
+ asyncio.run(self._async_http_client.aclose())
1064
+ except Exception:
1065
+ pass
1066
+ self._executor.shutdown(wait=False)
1067
+
1068
+ def __del__(self) -> None:
1069
+ try:
1070
+ if not self._closed:
1071
+ warnings.warn(
1072
+ "Unclosed Tracia client. Use 'client.close()' or 'with Tracia(...) as client:'.",
1073
+ ResourceWarning,
1074
+ stacklevel=1,
1075
+ )
1076
+ except Exception:
1077
+ pass
1078
+
1079
+ async def aclose(self) -> None:
1080
+ """Close the client and release resources (async)."""
1081
+ if self._closed:
1082
+ return
1083
+ self._closed = True
1084
+ await self.aflush()
1085
+ self._http_client.close()
1086
+ if self._async_http_client:
1087
+ await self._async_http_client.aclose()
1088
+ self._executor.shutdown(wait=False)
1089
+
1090
+ def __enter__(self) -> "Tracia":
1091
+ return self
1092
+
1093
+ def __exit__(self, *args: Any) -> None:
1094
+ self.close()
1095
+
1096
+ async def __aenter__(self) -> "Tracia":
1097
+ return self
1098
+
1099
+ async def __aexit__(self, *args: Any) -> None:
1100
+ await self.aclose()