arize-phoenix 5.6.0__py3-none-any.whl → 5.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (39) hide show
  1. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/METADATA +4 -6
  2. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/RECORD +39 -30
  3. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/WHEEL +1 -1
  4. phoenix/config.py +58 -0
  5. phoenix/server/api/helpers/playground_clients.py +758 -0
  6. phoenix/server/api/helpers/playground_registry.py +70 -0
  7. phoenix/server/api/helpers/playground_spans.py +422 -0
  8. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  9. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  10. phoenix/server/api/input_types/InvocationParameters.py +155 -13
  11. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  12. phoenix/server/api/mutations/__init__.py +4 -0
  13. phoenix/server/api/mutations/chat_mutations.py +355 -0
  14. phoenix/server/api/queries.py +41 -52
  15. phoenix/server/api/schema.py +42 -10
  16. phoenix/server/api/subscriptions.py +378 -595
  17. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  18. phoenix/server/api/types/GenerativeProvider.py +27 -3
  19. phoenix/server/api/types/Span.py +37 -0
  20. phoenix/server/api/types/TemplateLanguage.py +9 -0
  21. phoenix/server/app.py +75 -13
  22. phoenix/server/grpc_server.py +3 -1
  23. phoenix/server/main.py +14 -1
  24. phoenix/server/static/.vite/manifest.json +31 -31
  25. phoenix/server/static/assets/{components-C70HJiXz.js → components-MllbfxfJ.js} +168 -150
  26. phoenix/server/static/assets/{index-DLe1Oo3l.js → index-BVO2YcT1.js} +2 -2
  27. phoenix/server/static/assets/{pages-C8-Sl7JI.js → pages-BHfC6jnL.js} +464 -310
  28. phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-BEuNhfwH.js} +1 -1
  29. phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-Bskhzyjm.js} +1 -1
  30. phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-DLlXCf0x.js} +1 -1
  31. phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-CRqhvLYg.js} +1 -1
  32. phoenix/server/templates/index.html +1 -0
  33. phoenix/services.py +4 -0
  34. phoenix/session/session.py +15 -1
  35. phoenix/utilities/template_formatters.py +11 -1
  36. phoenix/version.py +1 -1
  37. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/entry_points.txt +0 -0
  38. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/IP_NOTICE +0 -0
  39. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,419 +1,85 @@
1
- import json
2
- from abc import ABC, abstractmethod
3
- from collections import defaultdict
4
- from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Mapping
5
- from dataclasses import asdict
1
+ import asyncio
2
+ import logging
3
+ from asyncio import FIRST_COMPLETED, Queue, QueueEmpty, Task, create_task, wait, wait_for
4
+ from collections.abc import AsyncIterator, Iterator
6
5
  from datetime import datetime, timezone
7
- from enum import Enum
8
- from itertools import chain
9
- from traceback import format_exc
10
- from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
6
+ from typing import (
7
+ Any,
8
+ Iterable,
9
+ Mapping,
10
+ Optional,
11
+ Sequence,
12
+ TypeVar,
13
+ cast,
14
+ )
11
15
 
12
16
  import strawberry
13
- from openinference.instrumentation import safe_json_dumps
14
- from openinference.semconv.trace import (
15
- MessageAttributes,
16
- OpenInferenceMimeTypeValues,
17
- OpenInferenceSpanKindValues,
18
- SpanAttributes,
19
- ToolAttributes,
20
- ToolCallAttributes,
21
- )
22
- from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
23
- from opentelemetry.trace import StatusCode
24
- from sqlalchemy import insert, select
25
- from strawberry import UNSET
26
- from strawberry.scalars import JSON as JSONScalarType
17
+ from openinference.semconv.trace import SpanAttributes
18
+ from sqlalchemy import and_, func, insert, select
19
+ from sqlalchemy.orm import load_only
20
+ from strawberry.relay.types import GlobalID
27
21
  from strawberry.types import Info
28
22
  from typing_extensions import TypeAlias, assert_never
29
23
 
30
24
  from phoenix.datetime_utils import local_now, normalize_datetime
31
25
  from phoenix.db import models
32
26
  from phoenix.server.api.context import Context
33
- from phoenix.server.api.exceptions import BadRequest
34
- from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
35
- from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
27
+ from phoenix.server.api.exceptions import BadRequest, NotFound
28
+ from phoenix.server.api.helpers.playground_clients import (
29
+ PlaygroundStreamingClient,
30
+ initialize_playground_clients,
31
+ )
32
+ from phoenix.server.api.helpers.playground_registry import (
33
+ PLAYGROUND_CLIENT_REGISTRY,
34
+ )
35
+ from phoenix.server.api.helpers.playground_spans import (
36
+ get_db_experiment_run,
37
+ get_db_span,
38
+ get_db_trace,
39
+ streaming_llm_span,
40
+ )
41
+ from phoenix.server.api.input_types.ChatCompletionInput import (
42
+ ChatCompletionInput,
43
+ ChatCompletionOverDatasetInput,
44
+ )
36
45
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
37
- from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
38
- from phoenix.server.api.types.Span import Span, to_gql_span
39
- from phoenix.server.dml_event import SpanInsertEvent
40
- from phoenix.trace.attributes import unflatten
41
- from phoenix.trace.schemas import (
42
- SpanEvent,
43
- SpanException,
46
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
47
+ ChatCompletionSubscriptionError,
48
+ ChatCompletionSubscriptionExperiment,
49
+ ChatCompletionSubscriptionPayload,
50
+ ChatCompletionSubscriptionResult,
44
51
  )
45
- from phoenix.utilities.json import jsonify
52
+ from phoenix.server.api.types.Dataset import Dataset
53
+ from phoenix.server.api.types.DatasetExample import DatasetExample
54
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
55
+ from phoenix.server.api.types.Experiment import to_gql_experiment
56
+ from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
57
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
58
+ from phoenix.server.api.types.Span import to_gql_span
59
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
60
+ from phoenix.server.dml_event import SpanInsertEvent
61
+ from phoenix.server.types import DbSessionFactory
46
62
  from phoenix.utilities.template_formatters import (
47
63
  FStringTemplateFormatter,
48
64
  MustacheTemplateFormatter,
49
65
  TemplateFormatter,
66
+ TemplateFormatterError,
50
67
  )
51
68
 
52
- if TYPE_CHECKING:
53
- from anthropic.types import MessageParam
54
- from openai.types import CompletionUsage
55
- from openai.types.chat import (
56
- ChatCompletionMessageParam,
57
- ChatCompletionMessageToolCallParam,
58
- )
59
-
60
- PLAYGROUND_PROJECT_NAME = "playground"
61
-
62
- ToolCallID: TypeAlias = str
63
- SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
64
-
65
-
66
- @strawberry.enum
67
- class TemplateLanguage(Enum):
68
- MUSTACHE = "MUSTACHE"
69
- F_STRING = "F_STRING"
70
-
71
-
72
- @strawberry.input
73
- class TemplateOptions:
74
- variables: JSONScalarType
75
- language: TemplateLanguage
76
-
77
-
78
- @strawberry.type
79
- class TextChunk:
80
- content: str
81
-
82
-
83
- @strawberry.type
84
- class FunctionCallChunk:
85
- name: str
86
- arguments: str
87
-
88
-
89
- @strawberry.type
90
- class ToolCallChunk:
91
- id: str
92
- function: FunctionCallChunk
93
-
69
+ GenericType = TypeVar("GenericType")
94
70
 
95
- @strawberry.type
96
- class ChatCompletionSubscriptionError:
97
- message: str
71
+ logger = logging.getLogger(__name__)
98
72
 
73
+ initialize_playground_clients()
99
74
 
100
- @strawberry.type
101
- class FinishedChatCompletion:
102
- span: Span
103
-
104
-
105
- ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
106
-
107
- ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
108
- Union[TextChunk, ToolCallChunk, FinishedChatCompletion, ChatCompletionSubscriptionError],
109
- strawberry.union("ChatCompletionSubscriptionPayload"),
75
+ ChatCompletionMessage: TypeAlias = tuple[
76
+ ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
110
77
  ]
