arize-phoenix 5.5.2__py3-none-any.whl → 5.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (172) hide show
  1. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/METADATA +3 -6
  2. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/RECORD +171 -171
  3. phoenix/config.py +8 -8
  4. phoenix/core/model.py +3 -3
  5. phoenix/core/model_schema.py +41 -50
  6. phoenix/core/model_schema_adapter.py +17 -16
  7. phoenix/datetime_utils.py +2 -2
  8. phoenix/db/bulk_inserter.py +10 -20
  9. phoenix/db/engines.py +2 -1
  10. phoenix/db/enums.py +2 -2
  11. phoenix/db/helpers.py +8 -7
  12. phoenix/db/insertion/dataset.py +9 -19
  13. phoenix/db/insertion/document_annotation.py +14 -13
  14. phoenix/db/insertion/helpers.py +6 -16
  15. phoenix/db/insertion/span_annotation.py +14 -13
  16. phoenix/db/insertion/trace_annotation.py +14 -13
  17. phoenix/db/insertion/types.py +19 -30
  18. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
  19. phoenix/db/models.py +28 -28
  20. phoenix/experiments/evaluators/base.py +2 -1
  21. phoenix/experiments/evaluators/code_evaluators.py +4 -5
  22. phoenix/experiments/evaluators/llm_evaluators.py +157 -4
  23. phoenix/experiments/evaluators/utils.py +3 -2
  24. phoenix/experiments/functions.py +10 -21
  25. phoenix/experiments/tracing.py +2 -1
  26. phoenix/experiments/types.py +20 -29
  27. phoenix/experiments/utils.py +2 -1
  28. phoenix/inferences/errors.py +6 -5
  29. phoenix/inferences/fixtures.py +6 -5
  30. phoenix/inferences/inferences.py +37 -37
  31. phoenix/inferences/schema.py +11 -10
  32. phoenix/inferences/validation.py +13 -14
  33. phoenix/logging/_formatter.py +3 -3
  34. phoenix/metrics/__init__.py +5 -4
  35. phoenix/metrics/binning.py +2 -1
  36. phoenix/metrics/metrics.py +2 -1
  37. phoenix/metrics/mixins.py +7 -6
  38. phoenix/metrics/retrieval_metrics.py +2 -1
  39. phoenix/metrics/timeseries.py +5 -4
  40. phoenix/metrics/wrappers.py +2 -2
  41. phoenix/pointcloud/clustering.py +3 -4
  42. phoenix/pointcloud/pointcloud.py +7 -5
  43. phoenix/pointcloud/umap_parameters.py +2 -1
  44. phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
  45. phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
  46. phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
  47. phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
  48. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
  49. phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
  50. phoenix/server/api/dataloaders/document_evaluations.py +3 -7
  51. phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
  52. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
  53. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
  54. phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
  55. phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
  56. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
  57. phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
  58. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
  59. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  60. phoenix/server/api/dataloaders/record_counts.py +11 -18
  61. phoenix/server/api/dataloaders/span_annotations.py +3 -7
  62. phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
  63. phoenix/server/api/dataloaders/span_descendants.py +3 -7
  64. phoenix/server/api/dataloaders/span_projects.py +2 -2
  65. phoenix/server/api/dataloaders/token_counts.py +12 -19
  66. phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
  67. phoenix/server/api/dataloaders/user_roles.py +3 -3
  68. phoenix/server/api/dataloaders/users.py +3 -3
  69. phoenix/server/api/helpers/__init__.py +4 -3
  70. phoenix/server/api/helpers/dataset_helpers.py +10 -9
  71. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
  72. phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
  73. phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
  74. phoenix/server/api/input_types/ClusterInput.py +2 -2
  75. phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
  76. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
  77. phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
  78. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  79. phoenix/server/api/input_types/Granularity.py +1 -1
  80. phoenix/server/api/input_types/InvocationParameters.py +2 -2
  81. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
  82. phoenix/server/api/mutations/dataset_mutations.py +4 -4
  83. phoenix/server/api/mutations/experiment_mutations.py +1 -2
  84. phoenix/server/api/mutations/export_events_mutations.py +7 -7
  85. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  86. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  87. phoenix/server/api/mutations/user_mutations.py +4 -4
  88. phoenix/server/api/openapi/schema.py +2 -2
  89. phoenix/server/api/queries.py +20 -20
  90. phoenix/server/api/routers/oauth2.py +4 -4
  91. phoenix/server/api/routers/v1/datasets.py +22 -36
  92. phoenix/server/api/routers/v1/evaluations.py +6 -5
  93. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
  94. phoenix/server/api/routers/v1/experiment_runs.py +2 -2
  95. phoenix/server/api/routers/v1/experiments.py +4 -4
  96. phoenix/server/api/routers/v1/spans.py +13 -12
  97. phoenix/server/api/routers/v1/traces.py +5 -5
  98. phoenix/server/api/routers/v1/utils.py +5 -5
  99. phoenix/server/api/subscriptions.py +284 -162
  100. phoenix/server/api/types/AnnotationSummary.py +3 -3
  101. phoenix/server/api/types/Cluster.py +8 -7
  102. phoenix/server/api/types/Dataset.py +5 -4
  103. phoenix/server/api/types/Dimension.py +3 -3
  104. phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
  105. phoenix/server/api/types/EmbeddingDimension.py +6 -5
  106. phoenix/server/api/types/EvaluationSummary.py +3 -3
  107. phoenix/server/api/types/Event.py +7 -7
  108. phoenix/server/api/types/Experiment.py +3 -3
  109. phoenix/server/api/types/ExperimentComparison.py +2 -4
  110. phoenix/server/api/types/Inferences.py +9 -8
  111. phoenix/server/api/types/InferencesRole.py +2 -2
  112. phoenix/server/api/types/Model.py +2 -2
  113. phoenix/server/api/types/Project.py +11 -18
  114. phoenix/server/api/types/Segments.py +3 -3
  115. phoenix/server/api/types/Span.py +8 -7
  116. phoenix/server/api/types/TimeSeries.py +8 -7
  117. phoenix/server/api/types/Trace.py +2 -2
  118. phoenix/server/api/types/UMAPPoints.py +6 -6
  119. phoenix/server/api/types/User.py +3 -3
  120. phoenix/server/api/types/node.py +1 -3
  121. phoenix/server/api/types/pagination.py +4 -4
  122. phoenix/server/api/utils.py +2 -4
  123. phoenix/server/app.py +16 -25
  124. phoenix/server/bearer_auth.py +4 -10
  125. phoenix/server/dml_event.py +3 -3
  126. phoenix/server/dml_event_handler.py +10 -24
  127. phoenix/server/grpc_server.py +3 -2
  128. phoenix/server/jwt_store.py +22 -21
  129. phoenix/server/main.py +3 -3
  130. phoenix/server/oauth2.py +3 -2
  131. phoenix/server/rate_limiters.py +5 -8
  132. phoenix/server/static/.vite/manifest.json +31 -31
  133. phoenix/server/static/assets/components-C70HJiXz.js +1612 -0
  134. phoenix/server/static/assets/{index-DCzakdJq.js → index-DLe1Oo3l.js} +2 -2
  135. phoenix/server/static/assets/{pages-CAL1FDMt.js → pages-C8-Sl7JI.js} +269 -434
  136. phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
  137. phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
  138. phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
  139. phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
  140. phoenix/server/thread_server.py +1 -1
  141. phoenix/server/types.py +17 -29
  142. phoenix/services.py +4 -3
  143. phoenix/session/client.py +12 -24
  144. phoenix/session/data_extractor.py +3 -3
  145. phoenix/session/evaluation.py +1 -2
  146. phoenix/session/session.py +11 -20
  147. phoenix/trace/attributes.py +16 -28
  148. phoenix/trace/dsl/filter.py +17 -21
  149. phoenix/trace/dsl/helpers.py +3 -3
  150. phoenix/trace/dsl/query.py +13 -22
  151. phoenix/trace/fixtures.py +11 -17
  152. phoenix/trace/otel.py +5 -15
  153. phoenix/trace/projects.py +3 -2
  154. phoenix/trace/schemas.py +2 -2
  155. phoenix/trace/span_evaluations.py +9 -8
  156. phoenix/trace/span_json_decoder.py +3 -3
  157. phoenix/trace/span_json_encoder.py +2 -2
  158. phoenix/trace/trace_dataset.py +6 -5
  159. phoenix/trace/utils.py +6 -6
  160. phoenix/utilities/deprecation.py +3 -2
  161. phoenix/utilities/error_handling.py +3 -2
  162. phoenix/utilities/json.py +2 -1
  163. phoenix/utilities/logging.py +2 -2
  164. phoenix/utilities/project.py +1 -1
  165. phoenix/utilities/re.py +3 -4
  166. phoenix/utilities/template_formatters.py +5 -4
  167. phoenix/version.py +1 -1
  168. phoenix/server/static/assets/components-hX0LgYz3.js +0 -1428
  169. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/WHEEL +0 -0
  170. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/entry_points.txt +0 -0
  171. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/IP_NOTICE +0 -0
  172. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,25 +1,13 @@
1
1
  import json
2
2
  from abc import ABC, abstractmethod
3
3
  from collections import defaultdict
4
+ from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Mapping
5
+ from dataclasses import asdict
4
6
  from datetime import datetime, timezone
5
7
  from enum import Enum
6
8
  from itertools import chain
7
- from typing import (
8
- TYPE_CHECKING,
9
- Annotated,
10
- Any,
11
- AsyncIterator,
12
- Callable,
13
- DefaultDict,
14
- Dict,
15
- Iterable,
16
- Iterator,
17
- List,
18
- Optional,
19
- Tuple,
20
- Type,
21
- Union,
22
- )
9
+ from traceback import format_exc
10
+ from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
23
11
 
24
12
  import strawberry
25
13
  from openinference.instrumentation import safe_json_dumps
@@ -31,9 +19,7 @@ from openinference.semconv.trace import (
31
19
  ToolAttributes,
32
20
  ToolCallAttributes,
33
21
  )
34
- from opentelemetry.sdk.trace import TracerProvider
35
- from opentelemetry.sdk.trace.export import SimpleSpanProcessor
36
- from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
22
+ from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
37
23
  from opentelemetry.trace import StatusCode
38
24
  from sqlalchemy import insert, select
39
25
  from strawberry import UNSET
@@ -41,8 +27,10 @@ from strawberry.scalars import JSON as JSONScalarType
41
27
  from strawberry.types import Info
42
28
  from typing_extensions import TypeAlias, assert_never
43
29
 
30
+ from phoenix.datetime_utils import local_now, normalize_datetime
44
31
  from phoenix.db import models
45
32
  from phoenix.server.api.context import Context
33
+ from phoenix.server.api.exceptions import BadRequest
46
34
  from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
47
35
  from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
48
36
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
@@ -50,6 +38,10 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
50
38
  from phoenix.server.api.types.Span import Span, to_gql_span
51
39
  from phoenix.server.dml_event import SpanInsertEvent
52
40
  from phoenix.trace.attributes import unflatten
41
+ from phoenix.trace.schemas import (
42
+ SpanEvent,
43
+ SpanException,
44
+ )
53
45
  from phoenix.utilities.json import jsonify
54
46
  from phoenix.utilities.template_formatters import (
55
47
  FStringTemplateFormatter,
@@ -60,11 +52,15 @@ from phoenix.utilities.template_formatters import (
60
52
  if TYPE_CHECKING:
61
53
  from anthropic.types import MessageParam
62
54
  from openai.types import CompletionUsage
63
- from openai.types.chat import ChatCompletionMessageParam
55
+ from openai.types.chat import (
56
+ ChatCompletionMessageParam,
57
+ ChatCompletionMessageToolCallParam,
58
+ )
64
59
 
65
60
  PLAYGROUND_PROJECT_NAME = "playground"
66
61
 
67
62
  ToolCallID: TypeAlias = str
63
+ SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
68
64
 
69
65
 
70
66
  @strawberry.enum
@@ -96,13 +92,20 @@ class ToolCallChunk:
96
92
  function: FunctionCallChunk
97
93
 
98
94
 
95
+ @strawberry.type
96
+ class ChatCompletionSubscriptionError:
97
+ message: str
98
+
99
+
99
100
  @strawberry.type
100
101
  class FinishedChatCompletion:
101
102
  span: Span
102
103
 
103
104
 
105
+ ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
106
+
104
107
  ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
105
- Union[TextChunk, ToolCallChunk, FinishedChatCompletion],
108
+ Union[TextChunk, ToolCallChunk, FinishedChatCompletion, ChatCompletionSubscriptionError],
106
109
  strawberry.union("ChatCompletionSubscriptionPayload"),
107
110
  ]
108
111
 
@@ -120,23 +123,23 @@ class GenerativeModelInput:
120
123
 
121
124
  @strawberry.input
122
125
  class ChatCompletionInput:
123
- messages: List[ChatCompletionMessageInput]
126
+ messages: list[ChatCompletionMessageInput]
124
127
  model: GenerativeModelInput
125
- invocation_parameters: InvocationParameters
126
- tools: Optional[List[JSONScalarType]] = UNSET
128
+ invocation_parameters: InvocationParameters = strawberry.field(default_factory=dict)
129
+ tools: Optional[list[JSONScalarType]] = UNSET
127
130
  template: Optional[TemplateOptions] = UNSET
128
131
  api_key: Optional[str] = strawberry.field(default=None)
129
132
 
130
133
 
131
- PLAYGROUND_STREAMING_CLIENT_REGISTRY: Dict[
132
- GenerativeProviderKey, Type["PlaygroundStreamingClient"]
134
+ PLAYGROUND_STREAMING_CLIENT_REGISTRY: dict[
135
+ GenerativeProviderKey, type["PlaygroundStreamingClient"]
133
136
  ] = {}
134
137
 
135
138
 
136
139
  def register_llm_client(
137
140
  provider_key: GenerativeProviderKey,
138
- ) -> Callable[[Type["PlaygroundStreamingClient"]], Type["PlaygroundStreamingClient"]]:
139
- def decorator(cls: Type["PlaygroundStreamingClient"]) -> Type["PlaygroundStreamingClient"]:
141
+ ) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
142
+ def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
140
143
  PLAYGROUND_STREAMING_CLIENT_REGISTRY[provider_key] = cls
141
144
  return cls
142
145
 
@@ -144,45 +147,56 @@ def register_llm_client(
144
147
 
145
148
 
146
149
  class PlaygroundStreamingClient(ABC):
147
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None: ...
150
+ def __init__(
151
+ self,
152
+ model: GenerativeModelInput,
153
+ api_key: Optional[str] = None,
154
+ set_span_attributes: Optional[SetSpanAttributesFn] = None,
155
+ ) -> None:
156
+ self._set_span_attributes = set_span_attributes
148
157
 
149
158
  @abstractmethod
150
159
  async def chat_completion_create(
151
160
  self,
152
- messages: List[Tuple[ChatCompletionMessageRole, str]],
153
- tools: List[JSONScalarType],
161
+ messages: list[
162
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
163
+ ],
164
+ tools: list[JSONScalarType],
154
165
  **invocation_parameters: Any,
155
- ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
166
+ ) -> AsyncIterator[ChatCompletionChunk]:
156
167
  # a yield statement is needed to satisfy the type-checker
157
168
  # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
158
169
  yield TextChunk(content="")
159
170
 
160
- @property
161
- @abstractmethod
162
- def attributes(self) -> Dict[str, Any]: ...
163
-
164
171
 
165
172
  @register_llm_client(GenerativeProviderKey.OPENAI)
166
173
  class OpenAIStreamingClient(PlaygroundStreamingClient):
167
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
174
+ def __init__(
175
+ self,
176
+ model: GenerativeModelInput,
177
+ api_key: Optional[str] = None,
178
+ set_span_attributes: Optional[SetSpanAttributesFn] = None,
179
+ ) -> None:
168
180
  from openai import AsyncOpenAI
169
181
 
182
+ super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
170
183
  self.client = AsyncOpenAI(api_key=api_key)
171
184
  self.model_name = model.name
172
- self._attributes: Dict[str, Any] = {}
173
185
 
174
186
  async def chat_completion_create(
175
187
  self,
176
- messages: List[Tuple[ChatCompletionMessageRole, str]],
177
- tools: List[JSONScalarType],
188
+ messages: list[
189
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
190
+ ],
191
+ tools: list[JSONScalarType],
178
192
  **invocation_parameters: Any,
179
- ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
193
+ ) -> AsyncIterator[ChatCompletionChunk]:
180
194
  from openai import NOT_GIVEN
181
195
  from openai.types.chat import ChatCompletionStreamOptionsParam
182
196
 
183
197
  # Convert standard messages to OpenAI messages
184
198
  openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
185
- tool_call_ids: Dict[int, str] = {}
199
+ tool_call_ids: dict[int, str] = {}
186
200
  token_usage: Optional["CompletionUsage"] = None
187
201
  async for chunk in await self.client.chat.completions.create(
188
202
  messages=openai_messages,
@@ -218,15 +232,20 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
218
232
  ),
219
233
  )
220
234
  yield tool_call_chunk
221
- if token_usage is not None:
222
- self._attributes.update(_llm_token_counts(token_usage))
235
+ if token_usage is not None and self._set_span_attributes:
236
+ self._set_span_attributes(dict(self._llm_token_counts(token_usage)))
223
237
 
224
238
  def to_openai_chat_completion_param(
225
- self, role: ChatCompletionMessageRole, content: JSONScalarType
239
+ self,
240
+ role: ChatCompletionMessageRole,
241
+ content: JSONScalarType,
242
+ tool_call_id: Optional[str] = None,
243
+ tool_calls: Optional[list[JSONScalarType]] = None,
226
244
  ) -> "ChatCompletionMessageParam":
227
245
  from openai.types.chat import (
228
246
  ChatCompletionAssistantMessageParam,
229
247
  ChatCompletionSystemMessageParam,
248
+ ChatCompletionToolMessageParam,
230
249
  ChatCompletionUserMessageParam,
231
250
  )
232
251
 
@@ -245,26 +264,64 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
245
264
  }
246
265
  )
247
266
  if role is ChatCompletionMessageRole.AI:
248
- return ChatCompletionAssistantMessageParam(
249
- {
250
- "content": content,
251
- "role": "assistant",
252
- }
253
- )
267
+ if tool_calls is None:
268
+ return ChatCompletionAssistantMessageParam(
269
+ {
270
+ "content": content,
271
+ "role": "assistant",
272
+ }
273
+ )
274
+ else:
275
+ return ChatCompletionAssistantMessageParam(
276
+ {
277
+ "content": content,
278
+ "role": "assistant",
279
+ "tool_calls": [
280
+ self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
281
+ ],
282
+ }
283
+ )
254
284
  if role is ChatCompletionMessageRole.TOOL:
255
- raise NotImplementedError
285
+ if tool_call_id is None:
286
+ raise ValueError("tool_call_id is required for tool messages")
287
+ return ChatCompletionToolMessageParam(
288
+ {"content": content, "role": "tool", "tool_call_id": tool_call_id}
289
+ )
256
290
  assert_never(role)
257
291
 
258
- @property
259
- def attributes(self) -> Dict[str, Any]:
260
- return self._attributes
292
+ def to_openai_tool_call_param(
293
+ self,
294
+ tool_call: JSONScalarType,
295
+ ) -> "ChatCompletionMessageToolCallParam":
296
+ from openai.types.chat import ChatCompletionMessageToolCallParam
297
+
298
+ return ChatCompletionMessageToolCallParam(
299
+ id=tool_call.get("id", ""),
300
+ function={
301
+ "name": tool_call.get("function", {}).get("name", ""),
302
+ "arguments": safe_json_dumps(tool_call.get("function", {}).get("arguments", "")),
303
+ },
304
+ type="function",
305
+ )
306
+
307
+ @staticmethod
308
+ def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
309
+ yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
310
+ yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
311
+ yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
261
312
 
262
313
 
263
314
  @register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
264
315
  class AzureOpenAIStreamingClient(OpenAIStreamingClient):
265
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None):
316
+ def __init__(
317
+ self,
318
+ model: GenerativeModelInput,
319
+ api_key: Optional[str] = None,
320
+ set_span_attributes: Optional[SetSpanAttributesFn] = None,
321
+ ):
266
322
  from openai import AsyncAzureOpenAI
267
323
 
324
+ super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
268
325
  if model.endpoint is None or model.api_version is None:
269
326
  raise ValueError("endpoint and api_version are required for Azure OpenAI models")
270
327
  self.client = AsyncAzureOpenAI(
@@ -276,18 +333,29 @@ class AzureOpenAIStreamingClient(OpenAIStreamingClient):
276
333
 
277
334
  @register_llm_client(GenerativeProviderKey.ANTHROPIC)
278
335
  class AnthropicStreamingClient(PlaygroundStreamingClient):
279
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
336
+ def __init__(
337
+ self,
338
+ model: GenerativeModelInput,
339
+ api_key: Optional[str] = None,
340
+ set_span_attributes: Optional[SetSpanAttributesFn] = None,
341
+ ) -> None:
280
342
  import anthropic
281
343
 
344
+ super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
282
345
  self.client = anthropic.AsyncAnthropic(api_key=api_key)
283
346
  self.model_name = model.name
284
347
 
285
348
  async def chat_completion_create(
286
349
  self,
287
- messages: List[Tuple[ChatCompletionMessageRole, str]],
288
- tools: List[JSONScalarType],
350
+ messages: list[
351
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
352
+ ],
353
+ tools: list[JSONScalarType],
289
354
  **invocation_parameters: Any,
290
- ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
355
+ ) -> AsyncIterator[ChatCompletionChunk]:
356
+ import anthropic.lib.streaming as anthropic_streaming
357
+ import anthropic.types as anthropic_types
358
+
291
359
  anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
292
360
 
293
361
  anthropic_params = {
@@ -297,17 +365,43 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
297
365
  "max_tokens": 1024,
298
366
  **invocation_parameters,
299
367
  }
300
-
301
368
  async with self.client.messages.stream(**anthropic_params) as stream:
302
- async for text in stream.text_stream:
303
- yield TextChunk(content=text)
369
+ async for event in stream:
370
+ if isinstance(event, anthropic_types.RawMessageStartEvent):
371
+ if self._set_span_attributes:
372
+ self._set_span_attributes(
373
+ {LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
374
+ )
375
+ elif isinstance(event, anthropic_streaming.TextEvent):
376
+ yield TextChunk(content=event.text)
377
+ elif isinstance(event, anthropic_streaming.MessageStopEvent):
378
+ if self._set_span_attributes:
379
+ self._set_span_attributes(
380
+ {LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
381
+ )
382
+ elif isinstance(
383
+ event,
384
+ (
385
+ anthropic_types.RawContentBlockStartEvent,
386
+ anthropic_types.RawContentBlockDeltaEvent,
387
+ anthropic_types.RawMessageDeltaEvent,
388
+ anthropic_streaming.ContentBlockStopEvent,
389
+ ),
390
+ ):
391
+ # event types emitted by the stream that don't contain useful information
392
+ pass
393
+ elif isinstance(event, anthropic_streaming.InputJsonEvent):
394
+ raise NotImplementedError
395
+ else:
396
+ assert_never(event)
304
397
 
305
398
  def _build_anthropic_messages(
306
- self, messages: List[Tuple[ChatCompletionMessageRole, str]]
307
- ) -> Tuple[List["MessageParam"], str]:
308
- anthropic_messages: List["MessageParam"] = []
399
+ self,
400
+ messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
401
+ ) -> tuple[list["MessageParam"], str]:
402
+ anthropic_messages: list["MessageParam"] = []
309
403
  system_prompt = ""
310
- for role, content in messages:
404
+ for role, content, _tool_call_id, _tool_calls in messages:
311
405
  if role == ChatCompletionMessageRole.USER:
312
406
  anthropic_messages.append({"role": "user", "content": content})
313
407
  elif role == ChatCompletionMessageRole.AI:
@@ -321,10 +415,6 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
321
415
 
322
416
  return anthropic_messages, system_prompt
323
417
 
324
- @property
325
- def attributes(self) -> Dict[str, Any]:
326
- return dict()
327
-
328
418
 
329
419
  @strawberry.type
330
420
  class Subscription:
@@ -334,43 +424,45 @@ class Subscription:
334
424
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
335
425
  # Determine which LLM client to use based on provider_key
336
426
  provider_key = input.model.provider_key
337
- llm_client_class = PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)
338
- if llm_client_class is None:
339
- raise ValueError(f"No LLM client registered for provider '{provider_key}'")
340
-
341
- llm_client = llm_client_class(model=input.model, api_key=input.api_key)
342
-
343
- messages = [(message.role, message.content) for message in input.messages]
344
-
427
+ if (llm_client_class := PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)) is None:
428
+ raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
429
+ llm_client = llm_client_class(
430
+ model=input.model,
431
+ api_key=input.api_key,
432
+ set_span_attributes=lambda attrs: attributes.update(attrs),
433
+ )
434
+ messages = [
435
+ (
436
+ message.role,
437
+ message.content,
438
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
439
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
440
+ )
441
+ for message in input.messages
442
+ ]
345
443
  if template_options := input.template:
346
444
  messages = list(_formatted_messages(messages, template_options))
347
-
348
445
  invocation_parameters = jsonify(input.invocation_parameters)
349
-
350
- in_memory_span_exporter = InMemorySpanExporter()
351
- tracer_provider = TracerProvider()
352
- tracer_provider.add_span_processor(
353
- span_processor=SimpleSpanProcessor(span_exporter=in_memory_span_exporter)
446
+ attributes = dict(
447
+ chain(
448
+ _llm_span_kind(),
449
+ _llm_model_name(input.model.name),
450
+ _llm_tools(input.tools or []),
451
+ _llm_input_messages(messages),
452
+ _llm_invocation_parameters(invocation_parameters),
453
+ _input_value_and_mime_type(input),
454
+ )
354
455
  )
355
- tracer = tracer_provider.get_tracer(__name__)
356
- span_name = "ChatCompletion"
357
- with tracer.start_span(
358
- span_name,
359
- attributes=dict(
360
- chain(
361
- _llm_span_kind(),
362
- _llm_model_name(input.model.name),
363
- _llm_tools(input.tools or []),
364
- _llm_input_messages(messages),
365
- _llm_invocation_parameters(invocation_parameters),
366
- _input_value_and_mime_type(input),
367
- )
368
- ),
369
- ) as span:
370
- response_chunks = []
371
- text_chunks: List[TextChunk] = []
372
- tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]] = defaultdict(list)
373
-
456
+ status_code: StatusCode
457
+ status_message = ""
458
+ events: list[SpanEvent] = []
459
+ start_time: datetime
460
+ end_time: datetime
461
+ response_chunks = []
462
+ text_chunks: list[TextChunk] = []
463
+ tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
464
+ try:
465
+ start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
374
466
  async for chunk in llm_client.chat_completion_create(
375
467
  messages=messages,
376
468
  tools=input.tools or [],
@@ -383,31 +475,35 @@ class Subscription:
383
475
  elif isinstance(chunk, ToolCallChunk):
384
476
  yield chunk
385
477
  tool_call_chunks[chunk.id].append(chunk)
386
-
387
- span.set_status(StatusCode.OK)
388
- llm_client_attributes = llm_client.attributes
389
-
390
- span.set_attributes(
391
- dict(
392
- chain(
393
- _output_value_and_mime_type(response_chunks),
394
- llm_client_attributes.items(),
395
- _llm_output_messages(text_chunks, tool_call_chunks),
396
- )
478
+ else:
479
+ assert_never(chunk)
480
+ status_code = StatusCode.OK
481
+ except Exception as error:
482
+ end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
483
+ status_code = StatusCode.ERROR
484
+ status_message = str(error)
485
+ events.append(
486
+ SpanException(
487
+ timestamp=end_time,
488
+ message=status_message,
489
+ exception_type=type(error).__name__,
490
+ exception_escaped=False,
491
+ exception_stacktrace=format_exc(),
397
492
  )
398
493
  )
399
- assert len(spans := in_memory_span_exporter.get_finished_spans()) == 1
400
- finished_span = spans[0]
401
- assert finished_span.start_time is not None
402
- assert finished_span.end_time is not None
403
- assert (attributes := finished_span.attributes) is not None
404
- start_time = _datetime(epoch_nanoseconds=finished_span.start_time)
405
- end_time = _datetime(epoch_nanoseconds=finished_span.end_time)
406
- prompt_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
407
- completion_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
408
- trace_id = _hex(finished_span.context.trace_id)
409
- span_id = _hex(finished_span.context.span_id)
410
- status = finished_span.status
494
+ yield ChatCompletionSubscriptionError(message=status_message)
495
+ else:
496
+ end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
497
+ attributes.update(
498
+ chain(
499
+ _output_value_and_mime_type(response_chunks),
500
+ _llm_output_messages(text_chunks, tool_call_chunks),
501
+ )
502
+ )
503
+ prompt_tokens = attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
504
+ completion_tokens = attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
505
+ trace_id = _generate_trace_id()
506
+ span_id = _generate_span_id()
411
507
  async with info.context.db() as session:
412
508
  if (
413
509
  playground_project_id := await session.scalar(
@@ -432,15 +528,15 @@ class Subscription:
432
528
  trace_rowid=playground_trace.id,
433
529
  span_id=span_id,
434
530
  parent_id=None,
435
- name=span_name,
531
+ name="ChatCompletion",
436
532
  span_kind=LLM,
437
533
  start_time=start_time,
438
534
  end_time=end_time,
439
535
  attributes=unflatten(attributes.items()),
440
- events=finished_span.events,
441
- status_code=status.status_code.name,
442
- status_message=status.description or "",
443
- cumulative_error_count=int(not status.is_ok),
536
+ events=[_serialize_event(event) for event in events],
537
+ status_code=status_code.name,
538
+ status_message=status_message,
539
+ cumulative_error_count=int(status_code is StatusCode.ERROR),
444
540
  cumulative_llm_token_count_prompt=prompt_tokens,
445
541
  cumulative_llm_token_count_completion=completion_tokens,
446
542
  llm_token_count_prompt=prompt_tokens,
@@ -454,30 +550,24 @@ class Subscription:
454
550
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
455
551
 
456
552
 
457
- def _llm_span_kind() -> Iterator[Tuple[str, Any]]:
553
+ def _llm_span_kind() -> Iterator[tuple[str, Any]]:
458
554
  yield OPENINFERENCE_SPAN_KIND, LLM
459
555
 
460
556
 
461
- def _llm_model_name(model_name: str) -> Iterator[Tuple[str, Any]]:
557
+ def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
462
558
  yield LLM_MODEL_NAME, model_name
463
559
 
464
560
 
465
- def _llm_invocation_parameters(invocation_parameters: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
561
+ def _llm_invocation_parameters(invocation_parameters: dict[str, Any]) -> Iterator[tuple[str, Any]]:
466
562
  yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
467
563
 
468
564
 
469
- def _llm_tools(tools: List[JSONScalarType]) -> Iterator[Tuple[str, Any]]:
565
+ def _llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
470
566
  for tool_index, tool in enumerate(tools):
471
567
  yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
472
568
 
473
569
 
474
- def _llm_token_counts(usage: "CompletionUsage") -> Iterator[Tuple[str, Any]]:
475
- yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
476
- yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
477
- yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
478
-
479
-
480
- def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
570
+ def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
481
571
  assert (api_key := "api_key") in (input_data := jsonify(input))
482
572
  input_data = {k: v for k, v in input_data.items() if k != api_key}
483
573
  assert api_key not in input_data
@@ -485,27 +575,40 @@ def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str
485
575
  yield INPUT_VALUE, safe_json_dumps(input_data)
486
576
 
487
577
 
488
- def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]:
578
+ def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
489
579
  yield OUTPUT_MIME_TYPE, JSON
490
580
  yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
491
581
 
492
582
 
493
583
  def _llm_input_messages(
494
- messages: Iterable[Tuple[ChatCompletionMessageRole, str]],
495
- ) -> Iterator[Tuple[str, Any]]:
496
- for i, (role, content) in enumerate(messages):
584
+ messages: Iterable[
585
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
586
+ ],
587
+ ) -> Iterator[tuple[str, Any]]:
588
+ for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
497
589
  yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
498
590
  yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
591
+ if tool_calls is not None:
592
+ for tool_call_index, tool_call in enumerate(tool_calls):
593
+ yield (
594
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
595
+ tool_call["function"]["name"],
596
+ )
597
+ if arguments := tool_call["function"]["arguments"]:
598
+ yield (
599
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
600
+ safe_json_dumps(jsonify(arguments)),
601
+ )
499
602
 
500
603
 
501
604
  def _llm_output_messages(
502
- text_chunks: List[TextChunk],
503
- tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]],
504
- ) -> Iterator[Tuple[str, Any]]:
605
+ text_chunks: list[TextChunk],
606
+ tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
607
+ ) -> Iterator[tuple[str, Any]]:
505
608
  yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
506
609
  if content := "".join(chunk.content for chunk in text_chunks):
507
610
  yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
508
- for tool_call_index, tool_call_chunks_ in tool_call_chunks.items():
611
+ for tool_call_index, (_tool_call_id, tool_call_chunks_) in enumerate(tool_call_chunks.items()):
509
612
  if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
510
613
  yield (
511
614
  f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
@@ -518,34 +621,46 @@ def _llm_output_messages(
518
621
  )
519
622
 
520
623
 
521
- def _hex(number: int) -> str:
624
+ def _generate_trace_id() -> str:
522
625
  """
523
- Converts an integer to a hexadecimal string.
626
+ Generates a random trace ID in hexadecimal format.
524
627
  """
525
- return hex(number)[2:]
628
+ return _hex(DefaultOTelIDGenerator().generate_trace_id())
526
629
 
527
630
 
528
- def _datetime(*, epoch_nanoseconds: float) -> datetime:
631
+ def _generate_span_id() -> str:
529
632
  """
530
- Converts a Unix epoch timestamp in nanoseconds to a datetime.
633
+ Generates a random span ID in hexadecimal format.
531
634
  """
532
- epoch_seconds = epoch_nanoseconds / 1e9
533
- return datetime.fromtimestamp(epoch_seconds).replace(tzinfo=timezone.utc)
635
+ return _hex(DefaultOTelIDGenerator().generate_span_id())
636
+
637
+
638
+ def _hex(number: int) -> str:
639
+ """
640
+ Converts an integer to a hexadecimal string.
641
+ """
642
+ return hex(number)[2:]
534
643
 
535
644
 
536
645
  def _formatted_messages(
537
- messages: Iterable[Tuple[ChatCompletionMessageRole, str]], template_options: TemplateOptions
538
- ) -> Iterator[Tuple[ChatCompletionMessageRole, str]]:
646
+ messages: Iterable[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
647
+ template_options: TemplateOptions,
648
+ ) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
539
649
  """
540
650
  Formats the messages using the given template options.
541
651
  """
542
652
  template_formatter = _template_formatter(template_language=template_options.language)
543
- roles, templates = zip(*messages)
653
+ (
654
+ roles,
655
+ templates,
656
+ tool_call_id,
657
+ tool_calls,
658
+ ) = zip(*messages)
544
659
  formatted_templates = map(
545
660
  lambda template: template_formatter.format(template, **template_options.variables),
546
661
  templates,
547
662
  )
548
- formatted_messages = zip(roles, formatted_templates)
663
+ formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
549
664
  return formatted_messages
550
665
 
551
666
 
@@ -560,6 +675,13 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
560
675
  assert_never(template_language)
561
676
 
562
677
 
678
+ def _serialize_event(event: SpanEvent) -> dict[str, Any]:
679
+ """
680
+ Serializes a SpanEvent to a dictionary.
681
+ """
682
+ return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
683
+
684
+
563
685
  JSON = OpenInferenceMimeTypeValues.JSON.value
564
686
 
565
687
  LLM = OpenInferenceSpanKindValues.LLM.value