arize-phoenix 5.3.1__py3-none-any.whl → 5.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: arize-phoenix
3
- Version: 5.3.1
3
+ Version: 5.5.0
4
4
  Summary: AI Observability and Evaluation
5
5
  Project-URL: Documentation, https://docs.arize.com/phoenix/
6
6
  Project-URL: Issues, https://github.com/Arize-ai/phoenix/issues
@@ -56,6 +56,7 @@ Requires-Dist: uvicorn
56
56
  Requires-Dist: websockets
57
57
  Requires-Dist: wrapt
58
58
  Provides-Extra: container
59
+ Requires-Dist: anthropic; extra == 'container'
59
60
  Requires-Dist: fast-hdbscan>=0.2.0; extra == 'container'
60
61
  Requires-Dist: numba>=0.60.0; extra == 'container'
61
62
  Requires-Dist: openai>=1.0.0; extra == 'container'
@@ -6,7 +6,7 @@ phoenix/exceptions.py,sha256=n2L2KKuecrdflB9MsCdAYCiSEvGJptIsfRkXMoJle7A,169
6
6
  phoenix/py.typed,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
7
7
  phoenix/services.py,sha256=aTxhcOA1pZHB6U-B3TEcp6fqDF5oT0xCUvEUNMZVTUQ,5175
8
8
  phoenix/settings.py,sha256=ht-0oN-sMV6SPXrk7Tu1EZlngpAYkGNLYPhO8DyrdQI,661
9
- phoenix/version.py,sha256=XehO8pb4GklFJT4padrwyHGgMITY1M6KU3ARht53Mw4,22
9
+ phoenix/version.py,sha256=zFTHldBmR5ReiC3uSZ8VkZOEirtsq_l6QbUJYRBHlTs,22
10
10
  phoenix/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  phoenix/core/embedding_dimension.py,sha256=zKGbcvwOXgLf-yrJBpQyKtd-LEOPRKHnUToyAU8Owis,87
12
12
  phoenix/core/model.py,sha256=km_a--PBHOuA337ClRw9xqhOHhrUT6Rl9pz_zV0JYkQ,4843
@@ -94,7 +94,7 @@ phoenix/server/api/exceptions.py,sha256=TA0JuY2YRnj35qGuMSQ8d0ToHum9gWm9W--3fSKH
94
94
  phoenix/server/api/interceptor.py,sha256=ykDnoC_apUd-llVli3m1CW18kNSIgjz2qZ6m5JmPDu8,1294
95
95
  phoenix/server/api/queries.py,sha256=sE_w4X9sz0QJBGNQbOMgiZiKpR24jvxCpv6AQeGeE9Y,27153
96
96
  phoenix/server/api/schema.py,sha256=JKxcZd0UBX6VnVic9tnKAzehJ2fGYNWAu6Gr1tR1PMI,637
97
- phoenix/server/api/subscriptions.py,sha256=l5t8f23J_3znUS6gFrK3XOPkkX9BN5dHHMteASWbhVk,17083
97
+ phoenix/server/api/subscriptions.py,sha256=hxIj8WRNk1m1aXVW7PDlpPt5TwRGNbjp2yxibRZRqrc,21972
98
98
  phoenix/server/api/utils.py,sha256=Kl47G-1A7QKTDrc75BU2QK6HupsG6MWuXxy351FOfKQ,858
99
99
  phoenix/server/api/dataloaders/__init__.py,sha256=jNYvfXjnZzgA2HWTG7AZdqWGla3ZysBUDUei8Zkz6N8,3290
100
100
  phoenix/server/api/dataloaders/annotation_summaries.py,sha256=Wv8AORZoGd5TJ4Y-em8iqJu87AMpZP7lWOTr-SML-x8,5560
@@ -264,15 +264,15 @@ phoenix/server/static/apple-touch-icon-76x76.png,sha256=CT_xT12I0u2i0WU8JzBZBuOQ
264
264
  phoenix/server/static/apple-touch-icon.png,sha256=fOfpjqGpWYbJ0eAurKsyoZP1EAs6ZVooBJ_SGk2ZkDs,3801
265
265
  phoenix/server/static/favicon.ico,sha256=bY0vvCKRftemZfPShwZtE93DiiQdaYaozkPGwNFr6H8,34494
266
266
  phoenix/server/static/modernizr.js,sha256=mvK-XtkNqjOral-QvzoqsyOMECXIMu5BQwSVN_wcU9c,2564
267
- phoenix/server/static/.vite/manifest.json,sha256=SieG3iZksdblyz3-x7sbYAjWMF6-abUBPpCo2fI-M6M,1929
268
- phoenix/server/static/assets/components-CeBEsPFA.js,sha256=vQY-UcFt8MuIlB-NCAw02xsIt8Tv26aLTlYCC8GwFQQ,272882
269
- phoenix/server/static/assets/index-CRTTIb3y.js,sha256=_HHh0kyNV3oT1-vYiwkXC25jInuLGhLMyr62OLhRNS8,7162
270
- phoenix/server/static/assets/pages-vHjpYRiN.js,sha256=WEnxplw-UwjrizzaVl3c6umZ_mEq4gl8sTo5vijmSLE,586223
271
- phoenix/server/static/assets/vendor-D-NIjePD.js,sha256=tba6ip4z7SOnhlPkBkyt9CUJiDDzZwi6IRKCnbwHvuA,10898294
267
+ phoenix/server/static/.vite/manifest.json,sha256=ifveuWhOubKDOfkJLSeHrUMz3kwtdip_8ctJ4KPRfP8,1929
268
+ phoenix/server/static/assets/components-DOjuTDx9.js,sha256=JXFIP0SiBdyEvvh53dKsdoi-6o1lOUFY_IIMsGBwMxg,272352
269
+ phoenix/server/static/assets/index-C15MiAe0.js,sha256=pM97A7RT-WTM1F2TqKVvP8VgNWGHC7OehcBqq-sjzR4,7162
270
+ phoenix/server/static/assets/pages-uwuWR0d3.js,sha256=GWoWNnx_8-q56uUQWbBRM_fZv8zxw1ZZXvWbog8Q0-0,591038
271
+ phoenix/server/static/assets/vendor-6IcPAw_j.js,sha256=kA0afCcdEfjdTlLbngra3925MiR8wVgYF5d7XBBViXQ,10898278
272
272
  phoenix/server/static/assets/vendor-DxkFTwjz.css,sha256=nZrkr0u6NNElFGvpWHk9GTHeGoibCXCli1bE7mXZGZg,1816
273
- phoenix/server/static/assets/vendor-arizeai-DoY5jUTO.js,sha256=RkuJgULcP6CtTAhyLSzI9bNRiJnq84MwkYCF6MQ2sCw,306507
274
- phoenix/server/static/assets/vendor-codemirror-CIhY_nEU.js,sha256=6g23qgUXzCj6VuoJC9Uhq7Y9qb3_WN3u767NLDhfGkM,392703
275
- phoenix/server/static/assets/vendor-recharts-Dgcm35Jq.js,sha256=9ifuAYGnEKtBtbTV9sFe95rP4e1NR9j8DOVEb-hRKzM,282859
273
+ phoenix/server/static/assets/vendor-arizeai-uC2sozJ6.js,sha256=o5Su5D9_jBumfVpDXwuQC7Hzpkze4LlvVA9t0Pvh1pk,306512
274
+ phoenix/server/static/assets/vendor-codemirror-DVE2_WBr.js,sha256=AxNgfQZIek1HfON1zHobnrzHzgBeLHWu6GCltpvct9U,392705
275
+ phoenix/server/static/assets/vendor-recharts-DwrexFA4.js,sha256=QSkNeG0tfn_fdrjYf-dRBXVygFjusBsyxzQ-lepdH8w,282859
276
276
  phoenix/server/static/assets/vendor-three-DwGkEfCM.js,sha256=0D12ZgKzfKCTSdSTKJBFR2RZO_xxeMXrqDp0AszZqHY,620972