111
-
112
-
113
- @strawberry.input
114
- class GenerativeModelInput:
115
- provider_key: GenerativeProviderKey
116
- name: str
117
- """ The name of the model. Or the Deployment Name for Azure OpenAI models. """
118
- endpoint: Optional[str] = UNSET
119
- """ The endpoint to use for the model. Only required for Azure OpenAI models. """
120
- api_version: Optional[str] = UNSET
121
- """ The API version to use for the model. """
122
-
123
-
124
- @strawberry.input
125
- class ChatCompletionInput:
126
- messages: list[ChatCompletionMessageInput]
127
- model: GenerativeModelInput
128
- invocation_parameters: InvocationParameters = strawberry.field(default_factory=dict)
129
- tools: Optional[list[JSONScalarType]] = UNSET
130
- template: Optional[TemplateOptions] = UNSET
131
- api_key: Optional[str] = strawberry.field(default=None)
132
-
133
-
134
- PLAYGROUND_STREAMING_CLIENT_REGISTRY: dict[
135
- GenerativeProviderKey, type["PlaygroundStreamingClient"]
136
- ] = {}
137
-
138
-
139
- def register_llm_client(
140
- provider_key: GenerativeProviderKey,
141
- ) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
142
- def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
143
- PLAYGROUND_STREAMING_CLIENT_REGISTRY[provider_key] = cls
144
- return cls
145
-
146
- return decorator
147
-
148
-
149
- class PlaygroundStreamingClient(ABC):
150
- def __init__(
151
- self,
152
- model: GenerativeModelInput,
153
- api_key: Optional[str] = None,
154
- set_span_attributes: Optional[SetSpanAttributesFn] = None,
155
- ) -> None:
156
- self._set_span_attributes = set_span_attributes
157
-
158
- @abstractmethod
159
- async def chat_completion_create(
160
- self,
161
- messages: list[
162
- tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
163
- ],
164
- tools: list[JSONScalarType],
165
- **invocation_parameters: Any,
166
- ) -> AsyncIterator[ChatCompletionChunk]:
167
- # a yield statement is needed to satisfy the type-checker
168
- # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
169
- yield TextChunk(content="")
170
-
171
-
172
- @register_llm_client(GenerativeProviderKey.OPENAI)
173
- class OpenAIStreamingClient(PlaygroundStreamingClient):
174
- def __init__(
175
- self,
176
- model: GenerativeModelInput,
177
- api_key: Optional[str] = None,
178
- set_span_attributes: Optional[SetSpanAttributesFn] = None,
179
- ) -> None:
180
- from openai import AsyncOpenAI
181
-
182
- super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
183
- self.client = AsyncOpenAI(api_key=api_key)
184
- self.model_name = model.name
185
-
186
- async def chat_completion_create(
187
- self,
188
- messages: list[
189
- tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
190
- ],
191
- tools: list[JSONScalarType],
192
- **invocation_parameters: Any,
193
- ) -> AsyncIterator[ChatCompletionChunk]:
194
- from openai import NOT_GIVEN
195
- from openai.types.chat import ChatCompletionStreamOptionsParam
196
-
197
- # Convert standard messages to OpenAI messages
198
- openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
199
- tool_call_ids: dict[int, str] = {}
200
- token_usage: Optional["CompletionUsage"] = None
201
- async for chunk in await self.client.chat.completions.create(
202
- messages=openai_messages,
203
- model=self.model_name,
204
- stream=True,
205
- stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
206
- tools=tools or NOT_GIVEN,
207
- **invocation_parameters,
208
- ):
209
- if (usage := chunk.usage) is not None:
210
- token_usage = usage
211
- continue
212
- choice = chunk.choices[0]
213
- delta = choice.delta
214
- if choice.finish_reason is None:
215
- if isinstance(chunk_content := delta.content, str):
216
- text_chunk = TextChunk(content=chunk_content)
217
- yield text_chunk
218
- if (tool_calls := delta.tool_calls) is not None:
219
- for tool_call_index, tool_call in enumerate(tool_calls):
220
- tool_call_id = (
221
- tool_call.id
222
- if tool_call.id is not None
223
- else tool_call_ids[tool_call_index]
224
- )
225
- tool_call_ids[tool_call_index] = tool_call_id
226
- if (function := tool_call.function) is not None:
227
- tool_call_chunk = ToolCallChunk(
228
- id=tool_call_id,
229
- function=FunctionCallChunk(
230
- name=function.name or "",
231
- arguments=function.arguments or "",
232
- ),
233
- )
234
- yield tool_call_chunk
235
- if token_usage is not None and self._set_span_attributes:
236
- self._set_span_attributes(dict(self._llm_token_counts(token_usage)))
237
-
238
- def to_openai_chat_completion_param(
239
- self,
240
- role: ChatCompletionMessageRole,
241
- content: JSONScalarType,
242
- tool_call_id: Optional[str] = None,
243
- tool_calls: Optional[list[JSONScalarType]] = None,
244
- ) -> "ChatCompletionMessageParam":
245
- from openai.types.chat import (
246
- ChatCompletionAssistantMessageParam,
247
- ChatCompletionSystemMessageParam,
248
- ChatCompletionToolMessageParam,
249
- ChatCompletionUserMessageParam,
250
- )
251
-
252
- if role is ChatCompletionMessageRole.USER:
253
- return ChatCompletionUserMessageParam(
254
- {
255
- "content": content,
256
- "role": "user",
257
- }
258
- )
259
- if role is ChatCompletionMessageRole.SYSTEM:
260
- return ChatCompletionSystemMessageParam(
261
- {
262
- "content": content,
263
- "role": "system",
264
- }
265
- )
266
- if role is ChatCompletionMessageRole.AI:
267
- if tool_calls is None:
268
- return ChatCompletionAssistantMessageParam(
269
- {
270
- "content": content,
271
- "role": "assistant",
272
- }
273
- )
274
- else:
275
- return ChatCompletionAssistantMessageParam(
276
- {
277
- "content": content,
278
- "role": "assistant",
279
- "tool_calls": [
280
- self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
281
- ],
282
- }
283
- )
284
- if role is ChatCompletionMessageRole.TOOL:
285
- if tool_call_id is None:
286
- raise ValueError("tool_call_id is required for tool messages")
287
- return ChatCompletionToolMessageParam(
288
- {"content": content, "role": "tool", "tool_call_id": tool_call_id}
289
- )
290
- assert_never(role)
291
-
292
- def to_openai_tool_call_param(
293
- self,
294
- tool_call: JSONScalarType,
295
- ) -> "ChatCompletionMessageToolCallParam":
296
- from openai.types.chat import ChatCompletionMessageToolCallParam
297
-
298
- return ChatCompletionMessageToolCallParam(
299
- id=tool_call.get("id", ""),
300
- function={
301
- "name": tool_call.get("function", {}).get("name", ""),
302
- "arguments": safe_json_dumps(tool_call.get("function", {}).get("arguments", "")),
303
- },
304
- type="function",
305
- )
306
-
307
- @staticmethod
308
- def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
309
- yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
310
- yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
311
- yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
312
-
313
-
314
- @register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
315
- class AzureOpenAIStreamingClient(OpenAIStreamingClient):
316
- def __init__(
317
- self,
318
- model: GenerativeModelInput,
319
- api_key: Optional[str] = None,
320
- set_span_attributes: Optional[SetSpanAttributesFn] = None,
321
- ):
322
- from openai import AsyncAzureOpenAI
323
-
324
- super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
325
- if model.endpoint is None or model.api_version is None:
326
- raise ValueError("endpoint and api_version are required for Azure OpenAI models")
327
- self.client = AsyncAzureOpenAI(
328
- api_key=api_key,
329
- azure_endpoint=model.endpoint,
330
- api_version=model.api_version,
331
- )
332
-
333
-
334
- @register_llm_client(GenerativeProviderKey.ANTHROPIC)
335
- class AnthropicStreamingClient(PlaygroundStreamingClient):
336
- def __init__(
337
- self,
338
- model: GenerativeModelInput,
339
- api_key: Optional[str] = None,
340
- set_span_attributes: Optional[SetSpanAttributesFn] = None,
341
- ) -> None:
342
- import anthropic
343
-
344
- super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
345
- self.client = anthropic.AsyncAnthropic(api_key=api_key)
346
- self.model_name = model.name
347
-
348
- async def chat_completion_create(
349
- self,
350
- messages: list[
351
- tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
352
- ],
353
- tools: list[JSONScalarType],
354
- **invocation_parameters: Any,
355
- ) -> AsyncIterator[ChatCompletionChunk]:
356
- import anthropic.lib.streaming as anthropic_streaming
357
- import anthropic.types as anthropic_types
358
-
359
- anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
360
-
361
- anthropic_params = {
362
- "messages": anthropic_messages,
363
- "model": self.model_name,
364
- "system": system_prompt,
365
- "max_tokens": 1024,
366
- **invocation_parameters,
367
- }
368
- async with self.client.messages.stream(**anthropic_params) as stream:
369
- async for event in stream:
370
- if isinstance(event, anthropic_types.RawMessageStartEvent):
371
- if self._set_span_attributes:
372
- self._set_span_attributes(
373
- {LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
374
- )
375
- elif isinstance(event, anthropic_streaming.TextEvent):
376
- yield TextChunk(content=event.text)
377
- elif isinstance(event, anthropic_streaming.MessageStopEvent):
378
- if self._set_span_attributes:
379
- self._set_span_attributes(
380
- {LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
381
- )
382
- elif isinstance(
383
- event,
384
- (
385
- anthropic_types.RawContentBlockStartEvent,
386
- anthropic_types.RawContentBlockDeltaEvent,
387
- anthropic_types.RawMessageDeltaEvent,
388
- anthropic_streaming.ContentBlockStopEvent,
389
- ),
390
- ):
391
- # event types emitted by the stream that don't contain useful information
392
- pass
393
- elif isinstance(event, anthropic_streaming.InputJsonEvent):
394
- raise NotImplementedError
395
- else:
396
- assert_never(event)
397
-
398
- def _build_anthropic_messages(
399
- self,
400
- messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
401
- ) -> tuple[list["MessageParam"], str]:
402
- anthropic_messages: list["MessageParam"] = []
403
- system_prompt = ""
404
- for role, content, _tool_call_id, _tool_calls in messages:
405
- if role == ChatCompletionMessageRole.USER:
406
- anthropic_messages.append({"role": "user", "content": content})
407
- elif role == ChatCompletionMessageRole.AI:
408
- anthropic_messages.append({"role": "assistant", "content": content})
409
- elif role == ChatCompletionMessageRole.SYSTEM:
410
- system_prompt += content + "\n"
411
- elif role == ChatCompletionMessageRole.TOOL:
412
- raise NotImplementedError
413
- else:
414
- assert_never(role)
415
-
416
- return anthropic_messages, system_prompt
78
+ DatasetExampleID: TypeAlias = GlobalID
79
+ ChatCompletionResult: TypeAlias = tuple[
80
+ DatasetExampleID, Optional[models.Span], models.ExperimentRun
81
+ ]
82
+ PLAYGROUND_PROJECT_NAME = "playground"
417
83
 
418
84
 
419
85
  @strawberry.type
@@ -422,15 +88,15 @@ class Subscription:
422
88
  async def chat_completion(
423
89
  self, info: Info[Context, None], input: ChatCompletionInput
424
90
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
425
- # Determine which LLM client to use based on provider_key
426
91
  provider_key = input.model.provider_key
427
- if (llm_client_class := PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)) is None:
92
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
93
+ if llm_client_class is None:
428
94
  raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
429
95
  llm_client = llm_client_class(
430
96
  model=input.model,
431
97
  api_key=input.api_key,
432
- set_span_attributes=lambda attrs: attributes.update(attrs),
433
98
  )
99
+
434
100
  messages = [
435
101
  (
436
102
  message.role,
@@ -441,70 +107,129 @@ class Subscription:
441
107
  for message in input.messages
442
108
  ]
443
109
  if template_options := input.template:
444
- messages = list(_formatted_messages(messages, template_options))
445
- invocation_parameters = jsonify(input.invocation_parameters)
446
- attributes = dict(
447
- chain(
448
- _llm_span_kind(),
449
- _llm_model_name(input.model.name),
450
- _llm_tools(input.tools or []),
451
- _llm_input_messages(messages),
452
- _llm_invocation_parameters(invocation_parameters),
453
- _input_value_and_mime_type(input),
110
+ messages = list(
111
+ _formatted_messages(
112
+ messages=messages,
113
+ template_language=template_options.language,
114
+ template_variables=template_options.variables,
115
+ )
454
116
  )
117
+ invocation_parameters = llm_client.construct_invocation_parameters(
118
+ input.invocation_parameters
455
119
  )
456
- status_code: StatusCode
457
- status_message = ""
458
- events: list[SpanEvent] = []
459
- start_time: datetime
460
- end_time: datetime
461
- response_chunks = []
462
- text_chunks: list[TextChunk] = []
463
- tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
464
- try:
465
- start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
120
+ async with streaming_llm_span(
121
+ input=input,
122
+ messages=messages,
123
+ invocation_parameters=invocation_parameters,
124
+ ) as span:
466
125
  async for chunk in llm_client.chat_completion_create(
467
- messages=messages,
468
- tools=input.tools or [],
469
- **invocation_parameters,
126
+ messages=messages, tools=input.tools or [], **invocation_parameters
470
127
  ):
471
- response_chunks.append(chunk)
472
- if isinstance(chunk, TextChunk):
473
- yield chunk
474
- text_chunks.append(chunk)
475
- elif isinstance(chunk, ToolCallChunk):
476
- yield chunk
477
- tool_call_chunks[chunk.id].append(chunk)
478
- else:
479
- assert_never(chunk)
480
- status_code = StatusCode.OK
481
- except Exception as error:
482
- end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
483
- status_code = StatusCode.ERROR
484
- status_message = str(error)
485
- events.append(
486
- SpanException(
487
- timestamp=end_time,
488
- message=status_message,
489
- exception_type=type(error).__name__,
490
- exception_escaped=False,
491
- exception_stacktrace=format_exc(),
128
+ span.add_response_chunk(chunk)
129
+ yield chunk
130
+ span.set_attributes(llm_client.attributes)
131
+ if span.status_message is not None:
132
+ yield ChatCompletionSubscriptionError(message=span.status_message)
133
+ async with info.context.db() as session:
134
+ if (
135
+ playground_project_id := await session.scalar(
136
+ select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
492
137
  )
493
- )
494
- yield ChatCompletionSubscriptionError(message=status_message)
495
- else:
496
- end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
497
- attributes.update(
498
- chain(
499
- _output_value_and_mime_type(response_chunks),
500
- _llm_output_messages(text_chunks, tool_call_chunks),
138
+ ) is None:
139
+ playground_project_id = await session.scalar(
140
+ insert(models.Project)
141
+ .returning(models.Project.id)
142
+ .values(
143
+ name=PLAYGROUND_PROJECT_NAME,
144
+ description="Traces from prompt playground",
145
+ )
501
146
  )
147
+ db_trace = get_db_trace(span, playground_project_id)
148
+ db_span = get_db_span(span, db_trace)
149
+ session.add(db_span)
150
+ await session.flush()
151
+ info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
152
+ yield ChatCompletionSubscriptionResult(span=to_gql_span(db_span))
153
+
154
+ @strawberry.subscription
155
+ async def chat_completion_over_dataset(
156
+ self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
157
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
158
+ provider_key = input.model.provider_key
159
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
160
+ if llm_client_class is None:
161
+ raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
162
+
163
+ dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
164
+ version_id = (
165
+ from_global_id_with_expected_type(
166
+ global_id=input.dataset_version_id, expected_type_name=DatasetVersion.__name__
502
167
  )
503
- prompt_tokens = attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
504
- completion_tokens = attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
505
- trace_id = _generate_trace_id()
506
- span_id = _generate_span_id()
168
+ if input.dataset_version_id
169
+ else None
170
+ )
507
171
  async with info.context.db() as session:
172
+ if (
173
+ dataset := await session.scalar(
174
+ select(models.Dataset).where(models.Dataset.id == dataset_id)
175
+ )
176
+ ) is None:
177
+ raise NotFound(f"Could not find dataset with ID {dataset_id}")
178
+ if version_id is None:
179
+ if (
180
+ resolved_version_id := await session.scalar(
181
+ select(models.DatasetVersion.id)
182
+ .where(models.DatasetVersion.dataset_id == dataset_id)
183
+ .order_by(models.DatasetVersion.id.desc())
184
+ .limit(1)
185
+ )
186
+ ) is None:
187
+ raise NotFound(f"No versions found for dataset with ID {dataset_id}")
188
+ else:
189
+ if (
190
+ resolved_version_id := await session.scalar(
191
+ select(models.DatasetVersion.id).where(
192
+ and_(
193
+ models.DatasetVersion.dataset_id == dataset_id,
194
+ models.DatasetVersion.id == version_id,
195
+ )
196
+ )
197
+ )
198
+ ) is None:
199
+ raise NotFound(f"Could not find dataset version with ID {version_id}")
200
+ revision_ids = (
201
+ select(func.max(models.DatasetExampleRevision.id))
202
+ .join(models.DatasetExample)
203
+ .where(
204
+ and_(
205
+ models.DatasetExample.dataset_id == dataset_id,
206
+ models.DatasetExampleRevision.dataset_version_id <= resolved_version_id,
207
+ )
208
+ )
209
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
210
+ )
211
+ if not (
212
+ revisions := [
213
+ rev
214
+ async for rev in await session.stream_scalars(
215
+ select(models.DatasetExampleRevision)
216
+ .where(
217
+ and_(
218
+ models.DatasetExampleRevision.id.in_(revision_ids),
219
+ models.DatasetExampleRevision.revision_kind != "DELETE",
220
+ )
221
+ )
222
+ .order_by(models.DatasetExampleRevision.dataset_example_id.asc())
223
+ .options(
224
+ load_only(
225
+ models.DatasetExampleRevision.dataset_example_id,
226
+ models.DatasetExampleRevision.input,
227
+ )
228
+ )
229
+ )
230
+ ]
231
+ ):
232
+ raise NotFound("No examples found for the given dataset and version")
508
233
  if (
509
234
  playground_project_id := await session.scalar(
510
235
  select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
@@ -518,138 +243,208 @@ class Subscription:
518
243
  description="Traces from prompt playground",
519
244
  )
520
245
  )
521
- playground_trace = models.Trace(
522
- project_rowid=playground_project_id,
523
- trace_id=trace_id,
524
- start_time=start_time,
525
- end_time=end_time,
246
+ experiment = models.Experiment(
247
+ dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
248
+ dataset_version_id=resolved_version_id,
249
+ name=input.experiment_name or _default_playground_experiment_name(),
250
+ description=input.experiment_description
251
+ or _default_playground_experiment_description(dataset_name=dataset.name),
252
+ repetitions=1,
253
+ metadata_=input.experiment_metadata
254
+ or _default_playground_experiment_metadata(
255
+ dataset_name=dataset.name,
256
+ dataset_id=input.dataset_id,
257
+ version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
258
+ ),
259
+ project_name=PLAYGROUND_PROJECT_NAME,
526
260
  )
527
- playground_span = models.Span(
528
- trace_rowid=playground_trace.id,
529
- span_id=span_id,
530
- parent_id=None,
531
- name="ChatCompletion",
532
- span_kind=LLM,
533
- start_time=start_time,
534
- end_time=end_time,
535
- attributes=unflatten(attributes.items()),
536
- events=[_serialize_event(event) for event in events],
537
- status_code=status_code.name,
538
- status_message=status_message,
539
- cumulative_error_count=int(status_code is StatusCode.ERROR),
540
- cumulative_llm_token_count_prompt=prompt_tokens,
541
- cumulative_llm_token_count_completion=completion_tokens,
542
- llm_token_count_prompt=prompt_tokens,
543
- llm_token_count_completion=completion_tokens,
544
- trace=playground_trace,
545
- )
546
- session.add(playground_trace)
547
- session.add(playground_span)
261
+ session.add(experiment)
548
262
  await session.flush()
549
- yield FinishedChatCompletion(span=to_gql_span(playground_span))
550
- info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
551
-
552
-
553
- def _llm_span_kind() -> Iterator[tuple[str, Any]]:
554
- yield OPENINFERENCE_SPAN_KIND, LLM
555
-
556
-
557
- def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
558
- yield LLM_MODEL_NAME, model_name
559
-
560
-
561
- def _llm_invocation_parameters(invocation_parameters: dict[str, Any]) -> Iterator[tuple[str, Any]]:
562
- yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
563
-
564
-
565
- def _llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
566
- for tool_index, tool in enumerate(tools):
567
- yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
568
-
569
-
570
- def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
571
- assert (api_key := "api_key") in (input_data := jsonify(input))
572
- input_data = {k: v for k, v in input_data.items() if k != api_key}
573
- assert api_key not in input_data
574
- yield INPUT_MIME_TYPE, JSON
575
- yield INPUT_VALUE, safe_json_dumps(input_data)
576
-
577
-
578
- def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
579
- yield OUTPUT_MIME_TYPE, JSON
580
- yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
263
+ yield ChatCompletionSubscriptionExperiment(
264
+ experiment=to_gql_experiment(experiment)
265
+ ) # eagerly yields experiment so it can be linked by consumers of the subscription
266
+
267
+ results_queue: Queue[ChatCompletionResult] = Queue()
268
+ chat_completion_streams = [
269
+ _stream_chat_completion_over_dataset_example(
270
+ input=input,
271
+ llm_client_class=llm_client_class,
272
+ revision=revision,
273
+ results_queue=results_queue,
274
+ experiment_id=experiment.id,
275
+ project_id=playground_project_id,
276
+ )
277
+ for revision in revisions
278
+ ]
279
+ stream_to_async_tasks: dict[
280
+ AsyncIterator[ChatCompletionSubscriptionPayload],
281
+ Task[ChatCompletionSubscriptionPayload],
282
+ ] = {iterator: _create_task_with_timeout(iterator) for iterator in chat_completion_streams}
283
+ batch_size = 10
284
+ while stream_to_async_tasks:
285
+ async_tasks_to_run = [task for task in stream_to_async_tasks.values()]
286
+ completed_tasks, _ = await wait(async_tasks_to_run, return_when=FIRST_COMPLETED)
287
+ for task in completed_tasks:
288
+ iterator = next(it for it, t in stream_to_async_tasks.items() if t == task)
289
+ try:
290
+ yield task.result()
291
+ except (StopAsyncIteration, asyncio.TimeoutError):
292
+ del stream_to_async_tasks[iterator] # removes exhausted iterator
293
+ except Exception as error:
294
+ del stream_to_async_tasks[iterator] # removes failed iterator
295
+ logger.exception(error)
296
+ else:
297
+ stream_to_async_tasks[iterator] = _create_task_with_timeout(iterator)
298
+ if results_queue.qsize() >= batch_size:
299
+ result_iterator = _chat_completion_result_payloads(
300
+ db=info.context.db, results=_drain_no_wait(results_queue)
301
+ )
302
+ stream_to_async_tasks[result_iterator] = _create_task_with_timeout(
303
+ result_iterator
304
+ )
305
+ if remaining_results := await _drain(results_queue):
306
+ async for result_payload in _chat_completion_result_payloads(
307
+ db=info.context.db, results=remaining_results
308
+ ):
309
+ yield result_payload
310
+
311
+
312
+ async def _stream_chat_completion_over_dataset_example(
313
+ *,
314
+ input: ChatCompletionOverDatasetInput,
315
+ llm_client_class: type["PlaygroundStreamingClient"],
316
+ revision: models.DatasetExampleRevision,
317
+ results_queue: Queue[ChatCompletionResult],
318
+ experiment_id: int,
319
+ project_id: int,
320
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
321
+ example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
322
+ llm_client = llm_client_class(
323
+ model=input.model,
324
+ api_key=input.api_key,
325
+ )
326
+ invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
327
+ messages = [
328
+ (
329
+ message.role,
330
+ message.content,
331
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
332
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
333
+ )
334
+ for message in input.messages
335
+ ]
336
+ try:
337
+ format_start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
338
+ messages = list(
339
+ _formatted_messages(
340
+ messages=messages,
341
+ template_language=input.template_language,
342
+ template_variables=revision.input,
343
+ )
344
+ )
345
+ except TemplateFormatterError as error:
346
+ format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
347
+ yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
348
+ await results_queue.put(
349
+ (
350
+ example_id,
351
+ None,
352
+ models.ExperimentRun(
353
+ experiment_id=experiment_id,
354
+ dataset_example_id=revision.dataset_example_id,
355
+ trace_id=None,
356
+ output={},
357
+ repetition_number=1,
358
+ start_time=format_start_time,
359
+ end_time=format_end_time,
360
+ error=str(error),
361
+ trace=None,
362
+ ),
363
+ )
364
+ )
365
+ return
366
+ async with streaming_llm_span(
367
+ input=input,
368
+ messages=messages,
369
+ invocation_parameters=invocation_parameters,
370
+ ) as span:
371
+ async for chunk in llm_client.chat_completion_create(
372
+ messages=messages, tools=input.tools or [], **invocation_parameters
373
+ ):
374
+ span.add_response_chunk(chunk)
375
+ chunk.dataset_example_id = example_id
376
+ yield chunk
377
+ span.set_attributes(llm_client.attributes)
378
+ db_trace = get_db_trace(span, project_id)
379
+ db_span = get_db_span(span, db_trace)
380
+ db_run = get_db_experiment_run(
381
+ db_span, db_trace, experiment_id=experiment_id, example_id=revision.dataset_example_id
382
+ )
383
+ await results_queue.put((example_id, db_span, db_run))
384
+ if span.status_message is not None:
385
+ yield ChatCompletionSubscriptionError(
386
+ message=span.status_message, dataset_example_id=example_id
387
+ )
581
388
 
582
389
 
583
- def _llm_input_messages(
584
- messages: Iterable[
585
- tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
586
- ],
587
- ) -> Iterator[tuple[str, Any]]:
588
- for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
589
- yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
590
- yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
591
- if tool_calls is not None:
592
- for tool_call_index, tool_call in enumerate(tool_calls):
593
- yield (
594
- f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
595
- tool_call["function"]["name"],
596
- )
597
- if arguments := tool_call["function"]["arguments"]:
598
- yield (
599
- f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
600
- safe_json_dumps(jsonify(arguments)),
601
- )
390
+ async def _chat_completion_result_payloads(
391
+ *,
392
+ db: DbSessionFactory,
393
+ results: Sequence[ChatCompletionResult],
394
+ ) -> AsyncIterator[ChatCompletionSubscriptionResult]:
395
+ if not results:
396
+ return
397
+ async with db() as session:
398
+ for _, span, run in results:
399
+ if span:
400
+ session.add(span)
401
+ session.add(run)
402
+ await session.flush()
403
+ for example_id, span, run in results:
404
+ yield ChatCompletionSubscriptionResult(
405
+ span=to_gql_span(span) if span else None,
406
+ experiment_run=to_gql_experiment_run(run),
407
+ dataset_example_id=example_id,
408
+ )
602
409
 
603
410
 
604
- def _llm_output_messages(
605
- text_chunks: list[TextChunk],
606
- tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
607
- ) -> Iterator[tuple[str, Any]]:
608
- yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
609
- if content := "".join(chunk.content for chunk in text_chunks):
610
- yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
611
- for tool_call_index, (_tool_call_id, tool_call_chunks_) in enumerate(tool_call_chunks.items()):
612
- if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
613
- yield (
614
- f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
615
- name,
616
- )
617
- if arguments := "".join(chunk.function.arguments for chunk in tool_call_chunks_):
618
- yield (
619
- f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
620
- arguments,
621
- )
411
+ def _create_task_with_timeout(
412
+ iterable: AsyncIterator[GenericType], timeout_in_seconds: int = 60
413
+ ) -> Task[GenericType]:
414
+ return create_task(wait_for(_as_coroutine(iterable), timeout=timeout_in_seconds))
622
415
 
623
416
 
624
- def _generate_trace_id() -> str:
625
- """
626
- Generates a random trace ID in hexadecimal format.
627
- """
628
- return _hex(DefaultOTelIDGenerator().generate_trace_id())
417
+ async def _drain(queue: Queue[GenericType]) -> list[GenericType]:
418
+ values: list[GenericType] = []
419
+ while not queue.empty():
420
+ values.append(await queue.get())
421
+ return values
629
422
 
630
423
 
631
- def _generate_span_id() -> str:
632
- """
633
- Generates a random span ID in hexadecimal format.
634
- """
635
- return _hex(DefaultOTelIDGenerator().generate_span_id())
424
+ def _drain_no_wait(queue: Queue[GenericType]) -> list[GenericType]:
425
+ values: list[GenericType] = []
426
+ while True:
427
+ try:
428
+ values.append(queue.get_nowait())
429
+ except QueueEmpty:
430
+ break
431
+ return values
636
432
 
637
433
 
638
- def _hex(number: int) -> str:
639
- """
640
- Converts an integer to a hexadecimal string.
641
- """
642
- return hex(number)[2:]
434
+ async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
435
+ return await iterable.__anext__()
643
436
 
644
437
 
645
438
  def _formatted_messages(
646
- messages: Iterable[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
647
- template_options: TemplateOptions,
439
+ *,
440
+ messages: Iterable[ChatCompletionMessage],
441
+ template_language: TemplateLanguage,
442
+ template_variables: Mapping[str, Any],
648
443
  ) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
649
444
  """
650
445
  Formats the messages using the given template options.
651
446
  """
652
- template_formatter = _template_formatter(template_language=template_options.language)
447
+ template_formatter = _template_formatter(template_language=template_language)
653
448
  (
654
449
  roles,
655
450
  templates,
@@ -657,7 +452,7 @@ def _formatted_messages(
657
452
  tool_calls,
658
453
  ) = zip(*messages)
659
454
  formatted_templates = map(
660
- lambda template: template_formatter.format(template, **template_options.variables),
455
+ lambda template: template_formatter.format(template, **template_variables),
661
456
  templates,
662
457
  )
663
458
  formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
@@ -675,36 +470,24 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
675
470
  assert_never(template_language)
676
471
 
677
472
 
678
- def _serialize_event(event: SpanEvent) -> dict[str, Any]:
679
- """
680
- Serializes a SpanEvent to a dictionary.
681
- """
682
- return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
473
+ def _default_playground_experiment_name() -> str:
474
+ return "playground-experiment"
683
475
 
684
476
 
685
- JSON = OpenInferenceMimeTypeValues.JSON.value
477
+ def _default_playground_experiment_description(dataset_name: str) -> str:
478
+ return f'Playground experiment for dataset "{dataset_name}"'
686
479
 
687
- LLM = OpenInferenceSpanKindValues.LLM.value
688
480
 
689
- OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
690
- INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
691
- INPUT_VALUE = SpanAttributes.INPUT_VALUE
692
- OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
693
- OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
694
- LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
695
- LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
696
- LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
697
- LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
698
- LLM_TOOLS = SpanAttributes.LLM_TOOLS
699
- LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
700
- LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
701
- LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
481
+ def _default_playground_experiment_metadata(
482
+ dataset_name: str, dataset_id: GlobalID, version_id: GlobalID
483
+ ) -> dict[str, Any]:
484
+ return {
485
+ "dataset_name": dataset_name,
486
+ "dataset_id": str(dataset_id),
487
+ "dataset_version_id": str(version_id),
488
+ }
702
489
 
703
- MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
704
- MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
705
- MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
706
490
 
707
- TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
708
- TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
709
-
710
- TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
491
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
492
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
493
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT