arize-phoenix 5.7.0__py3-none-any.whl → 5.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (32) hide show
  1. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/METADATA +3 -5
  2. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/RECORD +31 -31
  3. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/WHEEL +1 -1
  4. phoenix/config.py +19 -3
  5. phoenix/db/helpers.py +55 -1
  6. phoenix/server/api/helpers/playground_clients.py +283 -44
  7. phoenix/server/api/helpers/playground_spans.py +173 -76
  8. phoenix/server/api/input_types/InvocationParameters.py +7 -8
  9. phoenix/server/api/mutations/chat_mutations.py +244 -76
  10. phoenix/server/api/queries.py +5 -1
  11. phoenix/server/api/routers/v1/spans.py +25 -1
  12. phoenix/server/api/subscriptions.py +210 -158
  13. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +5 -3
  14. phoenix/server/api/types/ExperimentRun.py +38 -1
  15. phoenix/server/api/types/GenerativeProvider.py +2 -1
  16. phoenix/server/app.py +21 -2
  17. phoenix/server/grpc_server.py +3 -1
  18. phoenix/server/static/.vite/manifest.json +32 -32
  19. phoenix/server/static/assets/{components-Csu8UKOs.js → components-DU-8CYbi.js} +370 -329
  20. phoenix/server/static/assets/{index-Bk5C9EA7.js → index-D9E16vvV.js} +2 -2
  21. phoenix/server/static/assets/pages-t09OI1rC.js +3966 -0
  22. phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-D04tenE6.js} +181 -181
  23. phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-D3NxMQw0.js} +2 -2
  24. phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-XTiZSlqq.js} +5 -5
  25. phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-p0L0neVs.js} +1 -1
  26. phoenix/session/client.py +27 -7
  27. phoenix/utilities/json.py +31 -1
  28. phoenix/version.py +1 -1
  29. phoenix/server/static/assets/pages-UeWaKXNs.js +0 -3737
  30. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/entry_points.txt +0 -0
  31. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  32. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,12 @@
1
- import json
2
- from dataclasses import asdict
1
+ import asyncio
2
+ from dataclasses import asdict, field
3
3
  from datetime import datetime, timezone
4
4
  from itertools import chain
5
5
  from traceback import format_exc
6
- from typing import Any, Iterable, Iterator, List, Optional
6
+ from typing import Any, Iterable, Iterator, List, Optional, Union
7
7
 
8
8
  import strawberry
9
+ from openinference.instrumentation import safe_json_dumps
9
10
  from openinference.semconv.trace import (
10
11
  MessageAttributes,
11
12
  OpenInferenceMimeTypeValues,
@@ -17,27 +18,52 @@ from openinference.semconv.trace import (
17
18
  from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
18
19
  from opentelemetry.trace import StatusCode
19
20
  from sqlalchemy import insert, select
21
+ from strawberry.relay import GlobalID
20
22
  from strawberry.types import Info
21
23
  from typing_extensions import assert_never
22
24
 
23
25
  from phoenix.datetime_utils import local_now, normalize_datetime
24
26
  from phoenix.db import models
27
+ from phoenix.db.helpers import get_dataset_example_revisions
25
28
  from phoenix.server.api.context import Context
26
- from phoenix.server.api.exceptions import BadRequest
27
- from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
29
+ from phoenix.server.api.exceptions import BadRequest, NotFound
30
+ from phoenix.server.api.helpers.playground_clients import (
31
+ PlaygroundStreamingClient,
32
+ initialize_playground_clients,
33
+ )
28
34
  from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
29
- from phoenix.server.api.input_types.ChatCompletionInput import ChatCompletionInput
35
+ from phoenix.server.api.helpers.playground_spans import (
36
+ input_value_and_mime_type,
37
+ llm_input_messages,
38
+ llm_invocation_parameters,
39
+ llm_model_name,
40
+ llm_span_kind,
41
+ llm_tools,
42
+ )
43
+ from phoenix.server.api.input_types.ChatCompletionInput import (
44
+ ChatCompletionInput,
45
+ ChatCompletionOverDatasetInput,
46
+ )
30
47
  from phoenix.server.api.input_types.TemplateOptions import TemplateOptions
48
+ from phoenix.server.api.subscriptions import (
49
+ _default_playground_experiment_description,
50
+ _default_playground_experiment_metadata,
51
+ _default_playground_experiment_name,
52
+ )
31
53
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
32
54
  from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
33
55
  TextChunk,
34
56
  ToolCallChunk,
35
57
  )
58
+ from phoenix.server.api.types.Dataset import Dataset
59
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
60
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
36
61
  from phoenix.server.api.types.Span import Span, to_gql_span
37
62
  from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
38
63
  from phoenix.server.dml_event import SpanInsertEvent
39
- from phoenix.trace.attributes import unflatten
64
+ from phoenix.trace.attributes import get_attribute_value, unflatten
40
65
  from phoenix.trace.schemas import SpanException
66
+ from phoenix.utilities.json import jsonify
41
67
  from phoenix.utilities.template_formatters import (
42
68
  FStringTemplateFormatter,
43
69
  MustacheTemplateFormatter,
@@ -69,21 +95,192 @@ class ChatCompletionMutationPayload:
69
95
  error_message: Optional[str]
70
96
 
71
97
 
98
+ @strawberry.type
99
+ class ChatCompletionMutationError:
100
+ message: str
101
+
102
+
103
+ @strawberry.type
104
+ class ChatCompletionOverDatasetMutationExamplePayload:
105
+ dataset_example_id: GlobalID
106
+ experiment_run_id: GlobalID
107
+ result: Union[ChatCompletionMutationPayload, ChatCompletionMutationError]
108
+
109
+
110
+ @strawberry.type
111
+ class ChatCompletionOverDatasetMutationPayload:
112
+ dataset_id: GlobalID
113
+ dataset_version_id: GlobalID
114
+ experiment_id: GlobalID
115
+ examples: list[ChatCompletionOverDatasetMutationExamplePayload] = field(default_factory=list)
116
+
117
+
72
118
  @strawberry.type
73
119
  class ChatCompletionMutationMixin:
74
120
  @strawberry.mutation
121
+ @classmethod
122
+ async def chat_completion_over_dataset(
123
+ cls,
124
+ info: Info[Context, None],
125
+ input: ChatCompletionOverDatasetInput,
126
+ ) -> ChatCompletionOverDatasetMutationPayload:
127
+ provider_key = input.model.provider_key
128
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
129
+ if llm_client_class is None:
130
+ raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
131
+ llm_client = llm_client_class(
132
+ model=input.model,
133
+ api_key=input.api_key,
134
+ )
135
+ dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
136
+ dataset_version_id = (
137
+ from_global_id_with_expected_type(
138
+ global_id=input.dataset_version_id, expected_type_name=DatasetVersion.__name__
139
+ )
140
+ if input.dataset_version_id
141
+ else None
142
+ )
143
+ async with info.context.db() as session:
144
+ dataset = await session.scalar(select(models.Dataset).filter_by(id=dataset_id))
145
+ if dataset is None:
146
+ raise NotFound("Dataset not found")
147
+ if dataset_version_id is None:
148
+ resolved_version_id = await session.scalar(
149
+ select(models.DatasetVersion.id)
150
+ .filter_by(dataset_id=dataset_id)
151
+ .order_by(models.DatasetVersion.id.desc())
152
+ .limit(1)
153
+ )
154
+ if resolved_version_id is None:
155
+ raise NotFound("No versions found for the given dataset")
156
+ else:
157
+ resolved_version_id = dataset_version_id
158
+ revisions = [
159
+ revision
160
+ async for revision in await session.stream_scalars(
161
+ get_dataset_example_revisions(resolved_version_id)
162
+ )
163
+ ]
164
+ if not revisions:
165
+ raise NotFound("No examples found for the given dataset and version")
166
+ experiment = models.Experiment(
167
+ dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
168
+ dataset_version_id=resolved_version_id,
169
+ name=input.experiment_name or _default_playground_experiment_name(),
170
+ description=input.experiment_description
171
+ or _default_playground_experiment_description(dataset_name=dataset.name),
172
+ repetitions=1,
173
+ metadata_=input.experiment_metadata
174
+ or _default_playground_experiment_metadata(
175
+ dataset_name=dataset.name,
176
+ dataset_id=input.dataset_id,
177
+ version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
178
+ ),
179
+ project_name=PLAYGROUND_PROJECT_NAME,
180
+ )
181
+ session.add(experiment)
182
+ await session.flush()
183
+
184
+ start_time = datetime.now(timezone.utc)
185
+ results = await asyncio.gather(
186
+ *(
187
+ cls._chat_completion(
188
+ info,
189
+ llm_client,
190
+ ChatCompletionInput(
191
+ model=input.model,
192
+ api_key=input.api_key,
193
+ messages=input.messages,
194
+ tools=input.tools,
195
+ invocation_parameters=input.invocation_parameters,
196
+ template=TemplateOptions(
197
+ language=input.template_language,
198
+ variables=revision.input,
199
+ ),
200
+ ),
201
+ )
202
+ for revision in revisions
203
+ ),
204
+ return_exceptions=True,
205
+ )
206
+
207
+ payload = ChatCompletionOverDatasetMutationPayload(
208
+ dataset_id=GlobalID(models.Dataset.__name__, str(dataset.id)),
209
+ dataset_version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
210
+ experiment_id=GlobalID(models.Experiment.__name__, str(experiment.id)),
211
+ )
212
+ experiment_runs = []
213
+ for revision, result in zip(revisions, results):
214
+ if isinstance(result, BaseException):
215
+ experiment_run = models.ExperimentRun(
216
+ experiment_id=experiment.id,
217
+ dataset_example_id=revision.dataset_example_id,
218
+ output={},
219
+ repetition_number=1,
220
+ start_time=start_time,
221
+ end_time=start_time,
222
+ error=str(result),
223
+ )
224
+ else:
225
+ db_span = result.span.db_span
226
+ experiment_run = models.ExperimentRun(
227
+ experiment_id=experiment.id,
228
+ dataset_example_id=revision.dataset_example_id,
229
+ trace_id=str(result.span.context.trace_id),
230
+ output=models.ExperimentRunOutput(
231
+ task_output=get_attribute_value(db_span.attributes, LLM_OUTPUT_MESSAGES),
232
+ ),
233
+ prompt_token_count=db_span.cumulative_llm_token_count_prompt,
234
+ completion_token_count=db_span.cumulative_llm_token_count_completion,
235
+ repetition_number=1,
236
+ start_time=result.span.start_time,
237
+ end_time=result.span.end_time,
238
+ error=str(result.error_message) if result.error_message else None,
239
+ )
240
+ experiment_runs.append(experiment_run)
241
+
242
+ async with info.context.db() as session:
243
+ session.add_all(experiment_runs)
244
+ await session.flush()
245
+
246
+ for revision, experiment_run, result in zip(revisions, experiment_runs, results):
247
+ dataset_example_id = GlobalID(
248
+ models.DatasetExample.__name__, str(revision.dataset_example_id)
249
+ )
250
+ experiment_run_id = GlobalID(models.ExperimentRun.__name__, str(experiment_run.id))
251
+ example_payload = ChatCompletionOverDatasetMutationExamplePayload(
252
+ dataset_example_id=dataset_example_id,
253
+ experiment_run_id=experiment_run_id,
254
+ result=result
255
+ if isinstance(result, ChatCompletionMutationPayload)
256
+ else ChatCompletionMutationError(message=str(result)),
257
+ )
258
+ payload.examples.append(example_payload)
259
+ return payload
260
+
261
+ @strawberry.mutation
262
+ @classmethod
75
263
  async def chat_completion(
76
- self, info: Info[Context, None], input: ChatCompletionInput
264
+ cls, info: Info[Context, None], input: ChatCompletionInput
77
265
  ) -> ChatCompletionMutationPayload:
78
266
  provider_key = input.model.provider_key
79
267
  llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
80
268
  if llm_client_class is None:
81
269
  raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
82
- attributes: dict[str, Any] = {}
83
270
  llm_client = llm_client_class(
84
271
  model=input.model,
85
272
  api_key=input.api_key,
86
273
  )
274
+ return await cls._chat_completion(info, llm_client, input)
275
+
276
+ @classmethod
277
+ async def _chat_completion(
278
+ cls,
279
+ info: Info[Context, None],
280
+ llm_client: PlaygroundStreamingClient,
281
+ input: ChatCompletionInput,
282
+ ) -> ChatCompletionMutationPayload:
283
+ attributes: dict[str, Any] = {}
87
284
 
88
285
  messages = [
89
286
  (
@@ -94,7 +291,6 @@ class ChatCompletionMutationMixin:
94
291
  )
95
292
  for message in input.messages
96
293
  ]
97
-
98
294
  if template_options := input.template:
99
295
  messages = list(_formatted_messages(messages, template_options))
100
296
 
@@ -103,17 +299,16 @@ class ChatCompletionMutationMixin:
103
299
  )
104
300
 
105
301
  text_content = ""
106
- tool_calls = []
302
+ tool_calls: dict[str, ChatCompletionToolCall] = {}
107
303
  events = []
108
304
  attributes.update(
109
305
  chain(
110
- _llm_span_kind(),
111
- _llm_model_name(input.model.name),
112
- _llm_tools(input.tools or []),
113
- _llm_input_messages(messages),
114
- _llm_invocation_parameters(invocation_parameters),
115
- _input_value_and_mime_type(input),
116
- **llm_client.attributes,
306
+ llm_span_kind(),
307
+ llm_model_name(input.model.name),
308
+ llm_tools(input.tools or []),
309
+ llm_input_messages(messages),
310
+ llm_invocation_parameters(invocation_parameters),
311
+ input_value_and_mime_type(input),
117
312
  )
118
313
  )
119
314
 
@@ -128,14 +323,16 @@ class ChatCompletionMutationMixin:
128
323
  if isinstance(chunk, TextChunk):
129
324
  text_content += chunk.content
130
325
  elif isinstance(chunk, ToolCallChunk):
131
- tool_call = ChatCompletionToolCall(
132
- id=chunk.id,
133
- function=ChatCompletionFunctionCall(
134
- name=chunk.function.name,
135
- arguments=chunk.function.arguments,
136
- ),
137
- )
138
- tool_calls.append(tool_call)
326
+ if chunk.id not in tool_calls:
327
+ tool_calls[chunk.id] = ChatCompletionToolCall(
328
+ id=chunk.id,
329
+ function=ChatCompletionFunctionCall(
330
+ name=chunk.function.name,
331
+ arguments=chunk.function.arguments,
332
+ ),
333
+ )
334
+ else:
335
+ tool_calls[chunk.id].function.arguments += chunk.function.arguments
139
336
  else:
140
337
  assert_never(chunk)
141
338
  except Exception as e:
@@ -156,10 +353,11 @@ class ChatCompletionMutationMixin:
156
353
  else:
157
354
  end_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
158
355
 
356
+ attributes.update(llm_client.attributes)
159
357
  if text_content or tool_calls:
160
358
  attributes.update(
161
359
  chain(
162
- _output_value_and_mime_type({"text": text_content, "tool_calls": tool_calls}),
360
+ _output_value_and_mime_type(text_content, tool_calls),
163
361
  _llm_output_messages(text_content, tool_calls),
164
362
  )
165
363
  )
@@ -225,7 +423,7 @@ class ChatCompletionMutationMixin:
225
423
  else:
226
424
  return ChatCompletionMutationPayload(
227
425
  content=text_content if text_content else None,
228
- tool_calls=tool_calls,
426
+ tool_calls=list(tool_calls.values()),
229
427
  span=gql_span,
230
428
  error_message=None,
231
429
  )
@@ -264,61 +462,30 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
264
462
  assert_never(template_language)
265
463
 
266
464
 
267
- def _llm_span_kind() -> Iterator[tuple[str, Any]]:
268
- yield OPENINFERENCE_SPAN_KIND, LLM
269
-
270
-
271
- def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
272
- yield LLM_MODEL_NAME, model_name
273
-
274
-
275
- def _llm_invocation_parameters(invocation_parameters: dict[str, Any]) -> Iterator[tuple[str, Any]]:
276
- yield LLM_INVOCATION_PARAMETERS, json.dumps(invocation_parameters)
277
-
278
-
279
- def _llm_tools(tools: List[Any]) -> Iterator[tuple[str, Any]]:
280
- for tool_index, tool in enumerate(tools):
281
- yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
282
-
283
-
284
- def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
285
- input_data = input.__dict__.copy()
286
- input_data.pop("api_key", None)
287
- yield INPUT_MIME_TYPE, JSON
288
- yield INPUT_VALUE, json.dumps(input_data)
289
-
290
-
291
- def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
292
- yield OUTPUT_MIME_TYPE, JSON
293
- yield OUTPUT_VALUE, json.dumps(output)
294
-
295
-
296
- def _llm_input_messages(
297
- messages: Iterable[ChatCompletionMessage],
465
+ def _output_value_and_mime_type(
466
+ text: str, tool_calls: dict[str, ChatCompletionToolCall]
298
467
  ) -> Iterator[tuple[str, Any]]:
299
- for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
300
- yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
301
- yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
302
- if tool_calls:
303
- for tool_call_index, tool_call in enumerate(tool_calls):
304
- yield (
305
- f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
306
- tool_call["function"]["name"],
307
- )
308
- if arguments := tool_call["function"]["arguments"]:
309
- yield (
310
- f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
311
- json.dumps(arguments),
312
- )
468
+ if text and tool_calls:
469
+ yield OUTPUT_MIME_TYPE, JSON
470
+ yield (
471
+ OUTPUT_VALUE,
472
+ safe_json_dumps({"content": text, "tool_calls": jsonify(list(tool_calls.values()))}),
473
+ )
474
+ elif tool_calls:
475
+ yield OUTPUT_MIME_TYPE, JSON
476
+ yield OUTPUT_VALUE, safe_json_dumps(jsonify(list(tool_calls.values())))
477
+ elif text:
478
+ yield OUTPUT_MIME_TYPE, TEXT
479
+ yield OUTPUT_VALUE, text
313
480
 
314
481
 
315
482
  def _llm_output_messages(
316
- text_content: str, tool_calls: List[ChatCompletionToolCall]
483
+ text_content: str, tool_calls: dict[str, ChatCompletionToolCall]
317
484
  ) -> Iterator[tuple[str, Any]]:
318
485
  yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
319
486
  if text_content:
320
487
  yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", text_content
321
- for tool_call_index, tool_call in enumerate(tool_calls):
488
+ for tool_call_index, tool_call in enumerate(tool_calls.values()):
322
489
  yield (
323
490
  f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
324
491
  tool_call.function.name,
@@ -326,7 +493,7 @@ def _llm_output_messages(
326
493
  if arguments := tool_call.function.arguments:
327
494
  yield (
328
495
  f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
329
- json.dumps(arguments),
496
+ arguments,
330
497
  )
331
498
 
332
499
 
@@ -347,6 +514,7 @@ def _serialize_event(event: SpanException) -> dict[str, Any]:
347
514
 
348
515
 
349
516
  JSON = OpenInferenceMimeTypeValues.JSON.value
517
+ TEXT = OpenInferenceMimeTypeValues.TEXT.value
350
518
  LLM = OpenInferenceSpanKindValues.LLM.value
351
519
 
352
520
  OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
@@ -48,6 +48,7 @@ from phoenix.server.api.input_types.DatasetSort import DatasetSort
48
48
  from phoenix.server.api.input_types.InvocationParameters import (
49
49
  InvocationParameter,
50
50
  )
51
+ from phoenix.server.api.subscriptions import PLAYGROUND_PROJECT_NAME
51
52
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
52
53
  from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
53
54
  from phoenix.server.api.types.DatasetExample import DatasetExample
@@ -237,7 +238,10 @@ class Query:
237
238
  select(models.Project)
238
239
  .outerjoin(
239
240
  models.Experiment,
240
- models.Project.name == models.Experiment.project_name,
241
+ and_(
242
+ models.Project.name == models.Experiment.project_name,
243
+ models.Experiment.project_name != PLAYGROUND_PROJECT_NAME,
244
+ ),
241
245
  )
242
246
  .where(models.Experiment.project_name.is_(None))
243
247
  .order_by(models.Project.id)
@@ -1,8 +1,11 @@
1
+ from asyncio import get_running_loop
1
2
  from collections.abc import AsyncIterator
2
3
  from datetime import datetime, timezone
4
+ from secrets import token_urlsafe
3
5
  from typing import Any, Literal, Optional
4
6
 
5
- from fastapi import APIRouter, HTTPException, Query
7
+ import pandas as pd
8
+ from fastapi import APIRouter, Header, HTTPException, Query
6
9
  from pydantic import Field
7
10
  from sqlalchemy import select
8
11
  from starlette.requests import Request
@@ -19,6 +22,7 @@ from phoenix.db.insertion.types import Precursors
19
22
  from phoenix.server.api.routers.utils import df_to_bytes
20
23
  from phoenix.server.dml_event import SpanAnnotationInsertEvent
21
24
  from phoenix.trace.dsl import SpanQuery as SpanQuery_
25
+ from phoenix.utilities.json import encode_df_as_json_string
22
26
 
23
27
  from .pydantic_compat import V1RoutesBaseModel
24
28
  from .utils import RequestBody, ResponseBody, add_errors_to_responses
@@ -72,6 +76,7 @@ class QuerySpansRequestBody(V1RoutesBaseModel):
72
76
  async def query_spans_handler(
73
77
  request: Request,
74
78
  request_body: QuerySpansRequestBody,
79
+ accept: Optional[str] = Header(None),
75
80
  project_name: Optional[str] = Query(
76
81
  default=None, description="The project name to get evaluations from"
77
82
  ),
@@ -116,6 +121,13 @@ async def query_spans_handler(
116
121
  if not results:
117
122
  raise HTTPException(status_code=HTTP_404_NOT_FOUND)
118
123
 
124
+ if accept == "application/json":
125
+ boundary_token = token_urlsafe(64)
126
+ return StreamingResponse(
127
+ content=_json_multipart(results, boundary_token),
128
+ media_type=f"multipart/mixed; boundary={boundary_token}",
129
+ )
130
+
119
131
  async def content() -> AsyncIterator[bytes]:
120
132
  for result in results:
121
133
  yield df_to_bytes(result)
@@ -126,6 +138,18 @@ async def query_spans_handler(
126
138
  )
127
139
 
128
140
 
141
+ async def _json_multipart(
142
+ results: list[pd.DataFrame],
143
+ boundary_token: str,
144
+ ) -> AsyncIterator[str]:
145
+ for df in results:
146
+ yield f"--{boundary_token}\r\n"
147
+ yield "Content-Type: application/json\r\n\r\n"
148
+ yield await get_running_loop().run_in_executor(None, encode_df_as_json_string, df)
149
+ yield "\r\n"
150
+ yield f"--{boundary_token}--\r\n"
151
+
152
+
129
153
  @router.get("/spans", include_in_schema=False, deprecated=True)
130
154
  async def get_spans_handler(
131
155
  request: Request,