tracia 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tracia/_session.py ADDED
@@ -0,0 +1,244 @@
1
+ """Session management for the Tracia SDK."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any, overload
6
+
7
+ from ._streaming import AsyncLocalStream, LocalStream
8
+ from ._types import RunLocalInput, RunLocalResult, StreamResult
9
+ from ._utils import generate_trace_id
10
+
11
+ if TYPE_CHECKING:
12
+ from ._client import Tracia
13
+
14
+
15
+ class TraciaSession:
16
+ """Session for managing related traces.
17
+
18
+ A session automatically links related runs by managing trace IDs and
19
+ parent span IDs. This is useful for multi-turn conversations or
20
+ related operations that should be grouped together.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ tracia: "Tracia",
26
+ initial_trace_id: str | None = None,
27
+ initial_parent_span_id: str | None = None,
28
+ ) -> None:
29
+ """Initialize the session.
30
+
31
+ Args:
32
+ tracia: The Tracia client instance.
33
+ initial_trace_id: Optional initial trace ID.
34
+ initial_parent_span_id: Optional initial parent span ID.
35
+ """
36
+ self._tracia = tracia
37
+ self._trace_id = initial_trace_id
38
+ self._last_span_id = initial_parent_span_id
39
+
40
+ @property
41
+ def trace_id(self) -> str | None:
42
+ """Get the current trace ID."""
43
+ return self._trace_id
44
+
45
+ @property
46
+ def last_span_id(self) -> str | None:
47
+ """Get the last span ID."""
48
+ return self._last_span_id
49
+
50
+ def reset(self) -> None:
51
+ """Reset the session, clearing trace and span IDs."""
52
+ self._trace_id = None
53
+ self._last_span_id = None
54
+
55
+ def _update_from_result(
56
+ self, trace_id: str, span_id: str
57
+ ) -> None:
58
+ """Update session state from a result.
59
+
60
+ Args:
61
+ trace_id: The trace ID from the result.
62
+ span_id: The span ID from the result.
63
+ """
64
+ if self._trace_id is None:
65
+ self._trace_id = trace_id
66
+ self._last_span_id = span_id
67
+
68
+ @overload
69
+ def run_local(
70
+ self,
71
+ *,
72
+ messages: list[dict[str, Any]],
73
+ model: str,
74
+ stream: bool = ...,
75
+ **kwargs: Any,
76
+ ) -> RunLocalResult: ...
77
+
78
+ @overload
79
+ def run_local(
80
+ self,
81
+ *,
82
+ messages: list[dict[str, Any]],
83
+ model: str,
84
+ stream: bool = True,
85
+ **kwargs: Any,
86
+ ) -> LocalStream: ...
87
+
88
+ def run_local(
89
+ self,
90
+ *,
91
+ messages: list[dict[str, Any]],
92
+ model: str,
93
+ stream: bool = False,
94
+ **kwargs: Any,
95
+ ) -> RunLocalResult | LocalStream:
96
+ """Run a local prompt with session context.
97
+
98
+ Automatically includes trace_id and parent_span_id from the session.
99
+
100
+ Args:
101
+ messages: The messages to send.
102
+ model: The model name.
103
+ stream: Whether to stream the response.
104
+ **kwargs: Additional arguments for run_local.
105
+
106
+ Returns:
107
+ The result or stream.
108
+ """
109
+ # Ensure trace_id is set
110
+ if self._trace_id is None:
111
+ self._trace_id = generate_trace_id()
112
+
113
+ # Build input with session context
114
+ kwargs["trace_id"] = self._trace_id
115
+ if self._last_span_id is not None:
116
+ kwargs["parent_span_id"] = self._last_span_id
117
+
118
+ result = self._tracia.run_local(
119
+ messages=messages,
120
+ model=model,
121
+ stream=stream,
122
+ **kwargs,
123
+ )
124
+
125
+ if stream:
126
+ # For streaming, wrap to capture result
127
+ return self._wrap_stream(result)
128
+
129
+ # Update session state from result
130
+ self._update_from_result(result.trace_id, result.span_id)
131
+ return result
132
+
133
+ def _wrap_stream(self, stream: LocalStream) -> LocalStream:
134
+ """Wrap a stream to capture the result for session state.
135
+
136
+ Args:
137
+ stream: The original stream.
138
+
139
+ Returns:
140
+ The wrapped stream.
141
+ """
142
+ # Update state when stream is consumed
143
+ original_future = stream._result_future
144
+
145
+ def on_result_ready() -> None:
146
+ if original_future.done():
147
+ try:
148
+ result = original_future.result()
149
+ self._update_from_result(result.trace_id, result.span_id)
150
+ except Exception:
151
+ pass
152
+
153
+ # Register callback
154
+ original_future.add_done_callback(lambda _: on_result_ready())
155
+ return stream
156
+
157
+ @overload
158
+ async def arun_local(
159
+ self,
160
+ *,
161
+ messages: list[dict[str, Any]],
162
+ model: str,
163
+ stream: bool = ...,
164
+ **kwargs: Any,
165
+ ) -> RunLocalResult: ...
166
+
167
+ @overload
168
+ async def arun_local(
169
+ self,
170
+ *,
171
+ messages: list[dict[str, Any]],
172
+ model: str,
173
+ stream: bool = True,
174
+ **kwargs: Any,
175
+ ) -> AsyncLocalStream: ...
176
+
177
+ async def arun_local(
178
+ self,
179
+ *,
180
+ messages: list[dict[str, Any]],
181
+ model: str,
182
+ stream: bool = False,
183
+ **kwargs: Any,
184
+ ) -> RunLocalResult | AsyncLocalStream:
185
+ """Run a local prompt with session context asynchronously.
186
+
187
+ Automatically includes trace_id and parent_span_id from the session.
188
+
189
+ Args:
190
+ messages: The messages to send.
191
+ model: The model name.
192
+ stream: Whether to stream the response.
193
+ **kwargs: Additional arguments for run_local.
194
+
195
+ Returns:
196
+ The result or stream.
197
+ """
198
+ # Ensure trace_id is set
199
+ if self._trace_id is None:
200
+ self._trace_id = generate_trace_id()
201
+
202
+ # Build input with session context
203
+ kwargs["trace_id"] = self._trace_id
204
+ if self._last_span_id is not None:
205
+ kwargs["parent_span_id"] = self._last_span_id
206
+
207
+ result = await self._tracia.arun_local(
208
+ messages=messages,
209
+ model=model,
210
+ stream=stream,
211
+ **kwargs,
212
+ )
213
+
214
+ if stream:
215
+ # For streaming, wrap to capture result
216
+ return self._wrap_async_stream(result)
217
+
218
+ # Update session state from result
219
+ self._update_from_result(result.trace_id, result.span_id)
220
+ return result
221
+
222
+ def _wrap_async_stream(self, stream: AsyncLocalStream) -> AsyncLocalStream:
223
+ """Wrap an async stream to capture the result for session state.
224
+
225
+ Args:
226
+ stream: The original stream.
227
+
228
+ Returns:
229
+ The wrapped stream.
230
+ """
231
+ # Update state when stream is consumed
232
+ original_future = stream._result_future
233
+
234
+ def on_result_ready(future: Any) -> None:
235
+ if future.done():
236
+ try:
237
+ result = future.result()
238
+ self._update_from_result(result.trace_id, result.span_id)
239
+ except Exception:
240
+ pass
241
+
242
+ # Register callback
243
+ original_future.add_done_callback(on_result_ready)
244
+ return stream
tracia/_streaming.py ADDED
@@ -0,0 +1,135 @@
1
+ """Streaming classes for the Tracia SDK."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from concurrent.futures import Future
7
+ from threading import Event
8
+ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator
9
+
10
+ from ._types import StreamResult
11
+
12
+ if TYPE_CHECKING:
13
+ from ._types import RunLocalResult
14
+
15
+
16
+ class LocalStream:
17
+ """Synchronous stream wrapper for run_local with streaming."""
18
+
19
+ def __init__(
20
+ self,
21
+ span_id: str,
22
+ trace_id: str,
23
+ chunks: Iterator[str],
24
+ result_holder: list[Any],
25
+ result_future: "Future[StreamResult]",
26
+ abort_event: Event,
27
+ ) -> None:
28
+ """Initialize the stream.
29
+
30
+ Args:
31
+ span_id: The span ID.
32
+ trace_id: The trace ID.
33
+ chunks: Iterator yielding text chunks.
34
+ result_holder: List that will hold the completion result.
35
+ result_future: Future that resolves to the stream result.
36
+ abort_event: Event to signal abort.
37
+ """
38
+ self._span_id = span_id
39
+ self._trace_id = trace_id
40
+ self._chunks = chunks
41
+ self._result_holder = result_holder
42
+ self._result_future = result_future
43
+ self._abort_event = abort_event
44
+ self._consumed = False
45
+
46
+ @property
47
+ def span_id(self) -> str:
48
+ """Get the span ID."""
49
+ return self._span_id
50
+
51
+ @property
52
+ def trace_id(self) -> str:
53
+ """Get the trace ID."""
54
+ return self._trace_id
55
+
56
+ @property
57
+ def result(self) -> "Future[StreamResult]":
58
+ """Get the future that resolves to the stream result."""
59
+ return self._result_future
60
+
61
+ def __iter__(self) -> Iterator[str]:
62
+ """Iterate over text chunks."""
63
+ if self._consumed:
64
+ raise RuntimeError("Stream already consumed")
65
+ self._consumed = True
66
+
67
+ for chunk in self._chunks:
68
+ if self._abort_event.is_set():
69
+ break
70
+ yield chunk
71
+
72
+ def abort(self) -> None:
73
+ """Abort the stream."""
74
+ self._abort_event.set()
75
+
76
+
77
+ class AsyncLocalStream:
78
+ """Asynchronous stream wrapper for run_local with streaming."""
79
+
80
+ def __init__(
81
+ self,
82
+ span_id: str,
83
+ trace_id: str,
84
+ chunks: AsyncIterator[str],
85
+ result_holder: list[Any],
86
+ result_future: "asyncio.Future[StreamResult]",
87
+ abort_event: asyncio.Event,
88
+ ) -> None:
89
+ """Initialize the async stream.
90
+
91
+ Args:
92
+ span_id: The span ID.
93
+ trace_id: The trace ID.
94
+ chunks: Async iterator yielding text chunks.
95
+ result_holder: List that will hold the completion result.
96
+ result_future: Future that resolves to the stream result.
97
+ abort_event: Event to signal abort.
98
+ """
99
+ self._span_id = span_id
100
+ self._trace_id = trace_id
101
+ self._chunks = chunks
102
+ self._result_holder = result_holder
103
+ self._result_future = result_future
104
+ self._abort_event = abort_event
105
+ self._consumed = False
106
+
107
+ @property
108
+ def span_id(self) -> str:
109
+ """Get the span ID."""
110
+ return self._span_id
111
+
112
+ @property
113
+ def trace_id(self) -> str:
114
+ """Get the trace ID."""
115
+ return self._trace_id
116
+
117
+ @property
118
+ def result(self) -> "asyncio.Future[StreamResult]":
119
+ """Get the future that resolves to the stream result."""
120
+ return self._result_future
121
+
122
+ async def __aiter__(self) -> AsyncIterator[str]:
123
+ """Iterate over text chunks asynchronously."""
124
+ if self._consumed:
125
+ raise RuntimeError("Stream already consumed")
126
+ self._consumed = True
127
+
128
+ async for chunk in self._chunks:
129
+ if self._abort_event.is_set():
130
+ break
131
+ yield chunk
132
+
133
+ def abort(self) -> None:
134
+ """Abort the stream."""
135
+ self._abort_event.set()