arize-phoenix 5.5.1__py3-none-any.whl → 5.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (172) hide show
  1. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/METADATA +8 -11
  2. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/RECORD +171 -171
  3. phoenix/config.py +8 -8
  4. phoenix/core/model.py +3 -3
  5. phoenix/core/model_schema.py +41 -50
  6. phoenix/core/model_schema_adapter.py +17 -16
  7. phoenix/datetime_utils.py +2 -2
  8. phoenix/db/bulk_inserter.py +10 -20
  9. phoenix/db/engines.py +2 -1
  10. phoenix/db/enums.py +2 -2
  11. phoenix/db/helpers.py +8 -7
  12. phoenix/db/insertion/dataset.py +9 -19
  13. phoenix/db/insertion/document_annotation.py +14 -13
  14. phoenix/db/insertion/helpers.py +6 -16
  15. phoenix/db/insertion/span_annotation.py +14 -13
  16. phoenix/db/insertion/trace_annotation.py +14 -13
  17. phoenix/db/insertion/types.py +19 -30
  18. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
  19. phoenix/db/models.py +28 -28
  20. phoenix/experiments/evaluators/base.py +2 -1
  21. phoenix/experiments/evaluators/code_evaluators.py +4 -5
  22. phoenix/experiments/evaluators/llm_evaluators.py +157 -4
  23. phoenix/experiments/evaluators/utils.py +3 -2
  24. phoenix/experiments/functions.py +10 -21
  25. phoenix/experiments/tracing.py +2 -1
  26. phoenix/experiments/types.py +20 -29
  27. phoenix/experiments/utils.py +2 -1
  28. phoenix/inferences/errors.py +6 -5
  29. phoenix/inferences/fixtures.py +6 -5
  30. phoenix/inferences/inferences.py +37 -37
  31. phoenix/inferences/schema.py +11 -10
  32. phoenix/inferences/validation.py +13 -14
  33. phoenix/logging/_formatter.py +3 -3
  34. phoenix/metrics/__init__.py +5 -4
  35. phoenix/metrics/binning.py +2 -1
  36. phoenix/metrics/metrics.py +2 -1
  37. phoenix/metrics/mixins.py +7 -6
  38. phoenix/metrics/retrieval_metrics.py +2 -1
  39. phoenix/metrics/timeseries.py +5 -4
  40. phoenix/metrics/wrappers.py +2 -2
  41. phoenix/pointcloud/clustering.py +3 -4
  42. phoenix/pointcloud/pointcloud.py +7 -5
  43. phoenix/pointcloud/umap_parameters.py +2 -1
  44. phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
  45. phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
  46. phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
  47. phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
  48. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
  49. phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
  50. phoenix/server/api/dataloaders/document_evaluations.py +3 -7
  51. phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
  52. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
  53. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
  54. phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
  55. phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
  56. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
  57. phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
  58. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
  59. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  60. phoenix/server/api/dataloaders/record_counts.py +11 -18
  61. phoenix/server/api/dataloaders/span_annotations.py +3 -7
  62. phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
  63. phoenix/server/api/dataloaders/span_descendants.py +3 -7
  64. phoenix/server/api/dataloaders/span_projects.py +2 -2
  65. phoenix/server/api/dataloaders/token_counts.py +12 -19
  66. phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
  67. phoenix/server/api/dataloaders/user_roles.py +3 -3
  68. phoenix/server/api/dataloaders/users.py +3 -3
  69. phoenix/server/api/helpers/__init__.py +4 -3
  70. phoenix/server/api/helpers/dataset_helpers.py +10 -9
  71. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
  72. phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
  73. phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
  74. phoenix/server/api/input_types/ClusterInput.py +2 -2
  75. phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
  76. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
  77. phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
  78. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  79. phoenix/server/api/input_types/Granularity.py +1 -1
  80. phoenix/server/api/input_types/InvocationParameters.py +2 -2
  81. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
  82. phoenix/server/api/mutations/dataset_mutations.py +4 -4
  83. phoenix/server/api/mutations/experiment_mutations.py +1 -2
  84. phoenix/server/api/mutations/export_events_mutations.py +7 -7
  85. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  86. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  87. phoenix/server/api/mutations/user_mutations.py +4 -4
  88. phoenix/server/api/openapi/schema.py +2 -2
  89. phoenix/server/api/queries.py +20 -20
  90. phoenix/server/api/routers/oauth2.py +4 -4
  91. phoenix/server/api/routers/v1/datasets.py +22 -36
  92. phoenix/server/api/routers/v1/evaluations.py +6 -5
  93. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
  94. phoenix/server/api/routers/v1/experiment_runs.py +2 -2
  95. phoenix/server/api/routers/v1/experiments.py +4 -4
  96. phoenix/server/api/routers/v1/spans.py +13 -12
  97. phoenix/server/api/routers/v1/traces.py +5 -5
  98. phoenix/server/api/routers/v1/utils.py +5 -5
  99. phoenix/server/api/subscriptions.py +289 -167
  100. phoenix/server/api/types/AnnotationSummary.py +3 -3
  101. phoenix/server/api/types/Cluster.py +8 -7
  102. phoenix/server/api/types/Dataset.py +5 -4
  103. phoenix/server/api/types/Dimension.py +3 -3
  104. phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
  105. phoenix/server/api/types/EmbeddingDimension.py +6 -5
  106. phoenix/server/api/types/EvaluationSummary.py +3 -3
  107. phoenix/server/api/types/Event.py +7 -7
  108. phoenix/server/api/types/Experiment.py +3 -3
  109. phoenix/server/api/types/ExperimentComparison.py +2 -4
  110. phoenix/server/api/types/Inferences.py +9 -8
  111. phoenix/server/api/types/InferencesRole.py +2 -2
  112. phoenix/server/api/types/Model.py +2 -2
  113. phoenix/server/api/types/Project.py +11 -18
  114. phoenix/server/api/types/Segments.py +3 -3
  115. phoenix/server/api/types/Span.py +8 -7
  116. phoenix/server/api/types/TimeSeries.py +8 -7
  117. phoenix/server/api/types/Trace.py +2 -2
  118. phoenix/server/api/types/UMAPPoints.py +6 -6
  119. phoenix/server/api/types/User.py +3 -3
  120. phoenix/server/api/types/node.py +1 -3
  121. phoenix/server/api/types/pagination.py +4 -4
  122. phoenix/server/api/utils.py +2 -4
  123. phoenix/server/app.py +16 -25
  124. phoenix/server/bearer_auth.py +4 -10
  125. phoenix/server/dml_event.py +3 -3
  126. phoenix/server/dml_event_handler.py +10 -24
  127. phoenix/server/grpc_server.py +3 -2
  128. phoenix/server/jwt_store.py +22 -21
  129. phoenix/server/main.py +3 -3
  130. phoenix/server/oauth2.py +3 -2
  131. phoenix/server/rate_limiters.py +5 -8
  132. phoenix/server/static/.vite/manifest.json +31 -31
  133. phoenix/server/static/assets/components-C70HJiXz.js +1612 -0
  134. phoenix/server/static/assets/{index-BHfTZ6x_.js → index-DLe1Oo3l.js} +2 -2
  135. phoenix/server/static/assets/{pages-aAez_Ntk.js → pages-C8-Sl7JI.js} +269 -434
  136. phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
  137. phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
  138. phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
  139. phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
  140. phoenix/server/thread_server.py +1 -1
  141. phoenix/server/types.py +17 -29
  142. phoenix/services.py +4 -3
  143. phoenix/session/client.py +12 -24
  144. phoenix/session/data_extractor.py +3 -3
  145. phoenix/session/evaluation.py +1 -2
  146. phoenix/session/session.py +11 -20
  147. phoenix/trace/attributes.py +16 -28
  148. phoenix/trace/dsl/filter.py +17 -21
  149. phoenix/trace/dsl/helpers.py +3 -3
  150. phoenix/trace/dsl/query.py +13 -22
  151. phoenix/trace/fixtures.py +11 -17
  152. phoenix/trace/otel.py +5 -15
  153. phoenix/trace/projects.py +3 -2
  154. phoenix/trace/schemas.py +2 -2
  155. phoenix/trace/span_evaluations.py +9 -8
  156. phoenix/trace/span_json_decoder.py +3 -3
  157. phoenix/trace/span_json_encoder.py +2 -2
  158. phoenix/trace/trace_dataset.py +6 -5
  159. phoenix/trace/utils.py +6 -6
  160. phoenix/utilities/deprecation.py +3 -2
  161. phoenix/utilities/error_handling.py +3 -2
  162. phoenix/utilities/json.py +2 -1
  163. phoenix/utilities/logging.py +2 -2
  164. phoenix/utilities/project.py +1 -1
  165. phoenix/utilities/re.py +3 -4
  166. phoenix/utilities/template_formatters.py +5 -4
  167. phoenix/version.py +1 -1
  168. phoenix/server/static/assets/components-mVBxvljU.js +0 -1428
  169. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/WHEEL +0 -0
  170. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/entry_points.txt +0 -0
  171. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/IP_NOTICE +0 -0
  172. {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,26 +1,13 @@
1
1
  import json
2
2
  from abc import ABC, abstractmethod
3
3
  from collections import defaultdict
4
- from dataclasses import fields
5
- from datetime import datetime
4
+ from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Mapping
5
+ from dataclasses import asdict
6
+ from datetime import datetime, timezone
6
7
  from enum import Enum
7
8
  from itertools import chain
8
- from typing import (
9
- TYPE_CHECKING,
10
- Annotated,
11
- Any,
12
- AsyncIterator,
13
- Callable,
14
- DefaultDict,
15
- Dict,
16
- Iterable,
17
- Iterator,
18
- List,
19
- Optional,
20
- Tuple,
21
- Type,
22
- Union,
23
- )
9
+ from traceback import format_exc
10
+ from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
24
11
 
25
12
  import strawberry
26
13
  from openinference.instrumentation import safe_json_dumps
@@ -32,9 +19,7 @@ from openinference.semconv.trace import (
32
19
  ToolAttributes,
33
20
  ToolCallAttributes,
34
21
  )
35
- from opentelemetry.sdk.trace import TracerProvider
36
- from opentelemetry.sdk.trace.export import SimpleSpanProcessor
37
- from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
22
+ from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
38
23
  from opentelemetry.trace import StatusCode
39
24
  from sqlalchemy import insert, select
40
25
  from strawberry import UNSET
@@ -42,8 +27,10 @@ from strawberry.scalars import JSON as JSONScalarType
42
27
  from strawberry.types import Info
43
28
  from typing_extensions import TypeAlias, assert_never
44
29
 
30
+ from phoenix.datetime_utils import local_now, normalize_datetime
45
31
  from phoenix.db import models
46
32
  from phoenix.server.api.context import Context
33
+ from phoenix.server.api.exceptions import BadRequest
47
34
  from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
48
35
  from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
49
36
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
@@ -51,6 +38,10 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
51
38
  from phoenix.server.api.types.Span import Span, to_gql_span
52
39
  from phoenix.server.dml_event import SpanInsertEvent
53
40
  from phoenix.trace.attributes import unflatten
41
+ from phoenix.trace.schemas import (
42
+ SpanEvent,
43
+ SpanException,
44
+ )
54
45
  from phoenix.utilities.json import jsonify
55
46
  from phoenix.utilities.template_formatters import (
56
47
  FStringTemplateFormatter,
@@ -61,11 +52,15 @@ from phoenix.utilities.template_formatters import (
61
52
  if TYPE_CHECKING:
62
53
  from anthropic.types import MessageParam
63
54
  from openai.types import CompletionUsage
64
- from openai.types.chat import ChatCompletionMessageParam
55
+ from openai.types.chat import (
56
+ ChatCompletionMessageParam,
57
+ ChatCompletionMessageToolCallParam,
58
+ )
65
59
 
66
60
  PLAYGROUND_PROJECT_NAME = "playground"
67
61
 
68
62
  ToolCallID: TypeAlias = str
63
+ SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
69
64
 
70
65
 
71
66
  @strawberry.enum
@@ -97,13 +92,20 @@ class ToolCallChunk:
97
92
  function: FunctionCallChunk
98
93
 
99
94
 
95
+ @strawberry.type
96
+ class ChatCompletionSubscriptionError:
97
+ message: str
98
+
99
+
100
100
  @strawberry.type
101
101
  class FinishedChatCompletion:
102
102
  span: Span
103
103
 
104
104
 
105
+ ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
106
+
105
107
  ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
106
- Union[TextChunk, ToolCallChunk, FinishedChatCompletion],
108
+ Union[TextChunk, ToolCallChunk, FinishedChatCompletion, ChatCompletionSubscriptionError],
107
109
  strawberry.union("ChatCompletionSubscriptionPayload"),
108
110
  ]
109
111
 
@@ -121,23 +123,23 @@ class GenerativeModelInput:
121
123
 
122
124
  @strawberry.input
123
125
  class ChatCompletionInput:
124
- messages: List[ChatCompletionMessageInput]
126
+ messages: list[ChatCompletionMessageInput]
125
127
  model: GenerativeModelInput
126
- invocation_parameters: InvocationParameters
127
- tools: Optional[List[JSONScalarType]] = UNSET
128
+ invocation_parameters: InvocationParameters = strawberry.field(default_factory=dict)
129
+ tools: Optional[list[JSONScalarType]] = UNSET
128
130
  template: Optional[TemplateOptions] = UNSET
129
131
  api_key: Optional[str] = strawberry.field(default=None)
130
132
 
131
133
 
132
- PLAYGROUND_STREAMING_CLIENT_REGISTRY: Dict[
133
- GenerativeProviderKey, Type["PlaygroundStreamingClient"]
134
+ PLAYGROUND_STREAMING_CLIENT_REGISTRY: dict[
135
+ GenerativeProviderKey, type["PlaygroundStreamingClient"]
134
136
  ] = {}
135
137
 
136
138
 
137
139
  def register_llm_client(
138
140
  provider_key: GenerativeProviderKey,
139
- ) -> Callable[[Type["PlaygroundStreamingClient"]], Type["PlaygroundStreamingClient"]]:
140
- def decorator(cls: Type["PlaygroundStreamingClient"]) -> Type["PlaygroundStreamingClient"]:
141
+ ) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
142
+ def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
141
143
  PLAYGROUND_STREAMING_CLIENT_REGISTRY[provider_key] = cls
142
144
  return cls
143
145
 
@@ -145,45 +147,56 @@ def register_llm_client(
145
147
 
146
148
 
147
149
  class PlaygroundStreamingClient(ABC):
148
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None: ...
150
+ def __init__(
151
+ self,
152
+ model: GenerativeModelInput,
153
+ api_key: Optional[str] = None,
154
+ set_span_attributes: Optional[SetSpanAttributesFn] = None,
155
+ ) -> None:
156
+ self._set_span_attributes = set_span_attributes
149
157
 
150
158
  @abstractmethod
151
159
  async def chat_completion_create(
152
160
  self,
153
- messages: List[Tuple[ChatCompletionMessageRole, str]],
154
- tools: List[JSONScalarType],
161
+ messages: list[
162
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
163
+ ],
164
+ tools: list[JSONScalarType],
155
165
  **invocation_parameters: Any,
156
- ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
166
+ ) -> AsyncIterator[ChatCompletionChunk]:
157
167
  # a yield statement is needed to satisfy the type-checker
158
168
  # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
159
169
  yield TextChunk(content="")
160
170
 
161
- @property
162
- @abstractmethod
163
- def attributes(self) -> Dict[str, Any]: ...
164
-
165
171
 
166
172
  @register_llm_client(GenerativeProviderKey.OPENAI)
167
173
  class OpenAIStreamingClient(PlaygroundStreamingClient):
168
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
174
+ def __init__(
175
+ self,
176
+ model: GenerativeModelInput,
177
+ api_key: Optional[str] = None,
178
+ set_span_attributes: Optional[SetSpanAttributesFn] = None,
179
+ ) -> None:
169
180
  from openai import AsyncOpenAI
170
181
 
182
+ super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
171
183
  self.client = AsyncOpenAI(api_key=api_key)
172
184
  self.model_name = model.name
173
- self._attributes: Dict[str, Any] = {}
174
185
 
175
186
  async def chat_completion_create(
176
187
  self,
177
- messages: List[Tuple[ChatCompletionMessageRole, str]],
178
- tools: List[JSONScalarType],
188
+ messages: list[
189
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
190
+ ],
191
+ tools: list[JSONScalarType],
179
192
  **invocation_parameters: Any,
180
- ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
193
+ ) -> AsyncIterator[ChatCompletionChunk]:
181
194
  from openai import NOT_GIVEN
182
195
  from openai.types.chat import ChatCompletionStreamOptionsParam
183
196
 
184
197
  # Convert standard messages to OpenAI messages
185
198
  openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
186
- tool_call_ids: Dict[int, str] = {}
199
+ tool_call_ids: dict[int, str] = {}
187
200
  token_usage: Optional["CompletionUsage"] = None
188
201
  async for chunk in await self.client.chat.completions.create(
189
202
  messages=openai_messages,
@@ -219,15 +232,20 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
219
232
  ),
220
233
  )
221
234
  yield tool_call_chunk
222
- if token_usage is not None:
223
- self._attributes.update(_llm_token_counts(token_usage))
235
+ if token_usage is not None and self._set_span_attributes:
236
+ self._set_span_attributes(dict(self._llm_token_counts(token_usage)))
224
237
 
225
238
  def to_openai_chat_completion_param(
226
- self, role: ChatCompletionMessageRole, content: JSONScalarType
239
+ self,
240
+ role: ChatCompletionMessageRole,
241
+ content: JSONScalarType,
242
+ tool_call_id: Optional[str] = None,
243
+ tool_calls: Optional[list[JSONScalarType]] = None,
227
244
  ) -> "ChatCompletionMessageParam":
228
245
  from openai.types.chat import (
229
246
  ChatCompletionAssistantMessageParam,
230
247
  ChatCompletionSystemMessageParam,
248
+ ChatCompletionToolMessageParam,
231
249
  ChatCompletionUserMessageParam,
232
250
  )
233
251
 
@@ -246,26 +264,64 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
246
264
  }
247
265
  )
248
266
  if role is ChatCompletionMessageRole.AI:
249
- return ChatCompletionAssistantMessageParam(
250
- {
251
- "content": content,
252
- "role": "assistant",
253
- }
254
- )
267
+ if tool_calls is None:
268
+ return ChatCompletionAssistantMessageParam(
269
+ {
270
+ "content": content,
271
+ "role": "assistant",
272
+ }
273
+ )
274
+ else:
275
+ return ChatCompletionAssistantMessageParam(
276
+ {
277
+ "content": content,
278
+ "role": "assistant",
279
+ "tool_calls": [
280
+ self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
281
+ ],
282
+ }
283
+ )
255
284
  if role is ChatCompletionMessageRole.TOOL:
256
- raise NotImplementedError
285
+ if tool_call_id is None:
286
+ raise ValueError("tool_call_id is required for tool messages")
287
+ return ChatCompletionToolMessageParam(
288
+ {"content": content, "role": "tool", "tool_call_id": tool_call_id}
289
+ )
257
290
  assert_never(role)
258
291
 
259
- @property
260
- def attributes(self) -> Dict[str, Any]:
261
- return self._attributes
292
+ def to_openai_tool_call_param(
293
+ self,
294
+ tool_call: JSONScalarType,
295
+ ) -> "ChatCompletionMessageToolCallParam":
296
+ from openai.types.chat import ChatCompletionMessageToolCallParam
297
+
298
+ return ChatCompletionMessageToolCallParam(
299
+ id=tool_call.get("id", ""),
300
+ function={
301
+ "name": tool_call.get("function", {}).get("name", ""),
302
+ "arguments": safe_json_dumps(tool_call.get("function", {}).get("arguments", "")),
303
+ },
304
+ type="function",
305
+ )
306
+
307
+ @staticmethod
308
+ def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
309
+ yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
310
+ yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
311
+ yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
262
312
 
263
313
 
264
314
  @register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
265
315
  class AzureOpenAIStreamingClient(OpenAIStreamingClient):
266
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None):
316
+ def __init__(
317
+ self,
318
+ model: GenerativeModelInput,
319
+ api_key: Optional[str] = None,
320
+ set_span_attributes: Optional[SetSpanAttributesFn] = None,
321
+ ):
267
322
  from openai import AsyncAzureOpenAI
268
323
 
324
+ super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
269
325
  if model.endpoint is None or model.api_version is None:
270
326
  raise ValueError("endpoint and api_version are required for Azure OpenAI models")
271
327
  self.client = AsyncAzureOpenAI(
@@ -277,18 +333,29 @@ class AzureOpenAIStreamingClient(OpenAIStreamingClient):
277
333
 
278
334
  @register_llm_client(GenerativeProviderKey.ANTHROPIC)
279
335
  class AnthropicStreamingClient(PlaygroundStreamingClient):
280
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
336
+ def __init__(
337
+ self,
338
+ model: GenerativeModelInput,
339
+ api_key: Optional[str] = None,
340
+ set_span_attributes: Optional[SetSpanAttributesFn] = None,
341
+ ) -> None:
281
342
  import anthropic
282
343
 
344
+ super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
283
345
  self.client = anthropic.AsyncAnthropic(api_key=api_key)
284
346
  self.model_name = model.name
285
347
 
286
348
  async def chat_completion_create(
287
349
  self,
288
- messages: List[Tuple[ChatCompletionMessageRole, str]],
289
- tools: List[JSONScalarType],
350
+ messages: list[
351
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
352
+ ],
353
+ tools: list[JSONScalarType],
290
354
  **invocation_parameters: Any,
291
- ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
355
+ ) -> AsyncIterator[ChatCompletionChunk]:
356
+ import anthropic.lib.streaming as anthropic_streaming
357
+ import anthropic.types as anthropic_types
358
+
292
359
  anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
293
360
 
294
361
  anthropic_params = {
@@ -298,17 +365,43 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
298
365
  "max_tokens": 1024,
299
366
  **invocation_parameters,
300
367
  }
301
-
302
368
  async with self.client.messages.stream(**anthropic_params) as stream:
303
- async for text in stream.text_stream:
304
- yield TextChunk(content=text)
369
+ async for event in stream:
370
+ if isinstance(event, anthropic_types.RawMessageStartEvent):
371
+ if self._set_span_attributes:
372
+ self._set_span_attributes(
373
+ {LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
374
+ )
375
+ elif isinstance(event, anthropic_streaming.TextEvent):
376
+ yield TextChunk(content=event.text)
377
+ elif isinstance(event, anthropic_streaming.MessageStopEvent):
378
+ if self._set_span_attributes:
379
+ self._set_span_attributes(
380
+ {LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
381
+ )
382
+ elif isinstance(
383
+ event,
384
+ (
385
+ anthropic_types.RawContentBlockStartEvent,
386
+ anthropic_types.RawContentBlockDeltaEvent,
387
+ anthropic_types.RawMessageDeltaEvent,
388
+ anthropic_streaming.ContentBlockStopEvent,
389
+ ),
390
+ ):
391
+ # event types emitted by the stream that don't contain useful information
392
+ pass
393
+ elif isinstance(event, anthropic_streaming.InputJsonEvent):
394
+ raise NotImplementedError
395
+ else:
396
+ assert_never(event)
305
397
 
306
398
  def _build_anthropic_messages(
307
- self, messages: List[Tuple[ChatCompletionMessageRole, str]]
308
- ) -> Tuple[List["MessageParam"], str]:
309
- anthropic_messages: List["MessageParam"] = []
399
+ self,
400
+ messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
401
+ ) -> tuple[list["MessageParam"], str]:
402
+ anthropic_messages: list["MessageParam"] = []
310
403
  system_prompt = ""
311
- for role, content in messages:
404
+ for role, content, _tool_call_id, _tool_calls in messages:
312
405
  if role == ChatCompletionMessageRole.USER:
313
406
  anthropic_messages.append({"role": "user", "content": content})
314
407
  elif role == ChatCompletionMessageRole.AI:
@@ -322,10 +415,6 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
322
415
 
323
416
  return anthropic_messages, system_prompt
324
417
 
325
- @property
326
- def attributes(self) -> Dict[str, Any]:
327
- return dict()
328
-
329
418
 
330
419
  @strawberry.type
331
420
  class Subscription:
@@ -335,44 +424,45 @@ class Subscription:
335
424
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
336
425
  # Determine which LLM client to use based on provider_key
337
426
  provider_key = input.model.provider_key
338
- llm_client_class = PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)
339
- if llm_client_class is None:
340
- raise ValueError(f"No LLM client registered for provider '{provider_key}'")
341
-
342
- llm_client = llm_client_class(model=input.model, api_key=input.api_key)
343
-
344
- messages = [(message.role, message.content) for message in input.messages]
345
-
427
+ if (llm_client_class := PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)) is None:
428
+ raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
429
+ llm_client = llm_client_class(
430
+ model=input.model,
431
+ api_key=input.api_key,
432
+ set_span_attributes=lambda attrs: attributes.update(attrs),
433
+ )
434
+ messages = [
435
+ (
436
+ message.role,
437
+ message.content,
438
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
439
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
440
+ )
441
+ for message in input.messages
442
+ ]
346
443
  if template_options := input.template:
347
444
  messages = list(_formatted_messages(messages, template_options))
348
-
349
445
  invocation_parameters = jsonify(input.invocation_parameters)
350
-
351
- in_memory_span_exporter = InMemorySpanExporter()
352
- tracer_provider = TracerProvider()
353
- tracer_provider.add_span_processor(
354
- span_processor=SimpleSpanProcessor(span_exporter=in_memory_span_exporter)
446
+ attributes = dict(
447
+ chain(
448
+ _llm_span_kind(),
449
+ _llm_model_name(input.model.name),
450
+ _llm_tools(input.tools or []),
451
+ _llm_input_messages(messages),
452
+ _llm_invocation_parameters(invocation_parameters),
453
+ _input_value_and_mime_type(input),
454
+ )
355
455
  )
356
- tracer = tracer_provider.get_tracer(__name__)
357
- span_name = "ChatCompletion"
358
-
359
- with tracer.start_span(
360
- span_name,
361
- attributes=dict(
362
- chain(
363
- _llm_span_kind(),
364
- _llm_model_name(input.model.name),
365
- _llm_tools(input.tools or []),
366
- _llm_input_messages(messages),
367
- _llm_invocation_parameters(invocation_parameters),
368
- _input_value_and_mime_type(input),
369
- )
370
- ),
371
- ) as span:
372
- response_chunks = []
373
- text_chunks: List[TextChunk] = []
374
- tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]] = defaultdict(list)
375
-
456
+ status_code: StatusCode
457
+ status_message = ""
458
+ events: list[SpanEvent] = []
459
+ start_time: datetime
460
+ end_time: datetime
461
+ response_chunks = []
462
+ text_chunks: list[TextChunk] = []
463
+ tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
464
+ try:
465
+ start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
376
466
  async for chunk in llm_client.chat_completion_create(
377
467
  messages=messages,
378
468
  tools=input.tools or [],
@@ -385,31 +475,35 @@ class Subscription:
385
475
  elif isinstance(chunk, ToolCallChunk):
386
476
  yield chunk
387
477
  tool_call_chunks[chunk.id].append(chunk)
388
-
389
- span.set_status(StatusCode.OK)
390
- llm_client_attributes = llm_client.attributes
391
-
392
- span.set_attributes(
393
- dict(
394
- chain(
395
- _output_value_and_mime_type(response_chunks),
396
- llm_client_attributes.items(),
397
- _llm_output_messages(text_chunks, tool_call_chunks),
398
- )
478
+ else:
479
+ assert_never(chunk)
480
+ status_code = StatusCode.OK
481
+ except Exception as error:
482
+ end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
483
+ status_code = StatusCode.ERROR
484
+ status_message = str(error)
485
+ events.append(
486
+ SpanException(
487
+ timestamp=end_time,
488
+ message=status_message,
489
+ exception_type=type(error).__name__,
490
+ exception_escaped=False,
491
+ exception_stacktrace=format_exc(),
399
492
  )
400
493
  )
401
- assert len(spans := in_memory_span_exporter.get_finished_spans()) == 1
402
- finished_span = spans[0]
403
- assert finished_span.start_time is not None
404
- assert finished_span.end_time is not None
405
- assert (attributes := finished_span.attributes) is not None
406
- start_time = _datetime(epoch_nanoseconds=finished_span.start_time)
407
- end_time = _datetime(epoch_nanoseconds=finished_span.end_time)
408
- prompt_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
409
- completion_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
410
- trace_id = _hex(finished_span.context.trace_id)
411
- span_id = _hex(finished_span.context.span_id)
412
- status = finished_span.status
494
+ yield ChatCompletionSubscriptionError(message=status_message)
495
+ else:
496
+ end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
497
+ attributes.update(
498
+ chain(
499
+ _output_value_and_mime_type(response_chunks),
500
+ _llm_output_messages(text_chunks, tool_call_chunks),
501
+ )
502
+ )
503
+ prompt_tokens = attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
504
+ completion_tokens = attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
505
+ trace_id = _generate_trace_id()
506
+ span_id = _generate_span_id()
413
507
  async with info.context.db() as session:
414
508
  if (
415
509
  playground_project_id := await session.scalar(
@@ -434,15 +528,15 @@ class Subscription:
434
528
  trace_rowid=playground_trace.id,
435
529
  span_id=span_id,
436
530
  parent_id=None,
437
- name=span_name,
531
+ name="ChatCompletion",
438
532
  span_kind=LLM,
439
533
  start_time=start_time,
440
534
  end_time=end_time,
441
535
  attributes=unflatten(attributes.items()),
442
- events=finished_span.events,
443
- status_code=status.status_code.name,
444
- status_message=status.description or "",
445
- cumulative_error_count=int(not status.is_ok),
536
+ events=[_serialize_event(event) for event in events],
537
+ status_code=status_code.name,
538
+ status_message=status_message,
539
+ cumulative_error_count=int(status_code is StatusCode.ERROR),
446
540
  cumulative_llm_token_count_prompt=prompt_tokens,
447
541
  cumulative_llm_token_count_completion=completion_tokens,
448
542
  llm_token_count_prompt=prompt_tokens,
@@ -456,56 +550,65 @@ class Subscription:
456
550
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
457
551
 
458
552
 
459
- def _llm_span_kind() -> Iterator[Tuple[str, Any]]:
553
+ def _llm_span_kind() -> Iterator[tuple[str, Any]]:
460
554
  yield OPENINFERENCE_SPAN_KIND, LLM
461
555
 
462
556
 
463
- def _llm_model_name(model_name: str) -> Iterator[Tuple[str, Any]]:
557
+ def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
464
558
  yield LLM_MODEL_NAME, model_name
465
559
 
466
560
 
467
- def _llm_invocation_parameters(invocation_parameters: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
561
+ def _llm_invocation_parameters(invocation_parameters: dict[str, Any]) -> Iterator[tuple[str, Any]]:
468
562
  yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
469
563
 
470
564
 
471
- def _llm_tools(tools: List[JSONScalarType]) -> Iterator[Tuple[str, Any]]:
565
+ def _llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
472
566
  for tool_index, tool in enumerate(tools):
473
567
  yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
474
568
 
475
569
 
476
- def _llm_token_counts(usage: "CompletionUsage") -> Iterator[Tuple[str, Any]]:
477
- yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
478
- yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
479
- yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
480
-
481
-
482
- def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
483
- assert any(field.name == (api_key := "api_key") for field in fields(ChatCompletionInput))
570
+ def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
571
+ assert (api_key := "api_key") in (input_data := jsonify(input))
572
+ input_data = {k: v for k, v in input_data.items() if k != api_key}
573
+ assert api_key not in input_data
484
574
  yield INPUT_MIME_TYPE, JSON
485
- yield INPUT_VALUE, safe_json_dumps({k: v for k, v in jsonify(input).items() if k != api_key})
575
+ yield INPUT_VALUE, safe_json_dumps(input_data)
486
576
 
487
577
 
488
- def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]:
578
+ def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
489
579
  yield OUTPUT_MIME_TYPE, JSON
490
580
  yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
491
581
 
492
582
 
493
583
  def _llm_input_messages(
494
- messages: Iterable[Tuple[ChatCompletionMessageRole, str]],
495
- ) -> Iterator[Tuple[str, Any]]:
496
- for i, (role, content) in enumerate(messages):
584
+ messages: Iterable[
585
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
586
+ ],
587
+ ) -> Iterator[tuple[str, Any]]:
588
+ for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
497
589
  yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
498
590
  yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
591
+ if tool_calls is not None:
592
+ for tool_call_index, tool_call in enumerate(tool_calls):
593
+ yield (
594
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
595
+ tool_call["function"]["name"],
596
+ )
597
+ if arguments := tool_call["function"]["arguments"]:
598
+ yield (
599
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
600
+ safe_json_dumps(jsonify(arguments)),
601
+ )
499
602
 
500
603
 
501
604
  def _llm_output_messages(
502
- text_chunks: List[TextChunk],
503
- tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]],
504
- ) -> Iterator[Tuple[str, Any]]:
605
+ text_chunks: list[TextChunk],
606
+ tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
607
+ ) -> Iterator[tuple[str, Any]]:
505
608
  yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
506
609
  if content := "".join(chunk.content for chunk in text_chunks):
507
610
  yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
508
- for tool_call_index, tool_call_chunks_ in tool_call_chunks.items():
611
+ for tool_call_index, (_tool_call_id, tool_call_chunks_) in enumerate(tool_call_chunks.items()):
509
612
  if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
510
613
  yield (
511
614
  f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
@@ -518,34 +621,46 @@ def _llm_output_messages(
518
621
  )
519
622
 
520
623
 
521
- def _hex(number: int) -> str:
624
+ def _generate_trace_id() -> str:
522
625
  """
523
- Converts an integer to a hexadecimal string.
626
+ Generates a random trace ID in hexadecimal format.
524
627
  """
525
- return hex(number)[2:]
628
+ return _hex(DefaultOTelIDGenerator().generate_trace_id())
526
629
 
527
630
 
528
- def _datetime(*, epoch_nanoseconds: float) -> datetime:
631
+ def _generate_span_id() -> str:
529
632
  """
530
- Converts a Unix epoch timestamp in nanoseconds to a datetime.
633
+ Generates a random span ID in hexadecimal format.
531
634
  """
532
- epoch_seconds = epoch_nanoseconds / 1e9
533
- return datetime.fromtimestamp(epoch_seconds)
635
+ return _hex(DefaultOTelIDGenerator().generate_span_id())
636
+
637
+
638
+ def _hex(number: int) -> str:
639
+ """
640
+ Converts an integer to a hexadecimal string.
641
+ """
642
+ return hex(number)[2:]
534
643
 
535
644
 
536
645
  def _formatted_messages(
537
- messages: Iterable[Tuple[ChatCompletionMessageRole, str]], template_options: TemplateOptions
538
- ) -> Iterator[Tuple[ChatCompletionMessageRole, str]]:
646
+ messages: Iterable[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
647
+ template_options: TemplateOptions,
648
+ ) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
539
649
  """
540
650
  Formats the messages using the given template options.
541
651
  """
542
652
  template_formatter = _template_formatter(template_language=template_options.language)
543
- roles, templates = zip(*messages)
653
+ (
654
+ roles,
655
+ templates,
656
+ tool_call_id,
657
+ tool_calls,
658
+ ) = zip(*messages)
544
659
  formatted_templates = map(
545
660
  lambda template: template_formatter.format(template, **template_options.variables),
546
661
  templates,
547
662
  )
548
- formatted_messages = zip(roles, formatted_templates)
663
+ formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
549
664
  return formatted_messages
550
665
 
551
666
 
@@ -560,6 +675,13 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
560
675
  assert_never(template_language)
561
676
 
562
677
 
678
+ def _serialize_event(event: SpanEvent) -> dict[str, Any]:
679
+ """
680
+ Serializes a SpanEvent to a dictionary.
681
+ """
682
+ return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
683
+
684
+
563
685
  JSON = OpenInferenceMimeTypeValues.JSON.value
564
686
 
565
687
  LLM = OpenInferenceSpanKindValues.LLM.value