277
277
  phoenix/server/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
278
278
  phoenix/server/templates/index.html,sha256=ULNlp0-agU8pH4wCIapE1goqnwkbctAoQRgrfwH_irg,4379
@@ -286,7 +286,7 @@ phoenix/trace/attributes.py,sha256=B_OrzVaxZwFkrAFXZyicYoIti1UdUysURsvUS2GyW1U,1
286
286
  phoenix/trace/errors.py,sha256=wB1z8qdPckngdfU-TORToekvg3344oNFAA83_hC2yFY,180
287
287
  phoenix/trace/evaluation_conventions.py,sha256=t8jydM3U0-T5YpiQKRJ3tWdWGlHtzKyttYdw-ddvPOk,1048
288
288
  phoenix/trace/exporter.py,sha256=bUXh8fjJIbHurrnt4bAm-cCWqUN5FqNsIc8DZzzklkQ,4695
289
- phoenix/trace/fixtures.py,sha256=qKYXz33MD7FYkSni7ccyFjXMV5iub7serp1zMGqh1P8,18982
289
+ phoenix/trace/fixtures.py,sha256=fa2yMcjX4sshT88Ec4ce0EpSngfZQ9vXVcssNbcyYAU,19202
290
290
  phoenix/trace/otel.py,sha256=c3Iy9sUTYBliwYfChZIQ9_6xTti1BijY7t_4s2tqbyw,10028
291
291
  phoenix/trace/projects.py,sha256=2BwlNjFE-uwpqYtCu5YyBiYZk9wRPpM13vh3-Cv7GkA,2157
292
292
  phoenix/trace/schemas.py,sha256=HpWSyzec0yDHEQXEDuwyLbhpvKrqkGps8BJqGiIFj8Y,5978
@@ -313,9 +313,9 @@ phoenix/utilities/project.py,sha256=8IJuMM4yUMoooPi37sictGj8Etu9rGmq6RFtc9848cQ,
313
313
  phoenix/utilities/re.py,sha256=nr_B0txj_7CXc45953X6vr2KCRSWMuaXJSEkL8s8Sjc,2036
314
314
  phoenix/utilities/span_store.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
315
315
  phoenix/utilities/template_formatters.py,sha256=zWY6gjAMmspYqlmn-ZjKeY5SExylwePk7a1BUdDxVGI,2236
316
- arize_phoenix-5.3.1.dist-info/METADATA,sha256=8vSIv2Shr5LH4dkS6cshYjQos1ERiCIaQcLc6RZryds,12052
317
- arize_phoenix-5.3.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
318
- arize_phoenix-5.3.1.dist-info/entry_points.txt,sha256=Pgpn8Upxx9P8z8joPXZWl2LlnAlGc3gcQoVchb06X1Q,94
319
- arize_phoenix-5.3.1.dist-info/licenses/IP_NOTICE,sha256=JBqyyCYYxGDfzQ0TtsQgjts41IJoa-hiwDrBjCb9gHM,469
320
- arize_phoenix-5.3.1.dist-info/licenses/LICENSE,sha256=HFkW9REuMOkvKRACuwLPT0hRydHb3zNg-fdFt94td18,3794
321
- arize_phoenix-5.3.1.dist-info/RECORD,,
316
+ arize_phoenix-5.5.0.dist-info/METADATA,sha256=F32DnF7HFmK_M0NKDsD5mXY91llrxN1Php04_nmzo40,12099
317
+ arize_phoenix-5.5.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
318
+ arize_phoenix-5.5.0.dist-info/entry_points.txt,sha256=Pgpn8Upxx9P8z8joPXZWl2LlnAlGc3gcQoVchb06X1Q,94
319
+ arize_phoenix-5.5.0.dist-info/licenses/IP_NOTICE,sha256=JBqyyCYYxGDfzQ0TtsQgjts41IJoa-hiwDrBjCb9gHM,469
320
+ arize_phoenix-5.5.0.dist-info/licenses/LICENSE,sha256=HFkW9REuMOkvKRACuwLPT0hRydHb3zNg-fdFt94td18,3794
321
+ arize_phoenix-5.5.0.dist-info/RECORD,,
@@ -1,4 +1,5 @@
1
1
  import json
2
+ from abc import ABC, abstractmethod
2
3
  from collections import defaultdict
3
4
  from dataclasses import fields
4
5
  from datetime import datetime
@@ -9,6 +10,7 @@ from typing import (
9
10
  Annotated,
10
11
  Any,
11
12
  AsyncIterator,
13
+ Callable,
12
14
  DefaultDict,
13
15
  Dict,
14
16
  Iterable,
@@ -16,6 +18,7 @@ from typing import (
16
18
  List,
17
19
  Optional,
18
20
  Tuple,
21
+ Type,
19
22
  Union,
20
23
  )
21
24
 
@@ -45,6 +48,7 @@ from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatComple
45
48
  from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
46
49
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
47
50
  from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
51
+ from phoenix.server.api.types.Span import Span, to_gql_span
48
52
  from phoenix.server.dml_event import SpanInsertEvent
49
53
  from phoenix.trace.attributes import unflatten
50
54
  from phoenix.utilities.json import jsonify
@@ -55,14 +59,13 @@ from phoenix.utilities.template_formatters import (
55
59
  )
56
60
 
57
61
  if TYPE_CHECKING:
62
+ from anthropic.types import MessageParam
58
63
  from openai.types import CompletionUsage
59
- from openai.types.chat import (
60
- ChatCompletionMessageParam,
61
- )
64
+ from openai.types.chat import ChatCompletionMessageParam
62
65
 
63
66
  PLAYGROUND_PROJECT_NAME = "playground"
64
67
 
65
- ToolCallIndex: TypeAlias = int
68
+ ToolCallID: TypeAlias = str
66
69
 
67
70
 
68
71
  @strawberry.enum
@@ -94,8 +97,14 @@ class ToolCallChunk:
94
97
  function: FunctionCallChunk
95
98
 
96
99
 
97
- ChatCompletionChunk: TypeAlias = Annotated[
98
- Union[TextChunk, ToolCallChunk], strawberry.union("ChatCompletionChunk")
100
+ @strawberry.type
101
+ class FinishedChatCompletion:
102
+ span: Span
103
+
104
+
105
+ ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
106
+ Union[TextChunk, ToolCallChunk, FinishedChatCompletion],
107
+ strawberry.union("ChatCompletionSubscriptionPayload"),
99
108
  ]
100
109
 
101
110
 
@@ -120,39 +129,202 @@ class ChatCompletionInput:
120
129
  api_key: Optional[str] = strawberry.field(default=None)
121
130
 
122
131
 
123
- def to_openai_chat_completion_param(
124
- role: ChatCompletionMessageRole, content: JSONScalarType
125
- ) -> "ChatCompletionMessageParam":
126
- from openai.types.chat import (
127
- ChatCompletionAssistantMessageParam,
128
- ChatCompletionSystemMessageParam,
129
- ChatCompletionUserMessageParam,
130
- )
132
+ PLAYGROUND_STREAMING_CLIENT_REGISTRY: Dict[
133
+ GenerativeProviderKey, Type["PlaygroundStreamingClient"]
134
+ ] = {}
131
135
 
132
- if role is ChatCompletionMessageRole.USER:
133
- return ChatCompletionUserMessageParam(
134
- {
135
- "content": content,
136
- "role": "user",
137
- }
138
- )
139
- if role is ChatCompletionMessageRole.SYSTEM:
140
- return ChatCompletionSystemMessageParam(
141
- {
142
- "content": content,
143
- "role": "system",
144
- }
136
+
137
+ def register_llm_client(
138
+ provider_key: GenerativeProviderKey,
139
+ ) -> Callable[[Type["PlaygroundStreamingClient"]], Type["PlaygroundStreamingClient"]]:
140
+ def decorator(cls: Type["PlaygroundStreamingClient"]) -> Type["PlaygroundStreamingClient"]:
141
+ PLAYGROUND_STREAMING_CLIENT_REGISTRY[provider_key] = cls
142
+ return cls
143
+
144
+ return decorator
145
+
146
+
147
+ class PlaygroundStreamingClient(ABC):
148
+ def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None: ...
149
+
150
+ @abstractmethod
151
+ async def chat_completion_create(
152
+ self,
153
+ messages: List[Tuple[ChatCompletionMessageRole, str]],
154
+ tools: List[JSONScalarType],
155
+ **invocation_parameters: Any,
156
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
157
+ # a yield statement is needed to satisfy the type-checker
158
+ # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
159
+ yield TextChunk(content="")
160
+
161
+ @property
162
+ @abstractmethod
163
+ def attributes(self) -> Dict[str, Any]: ...
164
+
165
+
166
+ @register_llm_client(GenerativeProviderKey.OPENAI)
167
+ class OpenAIStreamingClient(PlaygroundStreamingClient):
168
+ def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
169
+ from openai import AsyncOpenAI
170
+
171
+ self.client = AsyncOpenAI(api_key=api_key)
172
+ self.model_name = model.name
173
+ self._attributes: Dict[str, Any] = {}
174
+
175
+ async def chat_completion_create(
176
+ self,
177
+ messages: List[Tuple[ChatCompletionMessageRole, str]],
178
+ tools: List[JSONScalarType],
179
+ **invocation_parameters: Any,
180
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
181
+ from openai import NOT_GIVEN
182
+ from openai.types.chat import ChatCompletionStreamOptionsParam
183
+
184
+ # Convert standard messages to OpenAI messages
185
+ openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
186
+ tool_call_ids: Dict[int, str] = {}
187
+ token_usage: Optional["CompletionUsage"] = None
188
+ async for chunk in await self.client.chat.completions.create(
189
+ messages=openai_messages,
190
+ model=self.model_name,
191
+ stream=True,
192
+ stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
193
+ tools=tools or NOT_GIVEN,
194
+ **invocation_parameters,
195
+ ):
196
+ if (usage := chunk.usage) is not None:
197
+ token_usage = usage
198
+ continue
199
+ choice = chunk.choices[0]
200
+ delta = choice.delta
201
+ if choice.finish_reason is None:
202
+ if isinstance(chunk_content := delta.content, str):
203
+ text_chunk = TextChunk(content=chunk_content)
204
+ yield text_chunk
205
+ if (tool_calls := delta.tool_calls) is not None:
206
+ for tool_call_index, tool_call in enumerate(tool_calls):
207
+ tool_call_id = (
208
+ tool_call.id
209
+ if tool_call.id is not None
210
+ else tool_call_ids[tool_call_index]
211
+ )
212
+ tool_call_ids[tool_call_index] = tool_call_id
213
+ if (function := tool_call.function) is not None:
214
+ tool_call_chunk = ToolCallChunk(
215
+ id=tool_call_id,
216
+ function=FunctionCallChunk(
217
+ name=function.name or "",
218
+ arguments=function.arguments or "",
219
+ ),
220
+ )
221
+ yield tool_call_chunk
222
+ if token_usage is not None:
223
+ self._attributes.update(_llm_token_counts(token_usage))
224
+
225
+ def to_openai_chat_completion_param(
226
+ self, role: ChatCompletionMessageRole, content: JSONScalarType
227
+ ) -> "ChatCompletionMessageParam":
228
+ from openai.types.chat import (
229
+ ChatCompletionAssistantMessageParam,
230
+ ChatCompletionSystemMessageParam,
231
+ ChatCompletionUserMessageParam,
145
232
  )
146
- if role is ChatCompletionMessageRole.AI:
147
- return ChatCompletionAssistantMessageParam(
148
- {
149
- "content": content,
150
- "role": "assistant",
151
- }
233
+
234
+ if role is ChatCompletionMessageRole.USER:
235
+ return ChatCompletionUserMessageParam(
236
+ {
237
+ "content": content,
238
+ "role": "user",
239
+ }
240
+ )
241
+ if role is ChatCompletionMessageRole.SYSTEM:
242
+ return ChatCompletionSystemMessageParam(
243
+ {
244
+ "content": content,
245
+ "role": "system",
246
+ }
247
+ )
248
+ if role is ChatCompletionMessageRole.AI:
249
+ return ChatCompletionAssistantMessageParam(
250
+ {
251
+ "content": content,
252
+ "role": "assistant",
253
+ }
254
+ )
255
+ if role is ChatCompletionMessageRole.TOOL:
256
+ raise NotImplementedError
257
+ assert_never(role)
258
+
259
+ @property
260
+ def attributes(self) -> Dict[str, Any]:
261
+ return self._attributes
262
+
263
+
264
+ @register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
265
+ class AzureOpenAIStreamingClient(OpenAIStreamingClient):
266
+ def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None):
267
+ from openai import AsyncAzureOpenAI
268
+
269
+ if model.endpoint is None or model.api_version is None:
270
+ raise ValueError("endpoint and api_version are required for Azure OpenAI models")
271
+ self.client = AsyncAzureOpenAI(
272
+ api_key=api_key,
273
+ azure_endpoint=model.endpoint,
274
+ api_version=model.api_version,
152
275
  )
153
- if role is ChatCompletionMessageRole.TOOL:
154
- raise NotImplementedError
155
- assert_never(role)
276
+
277
+
278
+ @register_llm_client(GenerativeProviderKey.ANTHROPIC)
279
+ class AnthropicStreamingClient(PlaygroundStreamingClient):
280
+ def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
281
+ import anthropic
282
+
283
+ self.client = anthropic.AsyncAnthropic(api_key=api_key)
284
+ self.model_name = model.name
285
+
286
+ async def chat_completion_create(
287
+ self,
288
+ messages: List[Tuple[ChatCompletionMessageRole, str]],
289
+ tools: List[JSONScalarType],
290
+ **invocation_parameters: Any,
291
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
292
+ anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
293
+
294
+ anthropic_params = {
295
+ "messages": anthropic_messages,
296
+ "model": self.model_name,
297
+ "system": system_prompt,
298
+ "max_tokens": 1024,
299
+ **invocation_parameters,
300
+ }
301
+
302
+ async with self.client.messages.stream(**anthropic_params) as stream:
303
+ async for text in stream.text_stream:
304
+ yield TextChunk(content=text)
305
+
306
+ def _build_anthropic_messages(
307
+ self, messages: List[Tuple[ChatCompletionMessageRole, str]]
308
+ ) -> Tuple[List["MessageParam"], str]:
309
+ anthropic_messages: List["MessageParam"] = []
310
+ system_prompt = ""
311
+ for role, content in messages:
312
+ if role == ChatCompletionMessageRole.USER:
313
+ anthropic_messages.append({"role": "user", "content": content})
314
+ elif role == ChatCompletionMessageRole.AI:
315
+ anthropic_messages.append({"role": "assistant", "content": content})
316
+ elif role == ChatCompletionMessageRole.SYSTEM:
317
+ system_prompt += content + "\n"
318
+ elif role == ChatCompletionMessageRole.TOOL:
319
+ raise NotImplementedError
320
+ else:
321
+ assert_never(role)
322
+
323
+ return anthropic_messages, system_prompt
324
+
325
+ @property
326
+ def attributes(self) -> Dict[str, Any]:
327
+ return dict()
156
328
 
157
329
 
158
330
  @strawberry.type
@@ -160,31 +332,21 @@ class Subscription:
160
332
  @strawberry.subscription
161
333
  async def chat_completion(
162
334
  self, info: Info[Context, None], input: ChatCompletionInput
163
- ) -> AsyncIterator[ChatCompletionChunk]:
164
- from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
165
- from openai.types.chat import ChatCompletionStreamOptionsParam
335
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
336
+ # Determine which LLM client to use based on provider_key
337
+ provider_key = input.model.provider_key
338
+ llm_client_class = PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)
339
+ if llm_client_class is None:
340
+ raise ValueError(f"No LLM client registered for provider '{provider_key}'")
166
341
 
167
- client: Union[AsyncAzureOpenAI, AsyncOpenAI]
168
-
169
- if input.model.provider_key == GenerativeProviderKey.AZURE_OPENAI:
170
- if input.model.endpoint is None or input.model.api_version is None:
171
- raise ValueError("endpoint and api_version are required for Azure OpenAI models")
172
- client = AsyncAzureOpenAI(
173
- api_key=input.api_key,
174
- azure_endpoint=input.model.endpoint,
175
- api_version=input.model.api_version,
176
- )
177
- else:
178
- client = AsyncOpenAI(api_key=input.api_key)
342
+ llm_client = llm_client_class(model=input.model, api_key=input.api_key)
179
343
 
180
- invocation_parameters = jsonify(input.invocation_parameters)
344
+ messages = [(message.role, message.content) for message in input.messages]
181
345
 
182
- messages: List[Tuple[ChatCompletionMessageRole, str]] = [
183
- (message.role, message.content) for message in input.messages
184
- ]
185
346
  if template_options := input.template:
186
347
  messages = list(_formatted_messages(messages, template_options))
187
- openai_messages = [to_openai_chat_completion_param(*message) for message in messages]
348
+
349
+ invocation_parameters = jsonify(input.invocation_parameters)
188
350
 
189
351
  in_memory_span_exporter = InMemorySpanExporter()
190
352
  tracer_provider = TracerProvider()
@@ -193,6 +355,7 @@ class Subscription:
193
355
  )
194
356
  tracer = tracer_provider.get_tracer(__name__)
195
357
  span_name = "ChatCompletion"
358
+
196
359
  with tracer.start_span(
197
360
  span_name,
198
361
  attributes=dict(
@@ -208,52 +371,29 @@ class Subscription:
208
371
  ) as span:
209
372
  response_chunks = []
210
373
  text_chunks: List[TextChunk] = []
211
- tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]] = defaultdict(list)
212
- role: Optional[str] = None
213
- token_usage: Optional[CompletionUsage] = None
214
- async for chunk in await client.chat.completions.create(
215
- messages=openai_messages,
216
- model=input.model.name,
217
- stream=True,
218
- tools=input.tools or NOT_GIVEN,
219
- stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
374
+ tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]] = defaultdict(list)
375
+
376
+ async for chunk in llm_client.chat_completion_create(
377
+ messages=messages,
378
+ tools=input.tools or [],
220
379
  **invocation_parameters,
221
380
  ):
222
381
  response_chunks.append(chunk)
223
- if (usage := chunk.usage) is not None:
224
- token_usage = usage
225
- continue
226
- choice = chunk.choices[0]
227
- delta = choice.delta
228
- if role is None:
229
- role = delta.role
230
- if choice.finish_reason is None:
231
- if isinstance(chunk_content := delta.content, str):
232
- text_chunk = TextChunk(content=chunk_content)
233
- yield text_chunk
234
- text_chunks.append(text_chunk)
235
- if (tool_calls := delta.tool_calls) is not None:
236
- for tool_call_index, tool_call in enumerate(tool_calls):
237
- if (function := tool_call.function) is not None:
238
- if (tool_call_id := tool_call.id) is None:
239
- first_tool_call_chunk = tool_call_chunks[tool_call_index][0]
240
- tool_call_id = first_tool_call_chunk.id
241
- tool_call_chunk = ToolCallChunk(
242
- id=tool_call_id,
243
- function=FunctionCallChunk(
244
- name=function.name or "",
245
- arguments=function.arguments or "",
246
- ),
247
- )
248
- yield tool_call_chunk
249
- tool_call_chunks[tool_call_index].append(tool_call_chunk)
382
+ if isinstance(chunk, TextChunk):
383
+ yield chunk
384
+ text_chunks.append(chunk)
385
+ elif isinstance(chunk, ToolCallChunk):
386
+ yield chunk
387
+ tool_call_chunks[chunk.id].append(chunk)
388
+
250
389
  span.set_status(StatusCode.OK)
251
- assert role is not None
390
+ llm_client_attributes = llm_client.attributes
391
+
252
392
  span.set_attributes(
253
393
  dict(
254
394
  chain(
255
395
  _output_value_and_mime_type(response_chunks),
256
- _llm_token_counts(token_usage) if token_usage is not None else [],
396
+ llm_client_attributes.items(),
257
397
  _llm_output_messages(text_chunks, tool_call_chunks),
258
398
  )
259
399
  )
@@ -265,8 +405,8 @@ class Subscription:
265
405
  assert (attributes := finished_span.attributes) is not None
266
406
  start_time = _datetime(epoch_nanoseconds=finished_span.start_time)
267
407
  end_time = _datetime(epoch_nanoseconds=finished_span.end_time)
268
- prompt_tokens = token_usage.prompt_tokens if token_usage is not None else 0
269
- completion_tokens = token_usage.completion_tokens if token_usage is not None else 0
408
+ prompt_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
409
+ completion_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
270
410
  trace_id = _hex(finished_span.context.trace_id)
271
411
  span_id = _hex(finished_span.context.span_id)
272
412
  status = finished_span.status
@@ -284,36 +424,35 @@ class Subscription:
284
424
  description="Traces from prompt playground",
285
425
  )
286
426
  )
287
- trace_rowid = await session.scalar(
288
- insert(models.Trace)
289
- .returning(models.Trace.id)
290
- .values(
291
- project_rowid=playground_project_id,
292
- trace_id=trace_id,
293
- start_time=start_time,
294
- end_time=end_time,
295
- )
427
+ playground_trace = models.Trace(
428
+ project_rowid=playground_project_id,
429
+ trace_id=trace_id,
430
+ start_time=start_time,
431
+ end_time=end_time,
296
432
  )
297
- await session.execute(
298
- insert(models.Span).values(
299
- trace_rowid=trace_rowid,
300
- span_id=span_id,
301
- parent_id=None,
302
- name=span_name,
303
- span_kind=LLM,
304
- start_time=start_time,
305
- end_time=end_time,
306
- attributes=unflatten(attributes.items()),
307
- events=finished_span.events,
308
- status_code=status.status_code.name,
309
- status_message=status.description or "",
310
- cumulative_error_count=int(not status.is_ok),
311
- cumulative_llm_token_count_prompt=prompt_tokens,
312
- cumulative_llm_token_count_completion=completion_tokens,
313
- llm_token_count_prompt=prompt_tokens,
314
- llm_token_count_completion=completion_tokens,
315
- )
433
+ playground_span = models.Span(
434
+ trace_rowid=playground_trace.id,
435
+ span_id=span_id,
436
+ parent_id=None,
437
+ name=span_name,
438
+ span_kind=LLM,
439
+ start_time=start_time,
440
+ end_time=end_time,
441
+ attributes=unflatten(attributes.items()),
442
+ events=finished_span.events,
443
+ status_code=status.status_code.name,
444
+ status_message=status.description or "",
445
+ cumulative_error_count=int(not status.is_ok),
446
+ cumulative_llm_token_count_prompt=prompt_tokens,
447
+ cumulative_llm_token_count_completion=completion_tokens,
448
+ llm_token_count_prompt=prompt_tokens,
449
+ llm_token_count_completion=completion_tokens,
450
+ trace=playground_trace,
316
451
  )
452
+ session.add(playground_trace)
453
+ session.add(playground_span)
454
+ await session.flush()
455
+ yield FinishedChatCompletion(span=to_gql_span(playground_span))
317
456
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
318
457
 
319
458
 
@@ -361,7 +500,7 @@ def _llm_input_messages(
361
500
 
362
501
  def _llm_output_messages(
363
502
  text_chunks: List[TextChunk],
364
- tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]],
503
+ tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]],
365
504
  ) -> Iterator[Tuple[str, Any]]:
366
505
  yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
367
506
  if content := "".join(chunk.content for chunk in text_chunks):