arize-phoenix 5.6.0__py3-none-any.whl → 5.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (34) hide show
  1. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +2 -2
  2. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/RECORD +34 -25
  3. phoenix/config.py +42 -0
  4. phoenix/server/api/helpers/playground_clients.py +671 -0
  5. phoenix/server/api/helpers/playground_registry.py +70 -0
  6. phoenix/server/api/helpers/playground_spans.py +325 -0
  7. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  8. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  9. phoenix/server/api/input_types/InvocationParameters.py +156 -13
  10. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  11. phoenix/server/api/mutations/__init__.py +4 -0
  12. phoenix/server/api/mutations/chat_mutations.py +374 -0
  13. phoenix/server/api/queries.py +41 -52
  14. phoenix/server/api/schema.py +42 -10
  15. phoenix/server/api/subscriptions.py +326 -595
  16. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
  17. phoenix/server/api/types/GenerativeProvider.py +27 -3
  18. phoenix/server/api/types/Span.py +37 -0
  19. phoenix/server/api/types/TemplateLanguage.py +9 -0
  20. phoenix/server/app.py +61 -13
  21. phoenix/server/main.py +14 -1
  22. phoenix/server/static/.vite/manifest.json +9 -9
  23. phoenix/server/static/assets/{components-C70HJiXz.js → components-Csu8UKOs.js} +114 -114
  24. phoenix/server/static/assets/{index-DLe1Oo3l.js → index-Bk5C9EA7.js} +1 -1
  25. phoenix/server/static/assets/{pages-C8-Sl7JI.js → pages-UeWaKXNs.js} +328 -268
  26. phoenix/server/templates/index.html +1 -0
  27. phoenix/services.py +4 -0
  28. phoenix/session/session.py +15 -1
  29. phoenix/utilities/template_formatters.py +11 -1
  30. phoenix/version.py +1 -1
  31. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
  32. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
  33. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  34. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,419 +1,73 @@
1
- import json
2
- from abc import ABC, abstractmethod
3
- from collections import defaultdict
4
- from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Mapping
5
- from dataclasses import asdict
6
- from datetime import datetime, timezone
7
- from enum import Enum
8
- from itertools import chain
9
- from traceback import format_exc
10
- from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
1
+ import logging
2
+ from asyncio import FIRST_COMPLETED, Task, create_task, wait
3
+ from collections.abc import Iterator
4
+ from typing import (
5
+ Any,
6
+ AsyncIterator,
7
+ Collection,
8
+ Iterable,
9
+ Mapping,
10
+ Optional,
11
+ TypeVar,
12
+ )
11
13
 
12
14
  import strawberry
13
- from openinference.instrumentation import safe_json_dumps
14
- from openinference.semconv.trace import (
15
- MessageAttributes,
16
- OpenInferenceMimeTypeValues,
17
- OpenInferenceSpanKindValues,
18
- SpanAttributes,
19
- ToolAttributes,
20
- ToolCallAttributes,
21
- )
22
- from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
23
- from opentelemetry.trace import StatusCode
24
- from sqlalchemy import insert, select
25
- from strawberry import UNSET
26
- from strawberry.scalars import JSON as JSONScalarType
15
+ from openinference.semconv.trace import SpanAttributes
16
+ from sqlalchemy import and_, func, insert, select
17
+ from sqlalchemy.orm import load_only
18
+ from strawberry.relay.types import GlobalID
27
19
  from strawberry.types import Info
28
20
  from typing_extensions import TypeAlias, assert_never
29
21
 
30
- from phoenix.datetime_utils import local_now, normalize_datetime
31
22
  from phoenix.db import models
32
23
  from phoenix.server.api.context import Context
33
24
  from phoenix.server.api.exceptions import BadRequest
34
- from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
35
- from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
25
+ from phoenix.server.api.helpers.playground_clients import (
26
+ PlaygroundStreamingClient,
27
+ initialize_playground_clients,
28
+ )
29
+ from phoenix.server.api.helpers.playground_registry import (
30
+ PLAYGROUND_CLIENT_REGISTRY,
31
+ )
32
+ from phoenix.server.api.helpers.playground_spans import streaming_llm_span
33
+ from phoenix.server.api.input_types.ChatCompletionInput import (
34
+ ChatCompletionInput,
35
+ ChatCompletionOverDatasetInput,
36
+ )
36
37
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
37
- from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
38
- from phoenix.server.api.types.Span import Span, to_gql_span
39
- from phoenix.server.dml_event import SpanInsertEvent
40
- from phoenix.trace.attributes import unflatten
41
- from phoenix.trace.schemas import (
42
- SpanEvent,
43
- SpanException,
38
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
39
+ ChatCompletionOverDatasetSubscriptionResult,
40
+ ChatCompletionSubscriptionError,
41
+ ChatCompletionSubscriptionPayload,
42
+ FinishedChatCompletion,
44
43
  )
45
- from phoenix.utilities.json import jsonify
44
+ from phoenix.server.api.types.Dataset import Dataset
45
+ from phoenix.server.api.types.DatasetExample import DatasetExample
46
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
47
+ from phoenix.server.api.types.Experiment import to_gql_experiment
48
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
49
+ from phoenix.server.api.types.Span import to_gql_span
50
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
51
+ from phoenix.server.dml_event import SpanInsertEvent
52
+ from phoenix.trace.attributes import get_attribute_value
46
53
  from phoenix.utilities.template_formatters import (
47
54
  FStringTemplateFormatter,
48
55
  MustacheTemplateFormatter,
49
56
  TemplateFormatter,
57
+ TemplateFormatterError,
50
58
  )
51
59
 
52
- if TYPE_CHECKING:
53
- from anthropic.types import MessageParam
54
- from openai.types import CompletionUsage
55
- from openai.types.chat import (
56
- ChatCompletionMessageParam,
57
- ChatCompletionMessageToolCallParam,
58
- )
59
-
60
- PLAYGROUND_PROJECT_NAME = "playground"
61
-
62
- ToolCallID: TypeAlias = str
63
- SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
64
-
65
-
66
- @strawberry.enum
67
- class TemplateLanguage(Enum):
68
- MUSTACHE = "MUSTACHE"
69
- F_STRING = "F_STRING"
70
-
60
+ GenericType = TypeVar("GenericType")
71
61
 
72
- @strawberry.input
73
- class TemplateOptions:
74
- variables: JSONScalarType
75
- language: TemplateLanguage
62
+ logger = logging.getLogger(__name__)
76
63
 
64
+ initialize_playground_clients()
77
65
 
78
- @strawberry.type
79
- class TextChunk:
80
- content: str
81
-
82
-
83
- @strawberry.type
84
- class FunctionCallChunk:
85
- name: str
86
- arguments: str
87
-
88
-
89
- @strawberry.type
90
- class ToolCallChunk:
91
- id: str
92
- function: FunctionCallChunk
93
-
94
-
95
- @strawberry.type
96
- class ChatCompletionSubscriptionError:
97
- message: str
98
-
99
-
100
- @strawberry.type
101
- class FinishedChatCompletion:
102
- span: Span
103
-
104
-
105
- ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
106
-
107
- ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
108
- Union[TextChunk, ToolCallChunk, FinishedChatCompletion, ChatCompletionSubscriptionError],
109
- strawberry.union("ChatCompletionSubscriptionPayload"),
66
+ ChatCompletionMessage: TypeAlias = tuple[
67
+ ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
110
68
  ]
111
-
112
-
113
- @strawberry.input
114
- class GenerativeModelInput:
115
- provider_key: GenerativeProviderKey
116
- name: str
117
- """ The name of the model. Or the Deployment Name for Azure OpenAI models. """
118
- endpoint: Optional[str] = UNSET
119
- """ The endpoint to use for the model. Only required for Azure OpenAI models. """
120
- api_version: Optional[str] = UNSET
121
- """ The API version to use for the model. """
122
-
123
-
124
- @strawberry.input
125
- class ChatCompletionInput:
126
- messages: list[ChatCompletionMessageInput]
127
- model: GenerativeModelInput
128
- invocation_parameters: InvocationParameters = strawberry.field(default_factory=dict)
129
- tools: Optional[list[JSONScalarType]] = UNSET
130
- template: Optional[TemplateOptions] = UNSET
131
- api_key: Optional[str] = strawberry.field(default=None)
132
-
133
-
134
- PLAYGROUND_STREAMING_CLIENT_REGISTRY: dict[
135
- GenerativeProviderKey, type["PlaygroundStreamingClient"]
136
- ] = {}
137
-
138
-
139
- def register_llm_client(
140
- provider_key: GenerativeProviderKey,
141
- ) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
142
- def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
143
- PLAYGROUND_STREAMING_CLIENT_REGISTRY[provider_key] = cls
144
- return cls
145
-
146
- return decorator
147
-
148
-
149
- class PlaygroundStreamingClient(ABC):
150
- def __init__(
151
- self,
152
- model: GenerativeModelInput,
153
- api_key: Optional[str] = None,
154
- set_span_attributes: Optional[SetSpanAttributesFn] = None,
155
- ) -> None:
156
- self._set_span_attributes = set_span_attributes
157
-
158
- @abstractmethod
159
- async def chat_completion_create(
160
- self,
161
- messages: list[
162
- tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
163
- ],
164
- tools: list[JSONScalarType],
165
- **invocation_parameters: Any,
166
- ) -> AsyncIterator[ChatCompletionChunk]:
167
- # a yield statement is needed to satisfy the type-checker
168
- # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
169
- yield TextChunk(content="")
170
-
171
-
172
- @register_llm_client(GenerativeProviderKey.OPENAI)
173
- class OpenAIStreamingClient(PlaygroundStreamingClient):
174
- def __init__(
175
- self,
176
- model: GenerativeModelInput,
177
- api_key: Optional[str] = None,
178
- set_span_attributes: Optional[SetSpanAttributesFn] = None,
179
- ) -> None:
180
- from openai import AsyncOpenAI
181
-
182
- super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
183
- self.client = AsyncOpenAI(api_key=api_key)
184
- self.model_name = model.name
185
-
186
- async def chat_completion_create(
187
- self,
188
- messages: list[
189
- tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
190
- ],
191
- tools: list[JSONScalarType],
192
- **invocation_parameters: Any,
193
- ) -> AsyncIterator[ChatCompletionChunk]:
194
- from openai import NOT_GIVEN
195
- from openai.types.chat import ChatCompletionStreamOptionsParam
196
-
197
- # Convert standard messages to OpenAI messages
198
- openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
199
- tool_call_ids: dict[int, str] = {}
200
- token_usage: Optional["CompletionUsage"] = None
201
- async for chunk in await self.client.chat.completions.create(
202
- messages=openai_messages,
203
- model=self.model_name,
204
- stream=True,
205
- stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
206
- tools=tools or NOT_GIVEN,
207
- **invocation_parameters,
208
- ):
209
- if (usage := chunk.usage) is not None:
210
- token_usage = usage
211
- continue
212
- choice = chunk.choices[0]
213
- delta = choice.delta
214
- if choice.finish_reason is None:
215
- if isinstance(chunk_content := delta.content, str):
216
- text_chunk = TextChunk(content=chunk_content)
217
- yield text_chunk
218
- if (tool_calls := delta.tool_calls) is not None:
219
- for tool_call_index, tool_call in enumerate(tool_calls):
220
- tool_call_id = (
221
- tool_call.id
222
- if tool_call.id is not None
223
- else tool_call_ids[tool_call_index]
224
- )
225
- tool_call_ids[tool_call_index] = tool_call_id
226
- if (function := tool_call.function) is not None:
227
- tool_call_chunk = ToolCallChunk(
228
- id=tool_call_id,
229
- function=FunctionCallChunk(
230
- name=function.name or "",
231
- arguments=function.arguments or "",
232
- ),
233
- )
234
- yield tool_call_chunk
235
- if token_usage is not None and self._set_span_attributes:
236
- self._set_span_attributes(dict(self._llm_token_counts(token_usage)))
237
-
238
- def to_openai_chat_completion_param(
239
- self,
240
- role: ChatCompletionMessageRole,
241
- content: JSONScalarType,
242
- tool_call_id: Optional[str] = None,
243
- tool_calls: Optional[list[JSONScalarType]] = None,
244
- ) -> "ChatCompletionMessageParam":
245
- from openai.types.chat import (
246
- ChatCompletionAssistantMessageParam,
247
- ChatCompletionSystemMessageParam,
248
- ChatCompletionToolMessageParam,
249
- ChatCompletionUserMessageParam,
250
- )
251
-
252
- if role is ChatCompletionMessageRole.USER:
253
- return ChatCompletionUserMessageParam(
254
- {
255
- "content": content,
256
- "role": "user",
257
- }
258
- )
259
- if role is ChatCompletionMessageRole.SYSTEM:
260
- return ChatCompletionSystemMessageParam(
261
- {
262
- "content": content,
263
- "role": "system",
264
- }
265
- )
266
- if role is ChatCompletionMessageRole.AI:
267
- if tool_calls is None:
268
- return ChatCompletionAssistantMessageParam(
269
- {
270
- "content": content,
271
- "role": "assistant",
272
- }
273
- )
274
- else:
275
- return ChatCompletionAssistantMessageParam(
276
- {
277
- "content": content,
278
- "role": "assistant",
279
- "tool_calls": [
280
- self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
281
- ],
282
- }
283
- )
284
- if role is ChatCompletionMessageRole.TOOL:
285
- if tool_call_id is None:
286
- raise ValueError("tool_call_id is required for tool messages")
287
- return ChatCompletionToolMessageParam(
288
- {"content": content, "role": "tool", "tool_call_id": tool_call_id}
289
- )
290
- assert_never(role)
291
-
292
- def to_openai_tool_call_param(
293
- self,
294
- tool_call: JSONScalarType,
295
- ) -> "ChatCompletionMessageToolCallParam":
296
- from openai.types.chat import ChatCompletionMessageToolCallParam
297
-
298
- return ChatCompletionMessageToolCallParam(
299
- id=tool_call.get("id", ""),
300
- function={
301
- "name": tool_call.get("function", {}).get("name", ""),
302
- "arguments": safe_json_dumps(tool_call.get("function", {}).get("arguments", "")),
303
- },
304
- type="function",
305
- )
306
-
307
- @staticmethod
308
- def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
309
- yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
310
- yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
311
- yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
312
-
313
-
314
- @register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
315
- class AzureOpenAIStreamingClient(OpenAIStreamingClient):
316
- def __init__(
317
- self,
318
- model: GenerativeModelInput,
319
- api_key: Optional[str] = None,
320
- set_span_attributes: Optional[SetSpanAttributesFn] = None,
321
- ):
322
- from openai import AsyncAzureOpenAI
323
-
324
- super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
325
- if model.endpoint is None or model.api_version is None:
326
- raise ValueError("endpoint and api_version are required for Azure OpenAI models")
327
- self.client = AsyncAzureOpenAI(
328
- api_key=api_key,
329
- azure_endpoint=model.endpoint,
330
- api_version=model.api_version,
331
- )
332
-
333
-
334
- @register_llm_client(GenerativeProviderKey.ANTHROPIC)
335
- class AnthropicStreamingClient(PlaygroundStreamingClient):
336
- def __init__(
337
- self,
338
- model: GenerativeModelInput,
339
- api_key: Optional[str] = None,
340
- set_span_attributes: Optional[SetSpanAttributesFn] = None,
341
- ) -> None:
342
- import anthropic
343
-
344
- super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
345
- self.client = anthropic.AsyncAnthropic(api_key=api_key)
346
- self.model_name = model.name
347
-
348
- async def chat_completion_create(
349
- self,
350
- messages: list[
351
- tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
352
- ],
353
- tools: list[JSONScalarType],
354
- **invocation_parameters: Any,
355
- ) -> AsyncIterator[ChatCompletionChunk]:
356
- import anthropic.lib.streaming as anthropic_streaming
357
- import anthropic.types as anthropic_types
358
-
359
- anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
360
-
361
- anthropic_params = {
362
- "messages": anthropic_messages,
363
- "model": self.model_name,
364
- "system": system_prompt,
365
- "max_tokens": 1024,
366
- **invocation_parameters,
367
- }
368
- async with self.client.messages.stream(**anthropic_params) as stream:
369
- async for event in stream:
370
- if isinstance(event, anthropic_types.RawMessageStartEvent):
371
- if self._set_span_attributes:
372
- self._set_span_attributes(
373
- {LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
374
- )
375
- elif isinstance(event, anthropic_streaming.TextEvent):
376
- yield TextChunk(content=event.text)
377
- elif isinstance(event, anthropic_streaming.MessageStopEvent):
378
- if self._set_span_attributes:
379
- self._set_span_attributes(
380
- {LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
381
- )
382
- elif isinstance(
383
- event,
384
- (
385
- anthropic_types.RawContentBlockStartEvent,
386
- anthropic_types.RawContentBlockDeltaEvent,
387
- anthropic_types.RawMessageDeltaEvent,
388
- anthropic_streaming.ContentBlockStopEvent,
389
- ),
390
- ):
391
- # event types emitted by the stream that don't contain useful information
392
- pass
393
- elif isinstance(event, anthropic_streaming.InputJsonEvent):
394
- raise NotImplementedError
395
- else:
396
- assert_never(event)
397
-
398
- def _build_anthropic_messages(
399
- self,
400
- messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
401
- ) -> tuple[list["MessageParam"], str]:
402
- anthropic_messages: list["MessageParam"] = []
403
- system_prompt = ""
404
- for role, content, _tool_call_id, _tool_calls in messages:
405
- if role == ChatCompletionMessageRole.USER:
406
- anthropic_messages.append({"role": "user", "content": content})
407
- elif role == ChatCompletionMessageRole.AI:
408
- anthropic_messages.append({"role": "assistant", "content": content})
409
- elif role == ChatCompletionMessageRole.SYSTEM:
410
- system_prompt += content + "\n"
411
- elif role == ChatCompletionMessageRole.TOOL:
412
- raise NotImplementedError
413
- else:
414
- assert_never(role)
415
-
416
- return anthropic_messages, system_prompt
69
+ DatasetExampleID: TypeAlias = GlobalID
70
+ PLAYGROUND_PROJECT_NAME = "playground"
417
71
 
418
72
 
419
73
  @strawberry.type
@@ -422,15 +76,15 @@ class Subscription:
422
76
  async def chat_completion(
423
77
  self, info: Info[Context, None], input: ChatCompletionInput
424
78
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
425
- # Determine which LLM client to use based on provider_key
426
79
  provider_key = input.model.provider_key
427
- if (llm_client_class := PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)) is None:
80
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
81
+ if llm_client_class is None:
428
82
  raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
429
83
  llm_client = llm_client_class(
430
84
  model=input.model,
431
85
  api_key=input.api_key,
432
- set_span_attributes=lambda attrs: attributes.update(attrs),
433
86
  )
87
+
434
88
  messages = [
435
89
  (
436
90
  message.role,
@@ -441,69 +95,29 @@ class Subscription:
441
95
  for message in input.messages
442
96
  ]
443
97
  if template_options := input.template:
444
- messages = list(_formatted_messages(messages, template_options))
445
- invocation_parameters = jsonify(input.invocation_parameters)
446
- attributes = dict(
447
- chain(
448
- _llm_span_kind(),
449
- _llm_model_name(input.model.name),
450
- _llm_tools(input.tools or []),
451
- _llm_input_messages(messages),
452
- _llm_invocation_parameters(invocation_parameters),
453
- _input_value_and_mime_type(input),
98
+ messages = list(
99
+ _formatted_messages(
100
+ messages=messages,
101
+ template_language=template_options.language,
102
+ template_variables=template_options.variables,
103
+ )
454
104
  )
105
+ invocation_parameters = llm_client.construct_invocation_parameters(
106
+ input.invocation_parameters
455
107
  )
456
- status_code: StatusCode
457
- status_message = ""
458
- events: list[SpanEvent] = []
459
- start_time: datetime
460
- end_time: datetime
461
- response_chunks = []
462
- text_chunks: list[TextChunk] = []
463
- tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
464
- try:
465
- start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
108
+ async with streaming_llm_span(
109
+ input=input,
110
+ messages=messages,
111
+ invocation_parameters=invocation_parameters,
112
+ ) as span:
466
113
  async for chunk in llm_client.chat_completion_create(
467
- messages=messages,
468
- tools=input.tools or [],
469
- **invocation_parameters,
114
+ messages=messages, tools=input.tools or [], **invocation_parameters
470
115
  ):
471
- response_chunks.append(chunk)
472
- if isinstance(chunk, TextChunk):
473
- yield chunk
474
- text_chunks.append(chunk)
475
- elif isinstance(chunk, ToolCallChunk):
476
- yield chunk
477
- tool_call_chunks[chunk.id].append(chunk)
478
- else:
479
- assert_never(chunk)
480
- status_code = StatusCode.OK
481
- except Exception as error:
482
- end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
483
- status_code = StatusCode.ERROR
484
- status_message = str(error)
485
- events.append(
486
- SpanException(
487
- timestamp=end_time,
488
- message=status_message,
489
- exception_type=type(error).__name__,
490
- exception_escaped=False,
491
- exception_stacktrace=format_exc(),
492
- )
493
- )
494
- yield ChatCompletionSubscriptionError(message=status_message)
495
- else:
496
- end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
497
- attributes.update(
498
- chain(
499
- _output_value_and_mime_type(response_chunks),
500
- _llm_output_messages(text_chunks, tool_call_chunks),
501
- )
502
- )
503
- prompt_tokens = attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
504
- completion_tokens = attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
505
- trace_id = _generate_trace_id()
506
- span_id = _generate_span_id()
116
+ span.add_response_chunk(chunk)
117
+ yield chunk
118
+ span.set_attributes(llm_client.attributes)
119
+ if span.error_message is not None:
120
+ yield ChatCompletionSubscriptionError(message=span.error_message)
507
121
  async with info.context.db() as session:
508
122
  if (
509
123
  playground_project_id := await session.scalar(
@@ -518,138 +132,262 @@ class Subscription:
518
132
  description="Traces from prompt playground",
519
133
  )
520
134
  )
521
- playground_trace = models.Trace(
522
- project_rowid=playground_project_id,
523
- trace_id=trace_id,
524
- start_time=start_time,
525
- end_time=end_time,
526
- )
527
- playground_span = models.Span(
528
- trace_rowid=playground_trace.id,
529
- span_id=span_id,
530
- parent_id=None,
531
- name="ChatCompletion",
532
- span_kind=LLM,
533
- start_time=start_time,
534
- end_time=end_time,
535
- attributes=unflatten(attributes.items()),
536
- events=[_serialize_event(event) for event in events],
537
- status_code=status_code.name,
538
- status_message=status_message,
539
- cumulative_error_count=int(status_code is StatusCode.ERROR),
540
- cumulative_llm_token_count_prompt=prompt_tokens,
541
- cumulative_llm_token_count_completion=completion_tokens,
542
- llm_token_count_prompt=prompt_tokens,
543
- llm_token_count_completion=completion_tokens,
544
- trace=playground_trace,
545
- )
546
- session.add(playground_trace)
547
- session.add(playground_span)
135
+ db_span = span.add_to_session(session, playground_project_id)
548
136
  await session.flush()
549
- yield FinishedChatCompletion(span=to_gql_span(playground_span))
137
+ yield FinishedChatCompletion(span=to_gql_span(db_span))
550
138
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
551
139
 
140
+ @strawberry.subscription
141
+ async def chat_completion_over_dataset(
142
+ self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
143
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
144
+ provider_key = input.model.provider_key
145
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
146
+ if llm_client_class is None:
147
+ raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
552
148
 
553
- def _llm_span_kind() -> Iterator[tuple[str, Any]]:
554
- yield OPENINFERENCE_SPAN_KIND, LLM
555
-
556
-
557
- def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
558
- yield LLM_MODEL_NAME, model_name
559
-
560
-
561
- def _llm_invocation_parameters(invocation_parameters: dict[str, Any]) -> Iterator[tuple[str, Any]]:
562
- yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
563
-
564
-
565
- def _llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
566
- for tool_index, tool in enumerate(tools):
567
- yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
568
-
569
-
570
- def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
571
- assert (api_key := "api_key") in (input_data := jsonify(input))
572
- input_data = {k: v for k, v in input_data.items() if k != api_key}
573
- assert api_key not in input_data
574
- yield INPUT_MIME_TYPE, JSON
575
- yield INPUT_VALUE, safe_json_dumps(input_data)
576
-
577
-
578
- def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
579
- yield OUTPUT_MIME_TYPE, JSON
580
- yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
581
-
149
+ dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
150
+ version_id = (
151
+ from_global_id_with_expected_type(
152
+ global_id=input.dataset_version_id, expected_type_name=DatasetVersion.__name__
153
+ )
154
+ if input.dataset_version_id
155
+ else None
156
+ )
157
+ revision_ids = (
158
+ select(func.max(models.DatasetExampleRevision.id))
159
+ .join(models.DatasetExample)
160
+ .where(models.DatasetExample.dataset_id == dataset_id)
161
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
162
+ )
163
+ if version_id:
164
+ version_id_subquery = (
165
+ select(models.DatasetVersion.id)
166
+ .where(models.DatasetVersion.dataset_id == dataset_id)
167
+ .where(models.DatasetVersion.id == version_id)
168
+ .scalar_subquery()
169
+ )
170
+ revision_ids = revision_ids.where(
171
+ models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
172
+ )
173
+ query = (
174
+ select(models.DatasetExampleRevision)
175
+ .where(
176
+ and_(
177
+ models.DatasetExampleRevision.id.in_(revision_ids),
178
+ models.DatasetExampleRevision.revision_kind != "DELETE",
179
+ )
180
+ )
181
+ .order_by(models.DatasetExampleRevision.dataset_example_id.asc())
182
+ .options(
183
+ load_only(
184
+ models.DatasetExampleRevision.dataset_example_id,
185
+ models.DatasetExampleRevision.input,
186
+ )
187
+ )
188
+ )
189
+ async with info.context.db() as session:
190
+ revisions = [revision async for revision in await session.stream_scalars(query)]
191
+ if not revisions:
192
+ raise BadRequest("No examples found for the given dataset and version")
193
+
194
+ spans: dict[DatasetExampleID, streaming_llm_span] = {}
195
+ async for payload in _merge_iterators(
196
+ [
197
+ _stream_chat_completion_over_dataset_example(
198
+ input=input,
199
+ llm_client_class=llm_client_class,
200
+ revision=revision,
201
+ spans=spans,
202
+ )
203
+ for revision in revisions
204
+ ]
205
+ ):
206
+ yield payload
582
207
 
583
- def _llm_input_messages(
584
- messages: Iterable[
585
- tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
586
- ],
587
- ) -> Iterator[tuple[str, Any]]:
588
- for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
589
- yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
590
- yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
591
- if tool_calls is not None:
592
- for tool_call_index, tool_call in enumerate(tool_calls):
593
- yield (
594
- f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
595
- tool_call["function"]["name"],
208
+ async with info.context.db() as session:
209
+ if (
210
+ playground_project_id := await session.scalar(
211
+ select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
596
212
  )
597
- if arguments := tool_call["function"]["arguments"]:
598
- yield (
599
- f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
600
- safe_json_dumps(jsonify(arguments)),
213
+ ) is None:
214
+ playground_project_id = await session.scalar(
215
+ insert(models.Project)
216
+ .returning(models.Project.id)
217
+ .values(
218
+ name=PLAYGROUND_PROJECT_NAME,
219
+ description="Traces from prompt playground",
601
220
  )
602
-
603
-
604
- def _llm_output_messages(
605
- text_chunks: list[TextChunk],
606
- tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
607
- ) -> Iterator[tuple[str, Any]]:
608
- yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
609
- if content := "".join(chunk.content for chunk in text_chunks):
610
- yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
611
- for tool_call_index, (_tool_call_id, tool_call_chunks_) in enumerate(tool_call_chunks.items()):
612
- if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
613
- yield (
614
- f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
615
- name,
221
+ )
222
+ db_spans = {
223
+ example_id: span.add_to_session(session, playground_project_id)
224
+ for example_id, span in spans.items()
225
+ }
226
+ assert (
227
+ dataset_name := await session.scalar(
228
+ select(models.Dataset.name).where(models.Dataset.id == dataset_id)
229
+ )
230
+ ) is not None
231
+ if version_id is None:
232
+ resolved_version_id = await session.scalar(
233
+ select(models.DatasetVersion.id)
234
+ .where(models.DatasetVersion.dataset_id == dataset_id)
235
+ .order_by(models.DatasetVersion.id.desc())
236
+ .limit(1)
237
+ )
238
+ else:
239
+ resolved_version_id = await session.scalar(
240
+ select(models.DatasetVersion.id).where(
241
+ and_(
242
+ models.DatasetVersion.dataset_id == dataset_id,
243
+ models.DatasetVersion.id == version_id,
244
+ )
245
+ )
246
+ )
247
+ assert resolved_version_id is not None
248
+ resolved_version_node_id = GlobalID(DatasetVersion.__name__, str(resolved_version_id))
249
+ experiment = models.Experiment(
250
+ dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
251
+ dataset_version_id=resolved_version_id,
252
+ name=input.experiment_name or _DEFAULT_PLAYGROUND_EXPERIMENT_NAME,
253
+ description=input.experiment_description
254
+ or _default_playground_experiment_description(dataset_name=dataset_name),
255
+ repetitions=1,
256
+ metadata_=input.experiment_metadata
257
+ or _default_playground_experiment_metadata(
258
+ dataset_name=dataset_name,
259
+ dataset_id=input.dataset_id,
260
+ version_id=resolved_version_node_id,
261
+ ),
262
+ project_name=PLAYGROUND_PROJECT_NAME,
263
+ )
264
+ session.add(experiment)
265
+ await session.flush()
266
+ runs = [
267
+ models.ExperimentRun(
268
+ experiment_id=experiment.id,
269
+ dataset_example_id=from_global_id_with_expected_type(
270
+ example_id, DatasetExample.__name__
271
+ ),
272
+ trace_id=span.trace_id,
273
+ output=models.ExperimentRunOutput(
274
+ task_output=_get_playground_experiment_task_output(span)
275
+ ),
276
+ repetition_number=1,
277
+ start_time=span.start_time,
278
+ end_time=span.end_time,
279
+ error=error_message
280
+ if (error_message := span.error_message) is not None
281
+ else None,
282
+ prompt_token_count=get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT),
283
+ completion_token_count=get_attribute_value(
284
+ span.attributes, LLM_TOKEN_COUNT_COMPLETION
285
+ ),
286
+ )
287
+ for example_id, span in spans.items()
288
+ ]
289
+ session.add_all(runs)
290
+ await session.flush()
291
+ for example_id in spans:
292
+ yield FinishedChatCompletion(
293
+ span=to_gql_span(db_spans[example_id]),
294
+ dataset_example_id=example_id,
616
295
  )
617
- if arguments := "".join(chunk.function.arguments for chunk in tool_call_chunks_):
618
- yield (
619
- f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
620
- arguments,
296
+ yield ChatCompletionOverDatasetSubscriptionResult(experiment=to_gql_experiment(experiment))
297
+
298
+
299
+ async def _stream_chat_completion_over_dataset_example(
300
+ *,
301
+ input: ChatCompletionOverDatasetInput,
302
+ llm_client_class: type["PlaygroundStreamingClient"],
303
+ revision: models.DatasetExampleRevision,
304
+ spans: dict[DatasetExampleID, streaming_llm_span],
305
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
306
+ example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
307
+ llm_client = llm_client_class(
308
+ model=input.model,
309
+ api_key=input.api_key,
310
+ )
311
+ invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
312
+ messages = [
313
+ (
314
+ message.role,
315
+ message.content,
316
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
317
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
318
+ )
319
+ for message in input.messages
320
+ ]
321
+ try:
322
+ messages = list(
323
+ _formatted_messages(
324
+ messages=messages,
325
+ template_language=input.template_language,
326
+ template_variables=revision.input,
621
327
  )
328
+ )
329
+ except TemplateFormatterError as error:
330
+ yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
331
+ return
332
+ span = streaming_llm_span(
333
+ input=input,
334
+ messages=messages,
335
+ invocation_parameters=invocation_parameters,
336
+ )
337
+ spans[example_id] = span
338
+ async with span:
339
+ async for chunk in llm_client.chat_completion_create(
340
+ messages=messages, tools=input.tools or [], **invocation_parameters
341
+ ):
342
+ span.add_response_chunk(chunk)
343
+ chunk.dataset_example_id = example_id
344
+ yield chunk
345
+ span.set_attributes(llm_client.attributes)
346
+ if span.error_message is not None:
347
+ yield ChatCompletionSubscriptionError(
348
+ message=span.error_message, dataset_example_id=example_id
349
+ )
622
350
 
623
351
 
624
- def _generate_trace_id() -> str:
625
- """
626
- Generates a random trace ID in hexadecimal format.
627
- """
628
- return _hex(DefaultOTelIDGenerator().generate_trace_id())
352
+ async def _merge_iterators(
353
+ iterators: Collection[AsyncIterator[GenericType]],
354
+ ) -> AsyncIterator[GenericType]:
355
+ tasks: dict[AsyncIterator[GenericType], Task[GenericType]] = {
356
+ iterable: _as_task(iterable) for iterable in iterators
357
+ }
358
+ while tasks:
359
+ completed_tasks, _ = await wait(tasks.values(), return_when=FIRST_COMPLETED)
360
+ for task in completed_tasks:
361
+ iterator = next(it for it, t in tasks.items() if t == task)
362
+ try:
363
+ yield task.result()
364
+ except StopAsyncIteration:
365
+ del tasks[iterator]
366
+ except Exception as error:
367
+ del tasks[iterator]
368
+ logger.exception(error)
369
+ else:
370
+ tasks[iterator] = _as_task(iterator)
629
371
 
630
372
 
631
- def _generate_span_id() -> str:
632
- """
633
- Generates a random span ID in hexadecimal format.
634
- """
635
- return _hex(DefaultOTelIDGenerator().generate_span_id())
373
+ def _as_task(iterable: AsyncIterator[GenericType]) -> Task[GenericType]:
374
+ return create_task(_as_coroutine(iterable))
636
375
 
637
376
 
638
- def _hex(number: int) -> str:
639
- """
640
- Converts an integer to a hexadecimal string.
641
- """
642
- return hex(number)[2:]
377
+ async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
378
+ return await iterable.__anext__()
643
379
 
644
380
 
645
381
  def _formatted_messages(
646
- messages: Iterable[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
647
- template_options: TemplateOptions,
382
+ *,
383
+ messages: Iterable[ChatCompletionMessage],
384
+ template_language: TemplateLanguage,
385
+ template_variables: Mapping[str, Any],
648
386
  ) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
649
387
  """
650
388
  Formats the messages using the given template options.
651
389
  """
652
- template_formatter = _template_formatter(template_language=template_options.language)
390
+ template_formatter = _template_formatter(template_language=template_language)
653
391
  (
654
392
  roles,
655
393
  templates,
@@ -657,7 +395,7 @@ def _formatted_messages(
657
395
  tool_calls,
658
396
  ) = zip(*messages)
659
397
  formatted_templates = map(
660
- lambda template: template_formatter.format(template, **template_options.variables),
398
+ lambda template: template_formatter.format(template, **template_variables),
661
399
  templates,
662
400
  )
663
401
  formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
@@ -675,36 +413,29 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
675
413
  assert_never(template_language)
676
414
 
677
415
 
678
- def _serialize_event(event: SpanEvent) -> dict[str, Any]:
679
- """
680
- Serializes a SpanEvent to a dictionary.
681
- """
682
- return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
416
+ def _get_playground_experiment_task_output(
417
+ span: streaming_llm_span,
418
+ ) -> Any:
419
+ return get_attribute_value(span.attributes, LLM_OUTPUT_MESSAGES)
683
420
 
684
421
 
685
- JSON = OpenInferenceMimeTypeValues.JSON.value
422
+ _DEFAULT_PLAYGROUND_EXPERIMENT_NAME = "playground-experiment"
686
423
 
687
- LLM = OpenInferenceSpanKindValues.LLM.value
688
424
 
689
- OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
690
- INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
691
- INPUT_VALUE = SpanAttributes.INPUT_VALUE
692
- OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
693
- OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
694
- LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
695
- LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
696
- LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
697
- LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
698
- LLM_TOOLS = SpanAttributes.LLM_TOOLS
699
- LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
700
- LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
701
- LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
425
+ def _default_playground_experiment_description(dataset_name: str) -> str:
426
+ return f'Playground experiment for dataset "{dataset_name}"'
702
427
 
703
- MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
704
- MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
705
- MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
706
428
 
707
- TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
708
- TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
429
+ def _default_playground_experiment_metadata(
430
+ dataset_name: str, dataset_id: GlobalID, version_id: GlobalID
431
+ ) -> dict[str, Any]:
432
+ return {
433
+ "dataset_name": dataset_name,
434
+ "dataset_id": str(dataset_id),
435
+ "dataset_version_id": str(version_id),
436
+ }
709
437
 
710
- TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
438
+
439
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
440
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
441
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT