arize-phoenix 5.8.0__py3-none-any.whl → 5.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (27) hide show
  1. {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/METADATA +1 -1
  2. {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/RECORD +27 -27
  3. {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/WHEEL +1 -1
  4. phoenix/config.py +13 -1
  5. phoenix/db/helpers.py +55 -1
  6. phoenix/server/api/helpers/playground_clients.py +160 -8
  7. phoenix/server/api/mutations/chat_mutations.py +198 -11
  8. phoenix/server/api/queries.py +5 -1
  9. phoenix/server/api/routers/oauth2.py +55 -23
  10. phoenix/server/api/routers/v1/spans.py +25 -1
  11. phoenix/server/api/types/ExperimentRun.py +38 -1
  12. phoenix/server/api/types/GenerativeProvider.py +2 -1
  13. phoenix/server/app.py +7 -2
  14. phoenix/server/static/.vite/manifest.json +32 -32
  15. phoenix/server/static/assets/{components-MllbfxfJ.js → components-BcvRmBnN.js} +320 -297
  16. phoenix/server/static/assets/{index-BVO2YcT1.js → index-BF4RUiOz.js} +2 -2
  17. phoenix/server/static/assets/{pages-BHfC6jnL.js → pages-CM_Zho_x.js} +617 -454
  18. phoenix/server/static/assets/{vendor-BEuNhfwH.js → vendor-Bjm5T3cE.js} +181 -181
  19. phoenix/server/static/assets/{vendor-arizeai-Bskhzyjm.js → vendor-arizeai-CQhWGEdL.js} +2 -2
  20. phoenix/server/static/assets/{vendor-codemirror-DLlXCf0x.js → vendor-codemirror-CdtiO80y.js} +5 -5
  21. phoenix/server/static/assets/{vendor-recharts-CRqhvLYg.js → vendor-recharts-BqWon6Py.js} +1 -1
  22. phoenix/session/client.py +27 -7
  23. phoenix/utilities/json.py +31 -1
  24. phoenix/version.py +1 -1
  25. {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/entry_points.txt +0 -0
  26. {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/licenses/IP_NOTICE +0 -0
  27. {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,9 @@
1
- import json
2
- from dataclasses import asdict
1
+ import asyncio
2
+ from dataclasses import asdict, field
3
3
  from datetime import datetime, timezone
4
4
  from itertools import chain
5
5
  from traceback import format_exc
6
- from typing import Any, Iterable, Iterator, List, Optional
6
+ from typing import Any, Iterable, Iterator, List, Optional, Union
7
7
 
8
8
  import strawberry
9
9
  from openinference.instrumentation import safe_json_dumps
@@ -18,14 +18,19 @@ from openinference.semconv.trace import (
18
18
  from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
19
19
  from opentelemetry.trace import StatusCode
20
20
  from sqlalchemy import insert, select
21
+ from strawberry.relay import GlobalID
21
22
  from strawberry.types import Info
22
23
  from typing_extensions import assert_never
23
24
 
24
25
  from phoenix.datetime_utils import local_now, normalize_datetime
25
26
  from phoenix.db import models
27
+ from phoenix.db.helpers import get_dataset_example_revisions
26
28
  from phoenix.server.api.context import Context
27
- from phoenix.server.api.exceptions import BadRequest
28
- from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
29
+ from phoenix.server.api.exceptions import BadRequest, NotFound
30
+ from phoenix.server.api.helpers.playground_clients import (
31
+ PlaygroundStreamingClient,
32
+ initialize_playground_clients,
33
+ )
29
34
  from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
30
35
  from phoenix.server.api.helpers.playground_spans import (
31
36
  input_value_and_mime_type,
@@ -35,17 +40,28 @@ from phoenix.server.api.helpers.playground_spans import (
35
40
  llm_span_kind,
36
41
  llm_tools,
37
42
  )
38
- from phoenix.server.api.input_types.ChatCompletionInput import ChatCompletionInput
43
+ from phoenix.server.api.input_types.ChatCompletionInput import (
44
+ ChatCompletionInput,
45
+ ChatCompletionOverDatasetInput,
46
+ )
39
47
  from phoenix.server.api.input_types.TemplateOptions import TemplateOptions
48
+ from phoenix.server.api.subscriptions import (
49
+ _default_playground_experiment_description,
50
+ _default_playground_experiment_metadata,
51
+ _default_playground_experiment_name,
52
+ )
40
53
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
41
54
  from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
42
55
  TextChunk,
43
56
  ToolCallChunk,
44
57
  )
58
+ from phoenix.server.api.types.Dataset import Dataset
59
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
60
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
45
61
  from phoenix.server.api.types.Span import Span, to_gql_span
46
62
  from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
47
63
  from phoenix.server.dml_event import SpanInsertEvent
48
- from phoenix.trace.attributes import unflatten
64
+ from phoenix.trace.attributes import get_attribute_value, unflatten
49
65
  from phoenix.trace.schemas import SpanException
50
66
  from phoenix.utilities.json import jsonify
51
67
  from phoenix.utilities.template_formatters import (
@@ -79,21 +95,192 @@ class ChatCompletionMutationPayload:
79
95
  error_message: Optional[str]
80
96
 
81
97
 
98
+ @strawberry.type
99
+ class ChatCompletionMutationError:
100
+ message: str
101
+
102
+
103
+ @strawberry.type
104
+ class ChatCompletionOverDatasetMutationExamplePayload:
105
+ dataset_example_id: GlobalID
106
+ experiment_run_id: GlobalID
107
+ result: Union[ChatCompletionMutationPayload, ChatCompletionMutationError]
108
+
109
+
110
+ @strawberry.type
111
+ class ChatCompletionOverDatasetMutationPayload:
112
+ dataset_id: GlobalID
113
+ dataset_version_id: GlobalID
114
+ experiment_id: GlobalID
115
+ examples: list[ChatCompletionOverDatasetMutationExamplePayload] = field(default_factory=list)
116
+
117
+
82
118
  @strawberry.type
83
119
  class ChatCompletionMutationMixin:
84
120
  @strawberry.mutation
121
+ @classmethod
122
+ async def chat_completion_over_dataset(
123
+ cls,
124
+ info: Info[Context, None],
125
+ input: ChatCompletionOverDatasetInput,
126
+ ) -> ChatCompletionOverDatasetMutationPayload:
127
+ provider_key = input.model.provider_key
128
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
129
+ if llm_client_class is None:
130
+ raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
131
+ llm_client = llm_client_class(
132
+ model=input.model,
133
+ api_key=input.api_key,
134
+ )
135
+ dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
136
+ dataset_version_id = (
137
+ from_global_id_with_expected_type(
138
+ global_id=input.dataset_version_id, expected_type_name=DatasetVersion.__name__
139
+ )
140
+ if input.dataset_version_id
141
+ else None
142
+ )
143
+ async with info.context.db() as session:
144
+ dataset = await session.scalar(select(models.Dataset).filter_by(id=dataset_id))
145
+ if dataset is None:
146
+ raise NotFound("Dataset not found")
147
+ if dataset_version_id is None:
148
+ resolved_version_id = await session.scalar(
149
+ select(models.DatasetVersion.id)
150
+ .filter_by(dataset_id=dataset_id)
151
+ .order_by(models.DatasetVersion.id.desc())
152
+ .limit(1)
153
+ )
154
+ if resolved_version_id is None:
155
+ raise NotFound("No versions found for the given dataset")
156
+ else:
157
+ resolved_version_id = dataset_version_id
158
+ revisions = [
159
+ revision
160
+ async for revision in await session.stream_scalars(
161
+ get_dataset_example_revisions(resolved_version_id)
162
+ )
163
+ ]
164
+ if not revisions:
165
+ raise NotFound("No examples found for the given dataset and version")
166
+ experiment = models.Experiment(
167
+ dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
168
+ dataset_version_id=resolved_version_id,
169
+ name=input.experiment_name or _default_playground_experiment_name(),
170
+ description=input.experiment_description
171
+ or _default_playground_experiment_description(dataset_name=dataset.name),
172
+ repetitions=1,
173
+ metadata_=input.experiment_metadata
174
+ or _default_playground_experiment_metadata(
175
+ dataset_name=dataset.name,
176
+ dataset_id=input.dataset_id,
177
+ version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
178
+ ),
179
+ project_name=PLAYGROUND_PROJECT_NAME,
180
+ )
181
+ session.add(experiment)
182
+ await session.flush()
183
+
184
+ start_time = datetime.now(timezone.utc)
185
+ results = await asyncio.gather(
186
+ *(
187
+ cls._chat_completion(
188
+ info,
189
+ llm_client,
190
+ ChatCompletionInput(
191
+ model=input.model,
192
+ api_key=input.api_key,
193
+ messages=input.messages,
194
+ tools=input.tools,
195
+ invocation_parameters=input.invocation_parameters,
196
+ template=TemplateOptions(
197
+ language=input.template_language,
198
+ variables=revision.input,
199
+ ),
200
+ ),
201
+ )
202
+ for revision in revisions
203
+ ),
204
+ return_exceptions=True,
205
+ )
206
+
207
+ payload = ChatCompletionOverDatasetMutationPayload(
208
+ dataset_id=GlobalID(models.Dataset.__name__, str(dataset.id)),
209
+ dataset_version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
210
+ experiment_id=GlobalID(models.Experiment.__name__, str(experiment.id)),
211
+ )
212
+ experiment_runs = []
213
+ for revision, result in zip(revisions, results):
214
+ if isinstance(result, BaseException):
215
+ experiment_run = models.ExperimentRun(
216
+ experiment_id=experiment.id,
217
+ dataset_example_id=revision.dataset_example_id,
218
+ output={},
219
+ repetition_number=1,
220
+ start_time=start_time,
221
+ end_time=start_time,
222
+ error=str(result),
223
+ )
224
+ else:
225
+ db_span = result.span.db_span
226
+ experiment_run = models.ExperimentRun(
227
+ experiment_id=experiment.id,
228
+ dataset_example_id=revision.dataset_example_id,
229
+ trace_id=str(result.span.context.trace_id),
230
+ output=models.ExperimentRunOutput(
231
+ task_output=get_attribute_value(db_span.attributes, LLM_OUTPUT_MESSAGES),
232
+ ),
233
+ prompt_token_count=db_span.cumulative_llm_token_count_prompt,
234
+ completion_token_count=db_span.cumulative_llm_token_count_completion,
235
+ repetition_number=1,
236
+ start_time=result.span.start_time,
237
+ end_time=result.span.end_time,
238
+ error=str(result.error_message) if result.error_message else None,
239
+ )
240
+ experiment_runs.append(experiment_run)
241
+
242
+ async with info.context.db() as session:
243
+ session.add_all(experiment_runs)
244
+ await session.flush()
245
+
246
+ for revision, experiment_run, result in zip(revisions, experiment_runs, results):
247
+ dataset_example_id = GlobalID(
248
+ models.DatasetExample.__name__, str(revision.dataset_example_id)
249
+ )
250
+ experiment_run_id = GlobalID(models.ExperimentRun.__name__, str(experiment_run.id))
251
+ example_payload = ChatCompletionOverDatasetMutationExamplePayload(
252
+ dataset_example_id=dataset_example_id,
253
+ experiment_run_id=experiment_run_id,
254
+ result=result
255
+ if isinstance(result, ChatCompletionMutationPayload)
256
+ else ChatCompletionMutationError(message=str(result)),
257
+ )
258
+ payload.examples.append(example_payload)
259
+ return payload
260
+
261
+ @strawberry.mutation
262
+ @classmethod
85
263
  async def chat_completion(
86
- self, info: Info[Context, None], input: ChatCompletionInput
264
+ cls, info: Info[Context, None], input: ChatCompletionInput
87
265
  ) -> ChatCompletionMutationPayload:
88
266
  provider_key = input.model.provider_key
89
267
  llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
90
268
  if llm_client_class is None:
91
269
  raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
92
- attributes: dict[str, Any] = {}
93
270
  llm_client = llm_client_class(
94
271
  model=input.model,
95
272
  api_key=input.api_key,
96
273
  )
274
+ return await cls._chat_completion(info, llm_client, input)
275
+
276
+ @classmethod
277
+ async def _chat_completion(
278
+ cls,
279
+ info: Info[Context, None],
280
+ llm_client: PlaygroundStreamingClient,
281
+ input: ChatCompletionInput,
282
+ ) -> ChatCompletionMutationPayload:
283
+ attributes: dict[str, Any] = {}
97
284
 
98
285
  messages = [
99
286
  (
@@ -122,7 +309,6 @@ class ChatCompletionMutationMixin:
122
309
  llm_input_messages(messages),
123
310
  llm_invocation_parameters(invocation_parameters),
124
311
  input_value_and_mime_type(input),
125
- **llm_client.attributes,
126
312
  )
127
313
  )
128
314
 
@@ -167,6 +353,7 @@ class ChatCompletionMutationMixin:
167
353
  else:
168
354
  end_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
169
355
 
356
+ attributes.update(llm_client.attributes)
170
357
  if text_content or tool_calls:
171
358
  attributes.update(
172
359
  chain(
@@ -306,7 +493,7 @@ def _llm_output_messages(
306
493
  if arguments := tool_call.function.arguments:
307
494
  yield (
308
495
  f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
309
- json.dumps(arguments),
496
+ arguments,
310
497
  )
311
498
 
312
499
 
@@ -48,6 +48,7 @@ from phoenix.server.api.input_types.DatasetSort import DatasetSort
48
48
  from phoenix.server.api.input_types.InvocationParameters import (
49
49
  InvocationParameter,
50
50
  )
51
+ from phoenix.server.api.subscriptions import PLAYGROUND_PROJECT_NAME
51
52
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
52
53
  from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
53
54
  from phoenix.server.api.types.DatasetExample import DatasetExample
@@ -237,7 +238,10 @@ class Query:
237
238
  select(models.Project)
238
239
  .outerjoin(
239
240
  models.Experiment,
240
- models.Project.name == models.Experiment.project_name,
241
+ and_(
242
+ models.Project.name == models.Experiment.project_name,
243
+ models.Experiment.project_name != PLAYGROUND_PROJECT_NAME,
244
+ ),
241
245
  )
242
246
  .where(models.Experiment.project_name.is_(None))
243
247
  .order_by(models.Project.id)
@@ -14,7 +14,7 @@ from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, upd
14
14
  from sqlalchemy.ext.asyncio import AsyncSession
15
15
  from sqlalchemy.orm import joinedload
16
16
  from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
17
- from starlette.datastructures import URL
17
+ from starlette.datastructures import URL, URLPath
18
18
  from starlette.responses import RedirectResponse
19
19
  from starlette.routing import Router
20
20
  from starlette.status import HTTP_302_FOUND
@@ -86,8 +86,16 @@ async def login(
86
86
  if not isinstance(
87
87
  oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
88
88
  ):
89
- return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
90
- origin_url = _get_origin_url(request)
89
+ return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
90
+ if (referer := request.headers.get("referer")) is not None:
91
+ # if the referer header is present, use it as the origin URL
92
+ parsed_url = urlparse(referer)
93
+ origin_url = _append_root_path_if_exists(
94
+ request=request, base_url=f"{parsed_url.scheme}://{parsed_url.netloc}"
95
+ )
96
+ else:
97
+ # fall back to the base url as the origin URL
98
+ origin_url = str(request.base_url)
91
99
  authorization_url_data = await oauth2_client.create_authorization_url(
92
100
  redirect_uri=_get_create_tokens_endpoint(
93
101
  request=request, origin_url=origin_url, idp_name=idp_name
@@ -124,22 +132,22 @@ async def create_tokens(
124
132
  ) -> RedirectResponse:
125
133
  secret = request.app.state.get_secret()
126
134
  if state != stored_state:
127
- return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE)
135
+ return _redirect_to_login(request=request, error=_INVALID_OAUTH2_STATE_MESSAGE)
128
136
  try:
129
137
  payload = _parse_state_payload(secret=secret, state=state)
130
138
  except JoseError:
131
- return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE)
139
+ return _redirect_to_login(request=request, error=_INVALID_OAUTH2_STATE_MESSAGE)
132
140
  if (return_url := payload.get("return_url")) is not None and not _is_relative_url(
133
141
  unquote(return_url)
134
142
  ):
135
- return _redirect_to_login(error="Attempting login with unsafe return URL.")
143
+ return _redirect_to_login(request=request, error="Attempting login with unsafe return URL.")
136
144
  assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
137
145
  assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
138
146
  token_store: TokenStore = request.app.state.get_token_store()
139
147
  if not isinstance(
140
148
  oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
141
149
  ):
142
- return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
150
+ return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
143
151
  try:
144
152
  token_data = await oauth2_client.fetch_access_token(
145
153
  state=state,
@@ -149,11 +157,12 @@ async def create_tokens(
149
157
  ),
150
158
  )
151
159
  except OAuthError as error:
152
- return _redirect_to_login(error=str(error))
160
+ return _redirect_to_login(request=request, error=str(error))
153
161
  _validate_token_data(token_data)
154
162
  if "id_token" not in token_data:
155
163
  return _redirect_to_login(
156
- error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect."
164
+ request=request,
165
+ error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect.",
157
166
  )
158
167
  user_info = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
159
168
  user_info = _parse_user_info(user_info)
@@ -165,14 +174,18 @@ async def create_tokens(
165
174
  user_info=user_info,
166
175
  )
167
176
  except EmailAlreadyInUse as error:
168
- return _redirect_to_login(error=str(error))
177
+ return _redirect_to_login(request=request, error=str(error))
169
178
  access_token, refresh_token = await create_access_and_refresh_tokens(
170
179
  user=user,
171
180
  token_store=token_store,
172
181
  access_token_expiry=access_token_expiry,
173
182
  refresh_token_expiry=refresh_token_expiry,
174
183
  )
175
- response = RedirectResponse(url=return_url or "/", status_code=HTTP_302_FOUND)
184
+ redirect_path = _prepend_root_path_if_exists(request=request, path=return_url or "/")
185
+ response = RedirectResponse(
186
+ url=redirect_path,
187
+ status_code=HTTP_302_FOUND,
188
+ )
176
189
  response = set_access_token_cookie(
177
190
  response=response, access_token=access_token, max_age=access_token_expiry
178
191
  )
@@ -352,17 +365,46 @@ class EmailAlreadyInUse(Exception):
352
365
  pass
353
366
 
354
367
 
355
- def _redirect_to_login(*, error: str) -> RedirectResponse:
368
+ def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
356
369
  """
357
370
  Creates a RedirectResponse to the login page to display an error message.
358
371
  """
359
- url = URL("/login").include_query_params(error=error)
372
+ login_path = _prepend_root_path_if_exists(request=request, path="/login")
373
+ url = URL(login_path).include_query_params(error=error)
360
374
  response = RedirectResponse(url=url)
361
375
  response = delete_oauth2_state_cookie(response)
362
376
  response = delete_oauth2_nonce_cookie(response)
363
377
  return response
364
378
 
365
379
 
380
+ def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
381
+ """
382
+ If a root path is configured, prepends it to the input path.
383
+ """
384
+ if not path.startswith("/"):
385
+ raise ValueError("path must start with a forward slash")
386
+ root_path = _get_root_path(request=request)
387
+ if root_path.endswith("/"):
388
+ root_path = root_path.rstrip("/")
389
+ return root_path + path
390
+
391
+
392
+ def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
393
+ """
394
+ If a root path is configured, appends it to the input base url.
395
+ """
396
+ if not (root_path := _get_root_path(request=request)):
397
+ return base_url
398
+ return str(URLPath(root_path).make_absolute_url(base_url=base_url))
399
+
400
+
401
+ def _get_root_path(*, request: Request) -> str:
402
+ """
403
+ Gets the root path from the request.
404
+ """
405
+ return str(request.scope.get("root_path", ""))
406
+
407
+
366
408
  def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
367
409
  """
368
410
  Gets the endpoint for create tokens route.
@@ -427,16 +469,6 @@ def _with_random_suffix(string: str) -> str:
427
469
  return f"{string}-{randrange(10_000, 100_000)}"
428
470
 
429
471
 
430
- def _get_origin_url(request: Request) -> str:
431
- """
432
- Infers the origin URL from the request.
433
- """
434
- if (referer := request.headers.get("referer")) is None:
435
- return str(request.base_url)
436
- parsed_url = urlparse(referer)
437
- return f"{parsed_url.scheme}://{parsed_url.netloc}"
438
-
439
-
440
472
  def _is_oauth2_state_payload(maybe_state_payload: Any) -> TypeGuard[_OAuth2StatePayload]:
441
473
  """
442
474
  Determines whether the given object is an OAuth2 state payload.
@@ -1,8 +1,11 @@
1
+ from asyncio import get_running_loop
1
2
  from collections.abc import AsyncIterator
2
3
  from datetime import datetime, timezone
4
+ from secrets import token_urlsafe
3
5
  from typing import Any, Literal, Optional
4
6
 
5
- from fastapi import APIRouter, HTTPException, Query
7
+ import pandas as pd
8
+ from fastapi import APIRouter, Header, HTTPException, Query
6
9
  from pydantic import Field
7
10
  from sqlalchemy import select
8
11
  from starlette.requests import Request
@@ -19,6 +22,7 @@ from phoenix.db.insertion.types import Precursors
19
22
  from phoenix.server.api.routers.utils import df_to_bytes
20
23
  from phoenix.server.dml_event import SpanAnnotationInsertEvent
21
24
  from phoenix.trace.dsl import SpanQuery as SpanQuery_
25
+ from phoenix.utilities.json import encode_df_as_json_string
22
26
 
23
27
  from .pydantic_compat import V1RoutesBaseModel
24
28
  from .utils import RequestBody, ResponseBody, add_errors_to_responses
@@ -72,6 +76,7 @@ class QuerySpansRequestBody(V1RoutesBaseModel):
72
76
  async def query_spans_handler(
73
77
  request: Request,
74
78
  request_body: QuerySpansRequestBody,
79
+ accept: Optional[str] = Header(None),
75
80
  project_name: Optional[str] = Query(
76
81
  default=None, description="The project name to get evaluations from"
77
82
  ),
@@ -116,6 +121,13 @@ async def query_spans_handler(
116
121
  if not results:
117
122
  raise HTTPException(status_code=HTTP_404_NOT_FOUND)
118
123
 
124
+ if accept == "application/json":
125
+ boundary_token = token_urlsafe(64)
126
+ return StreamingResponse(
127
+ content=_json_multipart(results, boundary_token),
128
+ media_type=f"multipart/mixed; boundary={boundary_token}",
129
+ )
130
+
119
131
  async def content() -> AsyncIterator[bytes]:
120
132
  for result in results:
121
133
  yield df_to_bytes(result)
@@ -126,6 +138,18 @@ async def query_spans_handler(
126
138
  )
127
139
 
128
140
 
141
+ async def _json_multipart(
142
+ results: list[pd.DataFrame],
143
+ boundary_token: str,
144
+ ) -> AsyncIterator[str]:
145
+ for df in results:
146
+ yield f"--{boundary_token}\r\n"
147
+ yield "Content-Type: application/json\r\n\r\n"
148
+ yield await get_running_loop().run_in_executor(None, encode_df_as_json_string, df)
149
+ yield "\r\n"
150
+ yield f"--{boundary_token}--\r\n"
151
+
152
+
129
153
  @router.get("/spans", include_in_schema=False, deprecated=True)
130
154
  async def get_spans_handler(
131
155
  request: Request,
@@ -1,7 +1,9 @@
1
1
  from datetime import datetime
2
- from typing import Optional
2
+ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
+ from sqlalchemy import select
6
+ from sqlalchemy.orm import load_only
5
7
  from strawberry import UNSET
6
8
  from strawberry.relay import Connection, GlobalID, Node, NodeID
7
9
  from strawberry.scalars import JSON
@@ -20,6 +22,9 @@ from phoenix.server.api.types.pagination import (
20
22
  )
21
23
  from phoenix.server.api.types.Trace import Trace
22
24
 
25
+ if TYPE_CHECKING:
26
+ from phoenix.server.api.types.DatasetExample import DatasetExample
27
+
23
28
 
24
29
  @strawberry.type
25
30
  class ExperimentRun(Node):
@@ -62,6 +67,38 @@ class ExperimentRun(Node):
62
67
  trace_rowid, project_rowid = trace
63
68
  return Trace(id_attr=trace_rowid, trace_id=self.trace_id, project_rowid=project_rowid)
64
69
 
70
+ @strawberry.field
71
+ async def example(
72
+ self, info: Info
73
+ ) -> Annotated[
74
+ "DatasetExample", strawberry.lazy("phoenix.server.api.types.DatasetExample")
75
+ ]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
76
+ from phoenix.server.api.types.DatasetExample import DatasetExample
77
+
78
+ async with info.context.db() as session:
79
+ assert (
80
+ result := await session.execute(
81
+ select(models.DatasetExample, models.Experiment.dataset_version_id)
82
+ .select_from(models.ExperimentRun)
83
+ .join(
84
+ models.DatasetExample,
85
+ models.DatasetExample.id == models.ExperimentRun.dataset_example_id,
86
+ )
87
+ .join(
88
+ models.Experiment,
89
+ models.Experiment.id == models.ExperimentRun.experiment_id,
90
+ )
91
+ .where(models.ExperimentRun.id == self.id_attr)
92
+ .options(load_only(models.DatasetExample.id, models.DatasetExample.created_at))
93
+ )
94
+ ) is not None
95
+ example, version_id = result.first()
96
+ return DatasetExample(
97
+ id_attr=example.id,
98
+ created_at=example.created_at,
99
+ version_id=version_id,
100
+ )
101
+
65
102
 
66
103
  def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
67
104
  """
@@ -8,6 +8,7 @@ class GenerativeProviderKey(Enum):
8
8
  OPENAI = "OpenAI"
9
9
  ANTHROPIC = "Anthropic"
10
10
  AZURE_OPENAI = "Azure OpenAI"
11
+ GEMINI = "Google AI Studio"
11
12
 
12
13
 
13
14
  @strawberry.type
@@ -24,7 +25,7 @@ class GenerativeProvider:
24
25
 
25
26
  default_client = PLAYGROUND_CLIENT_REGISTRY.get_client(self.key, PROVIDER_DEFAULT)
26
27
  if default_client:
27
- return default_client.dependencies()
28
+ return [dependency.name for dependency in default_client.dependencies()]
28
29
  return []
29
30
 
30
31
  @strawberry.field
phoenix/server/app.py CHANGED
@@ -253,7 +253,7 @@ class Static(StaticFiles):
253
253
 
254
254
 
255
255
  class RequestOriginHostnameValidator(BaseHTTPMiddleware):
256
- def __init__(self, trusted_hostnames: list[str], *args: Any, **kwargs: Any) -> None:
256
+ def __init__(self, *args: Any, trusted_hostnames: list[str], **kwargs: Any) -> None:
257
257
  super().__init__(*args, **kwargs)
258
258
  self._trusted_hostnames = trusted_hostnames
259
259
 
@@ -767,7 +767,12 @@ def create_app(
767
767
  middlewares.extend(user_fastapi_middlewares())
768
768
  if origins := get_env_csrf_trusted_origins():
769
769
  trusted_hostnames = [h for o in origins if o and (h := urlparse(o).hostname)]
770
- middlewares.append(Middleware(RequestOriginHostnameValidator, trusted_hostnames))
770
+ middlewares.append(
771
+ Middleware(
772
+ RequestOriginHostnameValidator,
773
+ trusted_hostnames=trusted_hostnames,
774
+ )
775
+ )
771
776
  elif email_sender or oauth2_client_configs:
772
777
  logger.warning(
773
778
  "CSRF protection can be enabled by listing trusted origins via "