tracia 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tracia/__init__.py +152 -3
- tracia/_client.py +1100 -0
- tracia/_constants.py +39 -0
- tracia/_errors.py +87 -0
- tracia/_http.py +362 -0
- tracia/_llm.py +898 -0
- tracia/_session.py +244 -0
- tracia/_streaming.py +135 -0
- tracia/_types.py +564 -0
- tracia/_utils.py +116 -0
- tracia/py.typed +0 -0
- tracia/resources/__init__.py +6 -0
- tracia/resources/prompts.py +273 -0
- tracia/resources/spans.py +227 -0
- tracia-0.1.1.dist-info/METADATA +277 -0
- tracia-0.1.1.dist-info/RECORD +18 -0
- tracia-0.0.1.dist-info/METADATA +0 -52
- tracia-0.0.1.dist-info/RECORD +0 -5
- {tracia-0.0.1.dist-info → tracia-0.1.1.dist-info}/WHEEL +0 -0
- {tracia-0.0.1.dist-info → tracia-0.1.1.dist-info}/licenses/LICENSE +0 -0
tracia/_client.py
ADDED
|
@@ -0,0 +1,1100 @@
|
|
|
1
|
+
"""Main Tracia client implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import time
|
|
7
|
+
import warnings
|
|
8
|
+
from concurrent.futures import Future, ThreadPoolExecutor
|
|
9
|
+
from threading import Event, Lock
|
|
10
|
+
from typing import Any, Callable, Literal, overload
|
|
11
|
+
|
|
12
|
+
from ._constants import (
|
|
13
|
+
BASE_URL,
|
|
14
|
+
DEFAULT_TIMEOUT_MS,
|
|
15
|
+
MAX_PENDING_SPANS,
|
|
16
|
+
SPAN_RETRY_ATTEMPTS,
|
|
17
|
+
SPAN_RETRY_DELAY_MS,
|
|
18
|
+
SPAN_STATUS_ERROR,
|
|
19
|
+
SPAN_STATUS_SUCCESS,
|
|
20
|
+
)
|
|
21
|
+
from ._errors import TraciaError, TraciaErrorCode, sanitize_error_message
|
|
22
|
+
from ._http import AsyncHttpClient, HttpClient
|
|
23
|
+
from ._llm import LLMClient, build_assistant_message, resolve_provider
|
|
24
|
+
from ._session import TraciaSession
|
|
25
|
+
from ._streaming import AsyncLocalStream, LocalStream
|
|
26
|
+
from ._types import (
|
|
27
|
+
CreateSpanPayload,
|
|
28
|
+
LocalPromptMessage,
|
|
29
|
+
LLMProvider,
|
|
30
|
+
RunLocalResult,
|
|
31
|
+
StreamResult,
|
|
32
|
+
TokenUsage,
|
|
33
|
+
ToolCall,
|
|
34
|
+
ToolDefinition,
|
|
35
|
+
ToolChoice,
|
|
36
|
+
)
|
|
37
|
+
from ._utils import (
|
|
38
|
+
generate_span_id,
|
|
39
|
+
generate_trace_id,
|
|
40
|
+
interpolate_message_content,
|
|
41
|
+
is_valid_span_id_format,
|
|
42
|
+
)
|
|
43
|
+
from .resources import Prompts, Spans
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Tracia:
|
|
47
|
+
"""Main Tracia client for LLM prompt management and tracing.
|
|
48
|
+
|
|
49
|
+
Example usage:
|
|
50
|
+
```python
|
|
51
|
+
from tracia import Tracia
|
|
52
|
+
|
|
53
|
+
client = Tracia(api_key="your_api_key")
|
|
54
|
+
|
|
55
|
+
# Run a local prompt
|
|
56
|
+
result = client.run_local(
|
|
57
|
+
model="gpt-4o",
|
|
58
|
+
messages=[{"role": "user", "content": "Hello!"}]
|
|
59
|
+
)
|
|
60
|
+
print(result.text)
|
|
61
|
+
|
|
62
|
+
# Run with streaming
|
|
63
|
+
stream = client.run_local(
|
|
64
|
+
model="gpt-4o",
|
|
65
|
+
messages=[{"role": "user", "content": "Tell me a story"}],
|
|
66
|
+
stream=True
|
|
67
|
+
)
|
|
68
|
+
for chunk in stream:
|
|
69
|
+
print(chunk, end="")
|
|
70
|
+
|
|
71
|
+
# Create a session for multi-turn conversations
|
|
72
|
+
session = client.create_session()
|
|
73
|
+
r1 = session.run_local(model="gpt-4o", messages=[...])
|
|
74
|
+
r2 = session.run_local(model="gpt-4o", messages=[...]) # Linked
|
|
75
|
+
|
|
76
|
+
# Clean up
|
|
77
|
+
client.close()
|
|
78
|
+
```
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
api_key: str,
|
|
84
|
+
*,
|
|
85
|
+
base_url: str = BASE_URL,
|
|
86
|
+
on_span_error: Callable[[Exception, str], None] | None = None,
|
|
87
|
+
) -> None:
|
|
88
|
+
"""Initialize the Tracia client.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
api_key: Your Tracia API key.
|
|
92
|
+
base_url: The base URL for the Tracia API.
|
|
93
|
+
on_span_error: Optional callback for span creation errors.
|
|
94
|
+
"""
|
|
95
|
+
self._api_key = api_key
|
|
96
|
+
self._base_url = base_url
|
|
97
|
+
self._on_span_error = on_span_error
|
|
98
|
+
|
|
99
|
+
# HTTP clients
|
|
100
|
+
self._http_client = HttpClient(api_key, base_url)
|
|
101
|
+
self._async_http_client: AsyncHttpClient | None = None
|
|
102
|
+
|
|
103
|
+
# LLM client
|
|
104
|
+
self._llm_client = LLMClient()
|
|
105
|
+
|
|
106
|
+
# Resources
|
|
107
|
+
self.prompts = Prompts(self._http_client)
|
|
108
|
+
self.spans = Spans(self._http_client)
|
|
109
|
+
|
|
110
|
+
# Pending spans management
|
|
111
|
+
self._pending_spans: dict[str, Future[None]] = {}
|
|
112
|
+
self._pending_spans_lock = Lock()
|
|
113
|
+
self._executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix="tracia-span-")
|
|
114
|
+
self._closed = False
|
|
115
|
+
|
|
116
|
+
def _get_async_http_client(self) -> AsyncHttpClient:
|
|
117
|
+
"""Get or create the async HTTP client."""
|
|
118
|
+
if self._async_http_client is None:
|
|
119
|
+
self._async_http_client = AsyncHttpClient(self._api_key, self._base_url)
|
|
120
|
+
# Update resources with async client
|
|
121
|
+
self.prompts._async_client = self._async_http_client
|
|
122
|
+
self.spans._async_client = self._async_http_client
|
|
123
|
+
return self._async_http_client
|
|
124
|
+
|
|
125
|
+
def _validate_run_local_input(
|
|
126
|
+
self,
|
|
127
|
+
messages: list[dict[str, Any] | LocalPromptMessage],
|
|
128
|
+
model: str,
|
|
129
|
+
span_id: str | None,
|
|
130
|
+
) -> None:
|
|
131
|
+
"""Validate run_local input parameters."""
|
|
132
|
+
if not model or not model.strip():
|
|
133
|
+
raise TraciaError(
|
|
134
|
+
code=TraciaErrorCode.INVALID_REQUEST,
|
|
135
|
+
message="model is required and cannot be empty",
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if not messages or len(messages) == 0:
|
|
139
|
+
raise TraciaError(
|
|
140
|
+
code=TraciaErrorCode.INVALID_REQUEST,
|
|
141
|
+
message="messages is required and cannot be empty",
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if span_id is not None and not is_valid_span_id_format(span_id):
|
|
145
|
+
raise TraciaError(
|
|
146
|
+
code=TraciaErrorCode.INVALID_REQUEST,
|
|
147
|
+
message=f"Invalid span_id format: {span_id}. Expected sp_XXXXXXXXXXXXXXXX or tr_XXXXXXXXXXXXXXXX",
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Validate tool messages have tool_call_id
|
|
151
|
+
for msg in messages:
|
|
152
|
+
if isinstance(msg, dict):
|
|
153
|
+
if msg.get("role") == "tool" and not msg.get("tool_call_id") and not msg.get("toolCallId"):
|
|
154
|
+
raise TraciaError(
|
|
155
|
+
code=TraciaErrorCode.INVALID_REQUEST,
|
|
156
|
+
message="Tool messages must have tool_call_id",
|
|
157
|
+
)
|
|
158
|
+
elif isinstance(msg, LocalPromptMessage):
|
|
159
|
+
if msg.role == "tool" and not msg.tool_call_id:
|
|
160
|
+
raise TraciaError(
|
|
161
|
+
code=TraciaErrorCode.INVALID_REQUEST,
|
|
162
|
+
message="Tool messages must have tool_call_id",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def _convert_messages(
|
|
166
|
+
self,
|
|
167
|
+
messages: list[dict[str, Any] | LocalPromptMessage],
|
|
168
|
+
variables: dict[str, str] | None,
|
|
169
|
+
) -> list[LocalPromptMessage]:
|
|
170
|
+
"""Convert dict messages to LocalPromptMessage and interpolate variables."""
|
|
171
|
+
result: list[LocalPromptMessage] = []
|
|
172
|
+
|
|
173
|
+
for msg in messages:
|
|
174
|
+
if isinstance(msg, LocalPromptMessage):
|
|
175
|
+
pm = msg
|
|
176
|
+
else:
|
|
177
|
+
pm = LocalPromptMessage(
|
|
178
|
+
role=msg["role"],
|
|
179
|
+
content=msg.get("content", ""),
|
|
180
|
+
tool_call_id=msg.get("tool_call_id") or msg.get("toolCallId"),
|
|
181
|
+
tool_name=msg.get("tool_name") or msg.get("toolName"),
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Interpolate variables (skip tool messages)
|
|
185
|
+
if pm.role != "tool" and variables:
|
|
186
|
+
pm = LocalPromptMessage(
|
|
187
|
+
role=pm.role,
|
|
188
|
+
content=interpolate_message_content(pm.content, variables),
|
|
189
|
+
tool_call_id=pm.tool_call_id,
|
|
190
|
+
tool_name=pm.tool_name,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
result.append(pm)
|
|
194
|
+
|
|
195
|
+
return result
|
|
196
|
+
|
|
197
|
+
def _schedule_span_creation(
|
|
198
|
+
self,
|
|
199
|
+
payload: CreateSpanPayload,
|
|
200
|
+
) -> None:
|
|
201
|
+
"""Schedule span creation in the background with retry logic."""
|
|
202
|
+
span_id = payload.span_id
|
|
203
|
+
|
|
204
|
+
def create_span_with_retry() -> None:
|
|
205
|
+
last_error: Exception | None = None
|
|
206
|
+
|
|
207
|
+
for attempt in range(SPAN_RETRY_ATTEMPTS + 1):
|
|
208
|
+
try:
|
|
209
|
+
self._http_client.post(
|
|
210
|
+
"/api/v1/spans",
|
|
211
|
+
payload.model_dump(by_alias=True, exclude_none=True),
|
|
212
|
+
)
|
|
213
|
+
return
|
|
214
|
+
except Exception as e:
|
|
215
|
+
last_error = e
|
|
216
|
+
if attempt < SPAN_RETRY_ATTEMPTS:
|
|
217
|
+
delay = SPAN_RETRY_DELAY_MS * (attempt + 1) / 1000.0
|
|
218
|
+
time.sleep(delay)
|
|
219
|
+
|
|
220
|
+
# All retries failed
|
|
221
|
+
if self._on_span_error and last_error:
|
|
222
|
+
try:
|
|
223
|
+
self._on_span_error(last_error, span_id)
|
|
224
|
+
except Exception:
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
# Evict old spans if at capacity
|
|
228
|
+
with self._pending_spans_lock:
|
|
229
|
+
if len(self._pending_spans) >= MAX_PENDING_SPANS:
|
|
230
|
+
# Remove the oldest span
|
|
231
|
+
oldest_key = next(iter(self._pending_spans))
|
|
232
|
+
del self._pending_spans[oldest_key]
|
|
233
|
+
|
|
234
|
+
future = self._executor.submit(create_span_with_retry)
|
|
235
|
+
self._pending_spans[span_id] = future
|
|
236
|
+
|
|
237
|
+
# Clean up when done
|
|
238
|
+
def on_done(f: Future[None]) -> None:
|
|
239
|
+
with self._pending_spans_lock:
|
|
240
|
+
self._pending_spans.pop(span_id, None)
|
|
241
|
+
|
|
242
|
+
future.add_done_callback(on_done)
|
|
243
|
+
|
|
244
|
+
@overload
|
|
245
|
+
def run_local(
|
|
246
|
+
self,
|
|
247
|
+
*,
|
|
248
|
+
messages: list[dict[str, Any] | LocalPromptMessage],
|
|
249
|
+
model: str,
|
|
250
|
+
stream: Literal[False] = False,
|
|
251
|
+
provider: LLMProvider | None = None,
|
|
252
|
+
temperature: float | None = None,
|
|
253
|
+
max_output_tokens: int | None = None,
|
|
254
|
+
top_p: float | None = None,
|
|
255
|
+
stop_sequences: list[str] | None = None,
|
|
256
|
+
timeout_ms: int | None = None,
|
|
257
|
+
variables: dict[str, str] | None = None,
|
|
258
|
+
provider_api_key: str | None = None,
|
|
259
|
+
tags: list[str] | None = None,
|
|
260
|
+
user_id: str | None = None,
|
|
261
|
+
session_id: str | None = None,
|
|
262
|
+
send_trace: bool | None = None,
|
|
263
|
+
span_id: str | None = None,
|
|
264
|
+
tools: list[ToolDefinition] | None = None,
|
|
265
|
+
tool_choice: ToolChoice | None = None,
|
|
266
|
+
trace_id: str | None = None,
|
|
267
|
+
parent_span_id: str | None = None,
|
|
268
|
+
) -> RunLocalResult: ...
|
|
269
|
+
|
|
270
|
+
@overload
|
|
271
|
+
def run_local(
|
|
272
|
+
self,
|
|
273
|
+
*,
|
|
274
|
+
messages: list[dict[str, Any] | LocalPromptMessage],
|
|
275
|
+
model: str,
|
|
276
|
+
stream: Literal[True],
|
|
277
|
+
provider: LLMProvider | None = None,
|
|
278
|
+
temperature: float | None = None,
|
|
279
|
+
max_output_tokens: int | None = None,
|
|
280
|
+
top_p: float | None = None,
|
|
281
|
+
stop_sequences: list[str] | None = None,
|
|
282
|
+
timeout_ms: int | None = None,
|
|
283
|
+
variables: dict[str, str] | None = None,
|
|
284
|
+
provider_api_key: str | None = None,
|
|
285
|
+
tags: list[str] | None = None,
|
|
286
|
+
user_id: str | None = None,
|
|
287
|
+
session_id: str | None = None,
|
|
288
|
+
send_trace: bool | None = None,
|
|
289
|
+
span_id: str | None = None,
|
|
290
|
+
tools: list[ToolDefinition] | None = None,
|
|
291
|
+
tool_choice: ToolChoice | None = None,
|
|
292
|
+
trace_id: str | None = None,
|
|
293
|
+
parent_span_id: str | None = None,
|
|
294
|
+
) -> LocalStream: ...
|
|
295
|
+
|
|
296
|
+
def run_local(
|
|
297
|
+
self,
|
|
298
|
+
*,
|
|
299
|
+
messages: list[dict[str, Any] | LocalPromptMessage],
|
|
300
|
+
model: str,
|
|
301
|
+
stream: bool = False,
|
|
302
|
+
provider: LLMProvider | None = None,
|
|
303
|
+
temperature: float | None = None,
|
|
304
|
+
max_output_tokens: int | None = None,
|
|
305
|
+
top_p: float | None = None,
|
|
306
|
+
stop_sequences: list[str] | None = None,
|
|
307
|
+
timeout_ms: int | None = None,
|
|
308
|
+
variables: dict[str, str] | None = None,
|
|
309
|
+
provider_api_key: str | None = None,
|
|
310
|
+
tags: list[str] | None = None,
|
|
311
|
+
user_id: str | None = None,
|
|
312
|
+
session_id: str | None = None,
|
|
313
|
+
send_trace: bool | None = None,
|
|
314
|
+
span_id: str | None = None,
|
|
315
|
+
tools: list[ToolDefinition] | None = None,
|
|
316
|
+
tool_choice: ToolChoice | None = None,
|
|
317
|
+
trace_id: str | None = None,
|
|
318
|
+
parent_span_id: str | None = None,
|
|
319
|
+
) -> RunLocalResult | LocalStream:
|
|
320
|
+
"""Run a local prompt with an LLM.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
messages: The messages to send.
|
|
324
|
+
model: The model name (e.g., "gpt-4o", "claude-sonnet-4-20250514").
|
|
325
|
+
stream: Whether to stream the response.
|
|
326
|
+
provider: The LLM provider (auto-detected if not specified).
|
|
327
|
+
temperature: Sampling temperature.
|
|
328
|
+
max_output_tokens: Maximum output tokens.
|
|
329
|
+
top_p: Top-p sampling parameter.
|
|
330
|
+
stop_sequences: Stop sequences.
|
|
331
|
+
timeout_ms: Request timeout in milliseconds.
|
|
332
|
+
variables: Variables to interpolate into messages.
|
|
333
|
+
provider_api_key: API key for the LLM provider.
|
|
334
|
+
tags: Tags for the span.
|
|
335
|
+
user_id: User ID for the span.
|
|
336
|
+
session_id: Session ID for the span.
|
|
337
|
+
send_trace: Whether to send trace data (default True).
|
|
338
|
+
span_id: Custom span ID.
|
|
339
|
+
tools: Tool definitions for function calling.
|
|
340
|
+
tool_choice: Tool choice setting.
|
|
341
|
+
trace_id: Trace ID for linking spans.
|
|
342
|
+
parent_span_id: Parent span ID for nested spans.
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
RunLocalResult for non-streaming, LocalStream for streaming.
|
|
346
|
+
|
|
347
|
+
Raises:
|
|
348
|
+
TraciaError: If the request fails.
|
|
349
|
+
"""
|
|
350
|
+
# Validate input
|
|
351
|
+
self._validate_run_local_input(messages, model, span_id)
|
|
352
|
+
|
|
353
|
+
# Convert and interpolate messages
|
|
354
|
+
prompt_messages = self._convert_messages(messages, variables)
|
|
355
|
+
|
|
356
|
+
# Generate IDs
|
|
357
|
+
effective_span_id = span_id or generate_span_id()
|
|
358
|
+
effective_trace_id = trace_id or generate_trace_id()
|
|
359
|
+
should_send_trace = send_trace is not False
|
|
360
|
+
|
|
361
|
+
# Calculate timeout
|
|
362
|
+
timeout_seconds = (timeout_ms or DEFAULT_TIMEOUT_MS) / 1000.0
|
|
363
|
+
|
|
364
|
+
if stream:
|
|
365
|
+
return self._run_local_streaming(
|
|
366
|
+
messages=prompt_messages,
|
|
367
|
+
model=model,
|
|
368
|
+
provider=provider,
|
|
369
|
+
temperature=temperature,
|
|
370
|
+
max_output_tokens=max_output_tokens,
|
|
371
|
+
top_p=top_p,
|
|
372
|
+
stop_sequences=stop_sequences,
|
|
373
|
+
timeout=timeout_seconds,
|
|
374
|
+
provider_api_key=provider_api_key,
|
|
375
|
+
tags=tags,
|
|
376
|
+
user_id=user_id,
|
|
377
|
+
session_id=session_id,
|
|
378
|
+
send_trace=should_send_trace,
|
|
379
|
+
span_id=effective_span_id,
|
|
380
|
+
trace_id=effective_trace_id,
|
|
381
|
+
parent_span_id=parent_span_id,
|
|
382
|
+
tools=tools,
|
|
383
|
+
tool_choice=tool_choice,
|
|
384
|
+
variables=variables,
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
return self._run_local_non_streaming(
|
|
388
|
+
messages=prompt_messages,
|
|
389
|
+
model=model,
|
|
390
|
+
provider=provider,
|
|
391
|
+
temperature=temperature,
|
|
392
|
+
max_output_tokens=max_output_tokens,
|
|
393
|
+
top_p=top_p,
|
|
394
|
+
stop_sequences=stop_sequences,
|
|
395
|
+
timeout=timeout_seconds,
|
|
396
|
+
provider_api_key=provider_api_key,
|
|
397
|
+
tags=tags,
|
|
398
|
+
user_id=user_id,
|
|
399
|
+
session_id=session_id,
|
|
400
|
+
send_trace=should_send_trace,
|
|
401
|
+
span_id=effective_span_id,
|
|
402
|
+
trace_id=effective_trace_id,
|
|
403
|
+
parent_span_id=parent_span_id,
|
|
404
|
+
tools=tools,
|
|
405
|
+
tool_choice=tool_choice,
|
|
406
|
+
variables=variables,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
def _run_local_non_streaming(
|
|
410
|
+
self,
|
|
411
|
+
*,
|
|
412
|
+
messages: list[LocalPromptMessage],
|
|
413
|
+
model: str,
|
|
414
|
+
provider: LLMProvider | None,
|
|
415
|
+
temperature: float | None,
|
|
416
|
+
max_output_tokens: int | None,
|
|
417
|
+
top_p: float | None,
|
|
418
|
+
stop_sequences: list[str] | None,
|
|
419
|
+
timeout: float,
|
|
420
|
+
provider_api_key: str | None,
|
|
421
|
+
tags: list[str] | None,
|
|
422
|
+
user_id: str | None,
|
|
423
|
+
session_id: str | None,
|
|
424
|
+
send_trace: bool,
|
|
425
|
+
span_id: str,
|
|
426
|
+
trace_id: str,
|
|
427
|
+
parent_span_id: str | None,
|
|
428
|
+
tools: list[ToolDefinition] | None,
|
|
429
|
+
tool_choice: ToolChoice | None,
|
|
430
|
+
variables: dict[str, str] | None,
|
|
431
|
+
) -> RunLocalResult:
|
|
432
|
+
"""Run local prompt without streaming."""
|
|
433
|
+
start_time = time.time()
|
|
434
|
+
error_message: str | None = None
|
|
435
|
+
result_text = ""
|
|
436
|
+
result_tool_calls: list[ToolCall] = []
|
|
437
|
+
result_usage = TokenUsage(inputTokens=0, outputTokens=0, totalTokens=0)
|
|
438
|
+
result_provider = resolve_provider(model, provider)
|
|
439
|
+
finish_reason = "stop"
|
|
440
|
+
|
|
441
|
+
try:
|
|
442
|
+
completion = self._llm_client.complete(
|
|
443
|
+
model=model,
|
|
444
|
+
messages=messages,
|
|
445
|
+
provider=result_provider,
|
|
446
|
+
temperature=temperature,
|
|
447
|
+
max_tokens=max_output_tokens,
|
|
448
|
+
top_p=top_p,
|
|
449
|
+
stop=stop_sequences,
|
|
450
|
+
tools=tools,
|
|
451
|
+
tool_choice=tool_choice,
|
|
452
|
+
api_key=provider_api_key,
|
|
453
|
+
timeout=timeout,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
result_text = completion.text
|
|
457
|
+
result_tool_calls = completion.tool_calls
|
|
458
|
+
result_provider = completion.provider
|
|
459
|
+
finish_reason = completion.finish_reason
|
|
460
|
+
result_usage = TokenUsage(
|
|
461
|
+
inputTokens=completion.input_tokens,
|
|
462
|
+
outputTokens=completion.output_tokens,
|
|
463
|
+
totalTokens=completion.total_tokens,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
except TraciaError:
|
|
467
|
+
raise
|
|
468
|
+
except Exception as e:
|
|
469
|
+
error_message = sanitize_error_message(str(e))
|
|
470
|
+
raise TraciaError(
|
|
471
|
+
code=TraciaErrorCode.PROVIDER_ERROR,
|
|
472
|
+
message=f"LLM provider error: {error_message}",
|
|
473
|
+
) from e
|
|
474
|
+
finally:
|
|
475
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
476
|
+
|
|
477
|
+
if send_trace:
|
|
478
|
+
payload = CreateSpanPayload(
|
|
479
|
+
spanId=span_id,
|
|
480
|
+
model=model,
|
|
481
|
+
provider=result_provider,
|
|
482
|
+
input={"messages": [m.model_dump(by_alias=True, exclude_none=True) for m in messages]},
|
|
483
|
+
variables=variables,
|
|
484
|
+
output=result_text if not error_message else None,
|
|
485
|
+
status=SPAN_STATUS_ERROR if error_message else SPAN_STATUS_SUCCESS,
|
|
486
|
+
error=error_message,
|
|
487
|
+
latencyMs=latency_ms,
|
|
488
|
+
inputTokens=result_usage.input_tokens,
|
|
489
|
+
outputTokens=result_usage.output_tokens,
|
|
490
|
+
totalTokens=result_usage.total_tokens,
|
|
491
|
+
tags=tags,
|
|
492
|
+
userId=user_id,
|
|
493
|
+
sessionId=session_id,
|
|
494
|
+
temperature=temperature,
|
|
495
|
+
maxOutputTokens=max_output_tokens,
|
|
496
|
+
topP=top_p,
|
|
497
|
+
tools=tools,
|
|
498
|
+
toolCalls=result_tool_calls if result_tool_calls else None,
|
|
499
|
+
traceId=trace_id,
|
|
500
|
+
parentSpanId=parent_span_id,
|
|
501
|
+
)
|
|
502
|
+
self._schedule_span_creation(payload)
|
|
503
|
+
|
|
504
|
+
return RunLocalResult(
|
|
505
|
+
text=result_text,
|
|
506
|
+
spanId=span_id,
|
|
507
|
+
traceId=trace_id,
|
|
508
|
+
latencyMs=latency_ms,
|
|
509
|
+
usage=result_usage,
|
|
510
|
+
cost=None,
|
|
511
|
+
provider=result_provider,
|
|
512
|
+
model=model,
|
|
513
|
+
toolCalls=result_tool_calls,
|
|
514
|
+
finishReason=finish_reason,
|
|
515
|
+
message=build_assistant_message(result_text, result_tool_calls),
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
def _run_local_streaming(
|
|
519
|
+
self,
|
|
520
|
+
*,
|
|
521
|
+
messages: list[LocalPromptMessage],
|
|
522
|
+
model: str,
|
|
523
|
+
provider: LLMProvider | None,
|
|
524
|
+
temperature: float | None,
|
|
525
|
+
max_output_tokens: int | None,
|
|
526
|
+
top_p: float | None,
|
|
527
|
+
stop_sequences: list[str] | None,
|
|
528
|
+
timeout: float,
|
|
529
|
+
provider_api_key: str | None,
|
|
530
|
+
tags: list[str] | None,
|
|
531
|
+
user_id: str | None,
|
|
532
|
+
session_id: str | None,
|
|
533
|
+
send_trace: bool,
|
|
534
|
+
span_id: str,
|
|
535
|
+
trace_id: str,
|
|
536
|
+
parent_span_id: str | None,
|
|
537
|
+
tools: list[ToolDefinition] | None,
|
|
538
|
+
tool_choice: ToolChoice | None,
|
|
539
|
+
variables: dict[str, str] | None,
|
|
540
|
+
) -> LocalStream:
|
|
541
|
+
"""Run local prompt with streaming."""
|
|
542
|
+
start_time = time.time()
|
|
543
|
+
abort_event = Event()
|
|
544
|
+
result_future: Future[StreamResult] = Future()
|
|
545
|
+
|
|
546
|
+
chunks_iter, result_holder, resolved_provider = self._llm_client.stream(
|
|
547
|
+
model=model,
|
|
548
|
+
messages=messages,
|
|
549
|
+
provider=provider,
|
|
550
|
+
temperature=temperature,
|
|
551
|
+
max_tokens=max_output_tokens,
|
|
552
|
+
top_p=top_p,
|
|
553
|
+
stop=stop_sequences,
|
|
554
|
+
tools=tools,
|
|
555
|
+
tool_choice=tool_choice,
|
|
556
|
+
api_key=provider_api_key,
|
|
557
|
+
timeout=timeout,
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
def wrapped_chunks():
|
|
561
|
+
error_message: str | None = None
|
|
562
|
+
aborted = False
|
|
563
|
+
|
|
564
|
+
try:
|
|
565
|
+
for chunk in chunks_iter:
|
|
566
|
+
if abort_event.is_set():
|
|
567
|
+
aborted = True
|
|
568
|
+
break
|
|
569
|
+
yield chunk
|
|
570
|
+
except Exception as e:
|
|
571
|
+
error_message = sanitize_error_message(str(e))
|
|
572
|
+
raise
|
|
573
|
+
finally:
|
|
574
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
575
|
+
|
|
576
|
+
# Get completion result
|
|
577
|
+
completion = result_holder[0] if result_holder else None
|
|
578
|
+
|
|
579
|
+
result_text = completion.text if completion else ""
|
|
580
|
+
result_tool_calls = completion.tool_calls if completion else []
|
|
581
|
+
result_usage = TokenUsage(
|
|
582
|
+
inputTokens=completion.input_tokens if completion else 0,
|
|
583
|
+
outputTokens=completion.output_tokens if completion else 0,
|
|
584
|
+
totalTokens=completion.total_tokens if completion else 0,
|
|
585
|
+
)
|
|
586
|
+
finish_reason = completion.finish_reason if completion else "stop"
|
|
587
|
+
|
|
588
|
+
if send_trace:
|
|
589
|
+
payload = CreateSpanPayload(
|
|
590
|
+
spanId=span_id,
|
|
591
|
+
model=model,
|
|
592
|
+
provider=resolved_provider,
|
|
593
|
+
input={"messages": [m.model_dump(by_alias=True, exclude_none=True) for m in messages]},
|
|
594
|
+
variables=variables,
|
|
595
|
+
output=result_text if not error_message else None,
|
|
596
|
+
status=SPAN_STATUS_ERROR if error_message else SPAN_STATUS_SUCCESS,
|
|
597
|
+
error=error_message,
|
|
598
|
+
latencyMs=latency_ms,
|
|
599
|
+
inputTokens=result_usage.input_tokens,
|
|
600
|
+
outputTokens=result_usage.output_tokens,
|
|
601
|
+
totalTokens=result_usage.total_tokens,
|
|
602
|
+
tags=tags,
|
|
603
|
+
userId=user_id,
|
|
604
|
+
sessionId=session_id,
|
|
605
|
+
temperature=temperature,
|
|
606
|
+
maxOutputTokens=max_output_tokens,
|
|
607
|
+
topP=top_p,
|
|
608
|
+
tools=tools,
|
|
609
|
+
toolCalls=result_tool_calls if result_tool_calls else None,
|
|
610
|
+
traceId=trace_id,
|
|
611
|
+
parentSpanId=parent_span_id,
|
|
612
|
+
)
|
|
613
|
+
self._schedule_span_creation(payload)
|
|
614
|
+
|
|
615
|
+
# Set the result
|
|
616
|
+
stream_result = StreamResult(
|
|
617
|
+
text=result_text,
|
|
618
|
+
spanId=span_id,
|
|
619
|
+
traceId=trace_id,
|
|
620
|
+
latencyMs=latency_ms,
|
|
621
|
+
usage=result_usage,
|
|
622
|
+
cost=None,
|
|
623
|
+
provider=resolved_provider,
|
|
624
|
+
model=model,
|
|
625
|
+
toolCalls=result_tool_calls,
|
|
626
|
+
finishReason=finish_reason,
|
|
627
|
+
message=build_assistant_message(result_text, result_tool_calls),
|
|
628
|
+
aborted=aborted,
|
|
629
|
+
)
|
|
630
|
+
result_future.set_result(stream_result)
|
|
631
|
+
|
|
632
|
+
return LocalStream(
|
|
633
|
+
span_id=span_id,
|
|
634
|
+
trace_id=trace_id,
|
|
635
|
+
chunks=wrapped_chunks(),
|
|
636
|
+
result_holder=result_holder,
|
|
637
|
+
result_future=result_future,
|
|
638
|
+
abort_event=abort_event,
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
@overload
|
|
642
|
+
async def arun_local(
|
|
643
|
+
self,
|
|
644
|
+
*,
|
|
645
|
+
messages: list[dict[str, Any] | LocalPromptMessage],
|
|
646
|
+
model: str,
|
|
647
|
+
stream: Literal[False] = False,
|
|
648
|
+
provider: LLMProvider | None = None,
|
|
649
|
+
temperature: float | None = None,
|
|
650
|
+
max_output_tokens: int | None = None,
|
|
651
|
+
top_p: float | None = None,
|
|
652
|
+
stop_sequences: list[str] | None = None,
|
|
653
|
+
timeout_ms: int | None = None,
|
|
654
|
+
variables: dict[str, str] | None = None,
|
|
655
|
+
provider_api_key: str | None = None,
|
|
656
|
+
tags: list[str] | None = None,
|
|
657
|
+
user_id: str | None = None,
|
|
658
|
+
session_id: str | None = None,
|
|
659
|
+
send_trace: bool | None = None,
|
|
660
|
+
span_id: str | None = None,
|
|
661
|
+
tools: list[ToolDefinition] | None = None,
|
|
662
|
+
tool_choice: ToolChoice | None = None,
|
|
663
|
+
trace_id: str | None = None,
|
|
664
|
+
parent_span_id: str | None = None,
|
|
665
|
+
) -> RunLocalResult: ...
|
|
666
|
+
|
|
667
|
+
@overload
|
|
668
|
+
async def arun_local(
|
|
669
|
+
self,
|
|
670
|
+
*,
|
|
671
|
+
messages: list[dict[str, Any] | LocalPromptMessage],
|
|
672
|
+
model: str,
|
|
673
|
+
stream: Literal[True],
|
|
674
|
+
provider: LLMProvider | None = None,
|
|
675
|
+
temperature: float | None = None,
|
|
676
|
+
max_output_tokens: int | None = None,
|
|
677
|
+
top_p: float | None = None,
|
|
678
|
+
stop_sequences: list[str] | None = None,
|
|
679
|
+
timeout_ms: int | None = None,
|
|
680
|
+
variables: dict[str, str] | None = None,
|
|
681
|
+
provider_api_key: str | None = None,
|
|
682
|
+
tags: list[str] | None = None,
|
|
683
|
+
user_id: str | None = None,
|
|
684
|
+
session_id: str | None = None,
|
|
685
|
+
send_trace: bool | None = None,
|
|
686
|
+
span_id: str | None = None,
|
|
687
|
+
tools: list[ToolDefinition] | None = None,
|
|
688
|
+
tool_choice: ToolChoice | None = None,
|
|
689
|
+
trace_id: str | None = None,
|
|
690
|
+
parent_span_id: str | None = None,
|
|
691
|
+
) -> AsyncLocalStream: ...
|
|
692
|
+
|
|
693
|
+
async def arun_local(
|
|
694
|
+
self,
|
|
695
|
+
*,
|
|
696
|
+
messages: list[dict[str, Any] | LocalPromptMessage],
|
|
697
|
+
model: str,
|
|
698
|
+
stream: bool = False,
|
|
699
|
+
provider: LLMProvider | None = None,
|
|
700
|
+
temperature: float | None = None,
|
|
701
|
+
max_output_tokens: int | None = None,
|
|
702
|
+
top_p: float | None = None,
|
|
703
|
+
stop_sequences: list[str] | None = None,
|
|
704
|
+
timeout_ms: int | None = None,
|
|
705
|
+
variables: dict[str, str] | None = None,
|
|
706
|
+
provider_api_key: str | None = None,
|
|
707
|
+
tags: list[str] | None = None,
|
|
708
|
+
user_id: str | None = None,
|
|
709
|
+
session_id: str | None = None,
|
|
710
|
+
send_trace: bool | None = None,
|
|
711
|
+
span_id: str | None = None,
|
|
712
|
+
tools: list[ToolDefinition] | None = None,
|
|
713
|
+
tool_choice: ToolChoice | None = None,
|
|
714
|
+
trace_id: str | None = None,
|
|
715
|
+
parent_span_id: str | None = None,
|
|
716
|
+
) -> RunLocalResult | AsyncLocalStream:
|
|
717
|
+
"""Run a local prompt asynchronously.
|
|
718
|
+
|
|
719
|
+
See run_local for parameter documentation.
|
|
720
|
+
"""
|
|
721
|
+
# Validate input
|
|
722
|
+
self._validate_run_local_input(messages, model, span_id)
|
|
723
|
+
|
|
724
|
+
# Convert and interpolate messages
|
|
725
|
+
prompt_messages = self._convert_messages(messages, variables)
|
|
726
|
+
|
|
727
|
+
# Generate IDs
|
|
728
|
+
effective_span_id = span_id or generate_span_id()
|
|
729
|
+
effective_trace_id = trace_id or generate_trace_id()
|
|
730
|
+
should_send_trace = send_trace is not False
|
|
731
|
+
|
|
732
|
+
# Calculate timeout
|
|
733
|
+
timeout_seconds = (timeout_ms or DEFAULT_TIMEOUT_MS) / 1000.0
|
|
734
|
+
|
|
735
|
+
if stream:
|
|
736
|
+
return await self._arun_local_streaming(
|
|
737
|
+
messages=prompt_messages,
|
|
738
|
+
model=model,
|
|
739
|
+
provider=provider,
|
|
740
|
+
temperature=temperature,
|
|
741
|
+
max_output_tokens=max_output_tokens,
|
|
742
|
+
top_p=top_p,
|
|
743
|
+
stop_sequences=stop_sequences,
|
|
744
|
+
timeout=timeout_seconds,
|
|
745
|
+
provider_api_key=provider_api_key,
|
|
746
|
+
tags=tags,
|
|
747
|
+
user_id=user_id,
|
|
748
|
+
session_id=session_id,
|
|
749
|
+
send_trace=should_send_trace,
|
|
750
|
+
span_id=effective_span_id,
|
|
751
|
+
trace_id=effective_trace_id,
|
|
752
|
+
parent_span_id=parent_span_id,
|
|
753
|
+
tools=tools,
|
|
754
|
+
tool_choice=tool_choice,
|
|
755
|
+
variables=variables,
|
|
756
|
+
)
|
|
757
|
+
else:
|
|
758
|
+
return await self._arun_local_non_streaming(
|
|
759
|
+
messages=prompt_messages,
|
|
760
|
+
model=model,
|
|
761
|
+
provider=provider,
|
|
762
|
+
temperature=temperature,
|
|
763
|
+
max_output_tokens=max_output_tokens,
|
|
764
|
+
top_p=top_p,
|
|
765
|
+
stop_sequences=stop_sequences,
|
|
766
|
+
timeout=timeout_seconds,
|
|
767
|
+
provider_api_key=provider_api_key,
|
|
768
|
+
tags=tags,
|
|
769
|
+
user_id=user_id,
|
|
770
|
+
session_id=session_id,
|
|
771
|
+
send_trace=should_send_trace,
|
|
772
|
+
span_id=effective_span_id,
|
|
773
|
+
trace_id=effective_trace_id,
|
|
774
|
+
parent_span_id=parent_span_id,
|
|
775
|
+
tools=tools,
|
|
776
|
+
tool_choice=tool_choice,
|
|
777
|
+
variables=variables,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
async def _arun_local_non_streaming(
|
|
781
|
+
self,
|
|
782
|
+
*,
|
|
783
|
+
messages: list[LocalPromptMessage],
|
|
784
|
+
model: str,
|
|
785
|
+
provider: LLMProvider | None,
|
|
786
|
+
temperature: float | None,
|
|
787
|
+
max_output_tokens: int | None,
|
|
788
|
+
top_p: float | None,
|
|
789
|
+
stop_sequences: list[str] | None,
|
|
790
|
+
timeout: float,
|
|
791
|
+
provider_api_key: str | None,
|
|
792
|
+
tags: list[str] | None,
|
|
793
|
+
user_id: str | None,
|
|
794
|
+
session_id: str | None,
|
|
795
|
+
send_trace: bool,
|
|
796
|
+
span_id: str,
|
|
797
|
+
trace_id: str,
|
|
798
|
+
parent_span_id: str | None,
|
|
799
|
+
tools: list[ToolDefinition] | None,
|
|
800
|
+
tool_choice: ToolChoice | None,
|
|
801
|
+
variables: dict[str, str] | None,
|
|
802
|
+
) -> RunLocalResult:
|
|
803
|
+
"""Run local prompt without streaming (async)."""
|
|
804
|
+
start_time = time.time()
|
|
805
|
+
error_message: str | None = None
|
|
806
|
+
result_text = ""
|
|
807
|
+
result_tool_calls: list[ToolCall] = []
|
|
808
|
+
result_usage = TokenUsage(inputTokens=0, outputTokens=0, totalTokens=0)
|
|
809
|
+
result_provider = resolve_provider(model, provider)
|
|
810
|
+
finish_reason = "stop"
|
|
811
|
+
latency_ms = 0
|
|
812
|
+
|
|
813
|
+
try:
|
|
814
|
+
completion = await self._llm_client.acomplete(
|
|
815
|
+
model=model,
|
|
816
|
+
messages=messages,
|
|
817
|
+
provider=result_provider,
|
|
818
|
+
temperature=temperature,
|
|
819
|
+
max_tokens=max_output_tokens,
|
|
820
|
+
top_p=top_p,
|
|
821
|
+
stop=stop_sequences,
|
|
822
|
+
tools=tools,
|
|
823
|
+
tool_choice=tool_choice,
|
|
824
|
+
api_key=provider_api_key,
|
|
825
|
+
timeout=timeout,
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
result_text = completion.text
|
|
829
|
+
result_tool_calls = completion.tool_calls
|
|
830
|
+
result_provider = completion.provider
|
|
831
|
+
finish_reason = completion.finish_reason
|
|
832
|
+
result_usage = TokenUsage(
|
|
833
|
+
inputTokens=completion.input_tokens,
|
|
834
|
+
outputTokens=completion.output_tokens,
|
|
835
|
+
totalTokens=completion.total_tokens,
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
except TraciaError:
|
|
839
|
+
raise
|
|
840
|
+
except Exception as e:
|
|
841
|
+
error_message = sanitize_error_message(str(e))
|
|
842
|
+
raise TraciaError(
|
|
843
|
+
code=TraciaErrorCode.PROVIDER_ERROR,
|
|
844
|
+
message=f"LLM provider error: {error_message}",
|
|
845
|
+
) from e
|
|
846
|
+
finally:
|
|
847
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
848
|
+
|
|
849
|
+
if send_trace:
|
|
850
|
+
payload = CreateSpanPayload(
|
|
851
|
+
spanId=span_id,
|
|
852
|
+
model=model,
|
|
853
|
+
provider=result_provider,
|
|
854
|
+
input={"messages": [m.model_dump(by_alias=True, exclude_none=True) for m in messages]},
|
|
855
|
+
variables=variables,
|
|
856
|
+
output=result_text if not error_message else None,
|
|
857
|
+
status=SPAN_STATUS_ERROR if error_message else SPAN_STATUS_SUCCESS,
|
|
858
|
+
error=error_message,
|
|
859
|
+
latencyMs=latency_ms,
|
|
860
|
+
inputTokens=result_usage.input_tokens,
|
|
861
|
+
outputTokens=result_usage.output_tokens,
|
|
862
|
+
totalTokens=result_usage.total_tokens,
|
|
863
|
+
tags=tags,
|
|
864
|
+
userId=user_id,
|
|
865
|
+
sessionId=session_id,
|
|
866
|
+
temperature=temperature,
|
|
867
|
+
maxOutputTokens=max_output_tokens,
|
|
868
|
+
topP=top_p,
|
|
869
|
+
tools=tools,
|
|
870
|
+
toolCalls=result_tool_calls if result_tool_calls else None,
|
|
871
|
+
traceId=trace_id,
|
|
872
|
+
parentSpanId=parent_span_id,
|
|
873
|
+
)
|
|
874
|
+
self._schedule_span_creation(payload)
|
|
875
|
+
|
|
876
|
+
return RunLocalResult(
|
|
877
|
+
text=result_text,
|
|
878
|
+
spanId=span_id,
|
|
879
|
+
traceId=trace_id,
|
|
880
|
+
latencyMs=latency_ms,
|
|
881
|
+
usage=result_usage,
|
|
882
|
+
cost=None,
|
|
883
|
+
provider=result_provider,
|
|
884
|
+
model=model,
|
|
885
|
+
toolCalls=result_tool_calls,
|
|
886
|
+
finishReason=finish_reason,
|
|
887
|
+
message=build_assistant_message(result_text, result_tool_calls),
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
async def _arun_local_streaming(
|
|
891
|
+
self,
|
|
892
|
+
*,
|
|
893
|
+
messages: list[LocalPromptMessage],
|
|
894
|
+
model: str,
|
|
895
|
+
provider: LLMProvider | None,
|
|
896
|
+
temperature: float | None,
|
|
897
|
+
max_output_tokens: int | None,
|
|
898
|
+
top_p: float | None,
|
|
899
|
+
stop_sequences: list[str] | None,
|
|
900
|
+
timeout: float,
|
|
901
|
+
provider_api_key: str | None,
|
|
902
|
+
tags: list[str] | None,
|
|
903
|
+
user_id: str | None,
|
|
904
|
+
session_id: str | None,
|
|
905
|
+
send_trace: bool,
|
|
906
|
+
span_id: str,
|
|
907
|
+
trace_id: str,
|
|
908
|
+
parent_span_id: str | None,
|
|
909
|
+
tools: list[ToolDefinition] | None,
|
|
910
|
+
tool_choice: ToolChoice | None,
|
|
911
|
+
variables: dict[str, str] | None,
|
|
912
|
+
) -> AsyncLocalStream:
|
|
913
|
+
"""Run local prompt with streaming (async)."""
|
|
914
|
+
start_time = time.time()
|
|
915
|
+
abort_event = asyncio.Event()
|
|
916
|
+
loop = asyncio.get_running_loop()
|
|
917
|
+
result_future: asyncio.Future[StreamResult] = loop.create_future()
|
|
918
|
+
|
|
919
|
+
chunks_iter, result_holder, resolved_provider = await self._llm_client.astream(
|
|
920
|
+
model=model,
|
|
921
|
+
messages=messages,
|
|
922
|
+
provider=provider,
|
|
923
|
+
temperature=temperature,
|
|
924
|
+
max_tokens=max_output_tokens,
|
|
925
|
+
top_p=top_p,
|
|
926
|
+
stop=stop_sequences,
|
|
927
|
+
tools=tools,
|
|
928
|
+
tool_choice=tool_choice,
|
|
929
|
+
api_key=provider_api_key,
|
|
930
|
+
timeout=timeout,
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
async def wrapped_chunks():
|
|
934
|
+
error_message: str | None = None
|
|
935
|
+
aborted = False
|
|
936
|
+
|
|
937
|
+
try:
|
|
938
|
+
async for chunk in chunks_iter:
|
|
939
|
+
if abort_event.is_set():
|
|
940
|
+
aborted = True
|
|
941
|
+
break
|
|
942
|
+
yield chunk
|
|
943
|
+
except Exception as e:
|
|
944
|
+
error_message = sanitize_error_message(str(e))
|
|
945
|
+
raise
|
|
946
|
+
finally:
|
|
947
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
948
|
+
|
|
949
|
+
# Get completion result
|
|
950
|
+
completion = result_holder[0] if result_holder else None
|
|
951
|
+
|
|
952
|
+
result_text = completion.text if completion else ""
|
|
953
|
+
result_tool_calls = completion.tool_calls if completion else []
|
|
954
|
+
result_usage = TokenUsage(
|
|
955
|
+
inputTokens=completion.input_tokens if completion else 0,
|
|
956
|
+
outputTokens=completion.output_tokens if completion else 0,
|
|
957
|
+
totalTokens=completion.total_tokens if completion else 0,
|
|
958
|
+
)
|
|
959
|
+
finish_reason = completion.finish_reason if completion else "stop"
|
|
960
|
+
|
|
961
|
+
if send_trace:
|
|
962
|
+
payload = CreateSpanPayload(
|
|
963
|
+
spanId=span_id,
|
|
964
|
+
model=model,
|
|
965
|
+
provider=resolved_provider,
|
|
966
|
+
input={"messages": [m.model_dump(by_alias=True, exclude_none=True) for m in messages]},
|
|
967
|
+
variables=variables,
|
|
968
|
+
output=result_text if not error_message else None,
|
|
969
|
+
status=SPAN_STATUS_ERROR if error_message else SPAN_STATUS_SUCCESS,
|
|
970
|
+
error=error_message,
|
|
971
|
+
latencyMs=latency_ms,
|
|
972
|
+
inputTokens=result_usage.input_tokens,
|
|
973
|
+
outputTokens=result_usage.output_tokens,
|
|
974
|
+
totalTokens=result_usage.total_tokens,
|
|
975
|
+
tags=tags,
|
|
976
|
+
userId=user_id,
|
|
977
|
+
sessionId=session_id,
|
|
978
|
+
temperature=temperature,
|
|
979
|
+
maxOutputTokens=max_output_tokens,
|
|
980
|
+
topP=top_p,
|
|
981
|
+
tools=tools,
|
|
982
|
+
toolCalls=result_tool_calls if result_tool_calls else None,
|
|
983
|
+
traceId=trace_id,
|
|
984
|
+
parentSpanId=parent_span_id,
|
|
985
|
+
)
|
|
986
|
+
self._schedule_span_creation(payload)
|
|
987
|
+
|
|
988
|
+
# Set the result
|
|
989
|
+
stream_result = StreamResult(
|
|
990
|
+
text=result_text,
|
|
991
|
+
spanId=span_id,
|
|
992
|
+
traceId=trace_id,
|
|
993
|
+
latencyMs=latency_ms,
|
|
994
|
+
usage=result_usage,
|
|
995
|
+
cost=None,
|
|
996
|
+
provider=resolved_provider,
|
|
997
|
+
model=model,
|
|
998
|
+
toolCalls=result_tool_calls,
|
|
999
|
+
finishReason=finish_reason,
|
|
1000
|
+
message=build_assistant_message(result_text, result_tool_calls),
|
|
1001
|
+
aborted=aborted,
|
|
1002
|
+
)
|
|
1003
|
+
if not result_future.done():
|
|
1004
|
+
result_future.set_result(stream_result)
|
|
1005
|
+
|
|
1006
|
+
return AsyncLocalStream(
|
|
1007
|
+
span_id=span_id,
|
|
1008
|
+
trace_id=trace_id,
|
|
1009
|
+
chunks=wrapped_chunks(),
|
|
1010
|
+
result_holder=result_holder,
|
|
1011
|
+
result_future=result_future,
|
|
1012
|
+
abort_event=abort_event,
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
def create_session(
|
|
1016
|
+
self,
|
|
1017
|
+
*,
|
|
1018
|
+
trace_id: str | None = None,
|
|
1019
|
+
parent_span_id: str | None = None,
|
|
1020
|
+
) -> TraciaSession:
|
|
1021
|
+
"""Create a new session for linked runs.
|
|
1022
|
+
|
|
1023
|
+
Args:
|
|
1024
|
+
trace_id: Optional initial trace ID.
|
|
1025
|
+
parent_span_id: Optional initial parent span ID.
|
|
1026
|
+
|
|
1027
|
+
Returns:
|
|
1028
|
+
A new TraciaSession instance.
|
|
1029
|
+
"""
|
|
1030
|
+
return TraciaSession(self, trace_id, parent_span_id)
|
|
1031
|
+
|
|
1032
|
+
def flush(self) -> None:
|
|
1033
|
+
"""Wait for all pending span creations to complete."""
|
|
1034
|
+
with self._pending_spans_lock:
|
|
1035
|
+
futures = list(self._pending_spans.values())
|
|
1036
|
+
|
|
1037
|
+
for future in futures:
|
|
1038
|
+
try:
|
|
1039
|
+
future.result(timeout=30.0)
|
|
1040
|
+
except Exception:
|
|
1041
|
+
pass
|
|
1042
|
+
|
|
1043
|
+
async def aflush(self) -> None:
|
|
1044
|
+
"""Wait for all pending span creations to complete (async)."""
|
|
1045
|
+
# In async context, we still use the sync executor
|
|
1046
|
+
loop = asyncio.get_running_loop()
|
|
1047
|
+
await loop.run_in_executor(None, self.flush)
|
|
1048
|
+
|
|
1049
|
+
def close(self) -> None:
|
|
1050
|
+
"""Close the client and release resources."""
|
|
1051
|
+
if self._closed:
|
|
1052
|
+
return
|
|
1053
|
+
self._closed = True
|
|
1054
|
+
self.flush()
|
|
1055
|
+
self._http_client.close()
|
|
1056
|
+
if self._async_http_client:
|
|
1057
|
+
try:
|
|
1058
|
+
loop = asyncio.get_running_loop()
|
|
1059
|
+
loop.create_task(self._async_http_client.aclose())
|
|
1060
|
+
except RuntimeError:
|
|
1061
|
+
# No running loop - run synchronously
|
|
1062
|
+
try:
|
|
1063
|
+
asyncio.run(self._async_http_client.aclose())
|
|
1064
|
+
except Exception:
|
|
1065
|
+
pass
|
|
1066
|
+
self._executor.shutdown(wait=False)
|
|
1067
|
+
|
|
1068
|
+
def __del__(self) -> None:
|
|
1069
|
+
try:
|
|
1070
|
+
if not self._closed:
|
|
1071
|
+
warnings.warn(
|
|
1072
|
+
"Unclosed Tracia client. Use 'client.close()' or 'with Tracia(...) as client:'.",
|
|
1073
|
+
ResourceWarning,
|
|
1074
|
+
stacklevel=1,
|
|
1075
|
+
)
|
|
1076
|
+
except Exception:
|
|
1077
|
+
pass
|
|
1078
|
+
|
|
1079
|
+
async def aclose(self) -> None:
|
|
1080
|
+
"""Close the client and release resources (async)."""
|
|
1081
|
+
if self._closed:
|
|
1082
|
+
return
|
|
1083
|
+
self._closed = True
|
|
1084
|
+
await self.aflush()
|
|
1085
|
+
self._http_client.close()
|
|
1086
|
+
if self._async_http_client:
|
|
1087
|
+
await self._async_http_client.aclose()
|
|
1088
|
+
self._executor.shutdown(wait=False)
|
|
1089
|
+
|
|
1090
|
+
def __enter__(self) -> "Tracia":
|
|
1091
|
+
return self
|
|
1092
|
+
|
|
1093
|
+
def __exit__(self, *args: Any) -> None:
|
|
1094
|
+
self.close()
|
|
1095
|
+
|
|
1096
|
+
async def __aenter__(self) -> "Tracia":
|
|
1097
|
+
return self
|
|
1098
|
+
|
|
1099
|
+
async def __aexit__(self, *args: Any) -> None:
|
|
1100
|
+
await self.aclose()
|