arize-phoenix 5.7.0__py3-none-any.whl → 5.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (32) hide show
  1. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/METADATA +3 -5
  2. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/RECORD +31 -31
  3. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/WHEEL +1 -1
  4. phoenix/config.py +19 -3
  5. phoenix/db/helpers.py +55 -1
  6. phoenix/server/api/helpers/playground_clients.py +283 -44
  7. phoenix/server/api/helpers/playground_spans.py +173 -76
  8. phoenix/server/api/input_types/InvocationParameters.py +7 -8
  9. phoenix/server/api/mutations/chat_mutations.py +244 -76
  10. phoenix/server/api/queries.py +5 -1
  11. phoenix/server/api/routers/v1/spans.py +25 -1
  12. phoenix/server/api/subscriptions.py +210 -158
  13. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +5 -3
  14. phoenix/server/api/types/ExperimentRun.py +38 -1
  15. phoenix/server/api/types/GenerativeProvider.py +2 -1
  16. phoenix/server/app.py +21 -2
  17. phoenix/server/grpc_server.py +3 -1
  18. phoenix/server/static/.vite/manifest.json +32 -32
  19. phoenix/server/static/assets/{components-Csu8UKOs.js → components-DU-8CYbi.js} +370 -329
  20. phoenix/server/static/assets/{index-Bk5C9EA7.js → index-D9E16vvV.js} +2 -2
  21. phoenix/server/static/assets/pages-t09OI1rC.js +3966 -0
  22. phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-D04tenE6.js} +181 -181
  23. phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-D3NxMQw0.js} +2 -2
  24. phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-XTiZSlqq.js} +5 -5
  25. phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-p0L0neVs.js} +1 -1
  26. phoenix/session/client.py +27 -7
  27. phoenix/utilities/json.py +31 -1
  28. phoenix/version.py +1 -1
  29. phoenix/server/static/assets/pages-UeWaKXNs.js +0 -3737
  30. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/entry_points.txt +0 -0
  31. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  32. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,16 @@
1
+ import asyncio
1
2
  import logging
2
- from asyncio import FIRST_COMPLETED, Task, create_task, wait
3
- from collections.abc import Iterator
3
+ from asyncio import FIRST_COMPLETED, Queue, QueueEmpty, Task, create_task, wait, wait_for
4
+ from collections.abc import AsyncIterator, Iterator
5
+ from datetime import datetime, timezone
4
6
  from typing import (
5
7
  Any,
6
- AsyncIterator,
7
- Collection,
8
8
  Iterable,
9
9
  Mapping,
10
10
  Optional,
11
+ Sequence,
11
12
  TypeVar,
13
+ cast,
12
14
  )
13
15
 
14
16
  import strawberry
@@ -19,9 +21,10 @@ from strawberry.relay.types import GlobalID
19
21
  from strawberry.types import Info
20
22
  from typing_extensions import TypeAlias, assert_never
21
23
 
24
+ from phoenix.datetime_utils import local_now, normalize_datetime
22
25
  from phoenix.db import models
23
26
  from phoenix.server.api.context import Context
24
- from phoenix.server.api.exceptions import BadRequest
27
+ from phoenix.server.api.exceptions import BadRequest, NotFound
25
28
  from phoenix.server.api.helpers.playground_clients import (
26
29
  PlaygroundStreamingClient,
27
30
  initialize_playground_clients,
@@ -29,27 +32,33 @@ from phoenix.server.api.helpers.playground_clients import (
29
32
  from phoenix.server.api.helpers.playground_registry import (
30
33
  PLAYGROUND_CLIENT_REGISTRY,
31
34
  )
32
- from phoenix.server.api.helpers.playground_spans import streaming_llm_span
35
+ from phoenix.server.api.helpers.playground_spans import (
36
+ get_db_experiment_run,
37
+ get_db_span,
38
+ get_db_trace,
39
+ streaming_llm_span,
40
+ )
33
41
  from phoenix.server.api.input_types.ChatCompletionInput import (
34
42
  ChatCompletionInput,
35
43
  ChatCompletionOverDatasetInput,
36
44
  )
37
45
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
38
46
  from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
39
- ChatCompletionOverDatasetSubscriptionResult,
40
47
  ChatCompletionSubscriptionError,
48
+ ChatCompletionSubscriptionExperiment,
41
49
  ChatCompletionSubscriptionPayload,
42
- FinishedChatCompletion,
50
+ ChatCompletionSubscriptionResult,
43
51
  )
44
52
  from phoenix.server.api.types.Dataset import Dataset
45
53
  from phoenix.server.api.types.DatasetExample import DatasetExample
46
54
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
47
55
  from phoenix.server.api.types.Experiment import to_gql_experiment
56
+ from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
48
57
  from phoenix.server.api.types.node import from_global_id_with_expected_type
49
58
  from phoenix.server.api.types.Span import to_gql_span
50
59
  from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
51
60
  from phoenix.server.dml_event import SpanInsertEvent
52
- from phoenix.trace.attributes import get_attribute_value
61
+ from phoenix.server.types import DbSessionFactory
53
62
  from phoenix.utilities.template_formatters import (
54
63
  FStringTemplateFormatter,
55
64
  MustacheTemplateFormatter,
@@ -67,6 +76,9 @@ ChatCompletionMessage: TypeAlias = tuple[
67
76
  ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
68
77
  ]
69
78
  DatasetExampleID: TypeAlias = GlobalID
79
+ ChatCompletionResult: TypeAlias = tuple[
80
+ DatasetExampleID, Optional[models.Span], models.ExperimentRun
81
+ ]
70
82
  PLAYGROUND_PROJECT_NAME = "playground"
71
83
 
72
84
 
@@ -116,8 +128,8 @@ class Subscription:
116
128
  span.add_response_chunk(chunk)
117
129
  yield chunk
118
130
  span.set_attributes(llm_client.attributes)
119
- if span.error_message is not None:
120
- yield ChatCompletionSubscriptionError(message=span.error_message)
131
+ if span.status_message is not None:
132
+ yield ChatCompletionSubscriptionError(message=span.status_message)
121
133
  async with info.context.db() as session:
122
134
  if (
123
135
  playground_project_id := await session.scalar(
@@ -132,10 +144,12 @@ class Subscription:
132
144
  description="Traces from prompt playground",
133
145
  )
134
146
  )
135
- db_span = span.add_to_session(session, playground_project_id)
147
+ db_trace = get_db_trace(span, playground_project_id)
148
+ db_span = get_db_span(span, db_trace)
149
+ session.add(db_span)
136
150
  await session.flush()
137
- yield FinishedChatCompletion(span=to_gql_span(db_span))
138
151
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
152
+ yield ChatCompletionSubscriptionResult(span=to_gql_span(db_span))
139
153
 
140
154
  @strawberry.subscription
141
155
  async def chat_completion_over_dataset(
@@ -154,58 +168,68 @@ class Subscription:
154
168
  if input.dataset_version_id
155
169
  else None
156
170
  )
157
- revision_ids = (
158
- select(func.max(models.DatasetExampleRevision.id))
159
- .join(models.DatasetExample)
160
- .where(models.DatasetExample.dataset_id == dataset_id)
161
- .group_by(models.DatasetExampleRevision.dataset_example_id)
162
- )
163
- if version_id:
164
- version_id_subquery = (
165
- select(models.DatasetVersion.id)
166
- .where(models.DatasetVersion.dataset_id == dataset_id)
167
- .where(models.DatasetVersion.id == version_id)
168
- .scalar_subquery()
169
- )
170
- revision_ids = revision_ids.where(
171
- models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
172
- )
173
- query = (
174
- select(models.DatasetExampleRevision)
175
- .where(
176
- and_(
177
- models.DatasetExampleRevision.id.in_(revision_ids),
178
- models.DatasetExampleRevision.revision_kind != "DELETE",
171
+ async with info.context.db() as session:
172
+ if (
173
+ dataset := await session.scalar(
174
+ select(models.Dataset).where(models.Dataset.id == dataset_id)
179
175
  )
180
- )
181
- .order_by(models.DatasetExampleRevision.dataset_example_id.asc())
182
- .options(
183
- load_only(
184
- models.DatasetExampleRevision.dataset_example_id,
185
- models.DatasetExampleRevision.input,
176
+ ) is None:
177
+ raise NotFound(f"Could not find dataset with ID {dataset_id}")
178
+ if version_id is None:
179
+ if (
180
+ resolved_version_id := await session.scalar(
181
+ select(models.DatasetVersion.id)
182
+ .where(models.DatasetVersion.dataset_id == dataset_id)
183
+ .order_by(models.DatasetVersion.id.desc())
184
+ .limit(1)
185
+ )
186
+ ) is None:
187
+ raise NotFound(f"No versions found for dataset with ID {dataset_id}")
188
+ else:
189
+ if (
190
+ resolved_version_id := await session.scalar(
191
+ select(models.DatasetVersion.id).where(
192
+ and_(
193
+ models.DatasetVersion.dataset_id == dataset_id,
194
+ models.DatasetVersion.id == version_id,
195
+ )
196
+ )
197
+ )
198
+ ) is None:
199
+ raise NotFound(f"Could not find dataset version with ID {version_id}")
200
+ revision_ids = (
201
+ select(func.max(models.DatasetExampleRevision.id))
202
+ .join(models.DatasetExample)
203
+ .where(
204
+ and_(
205
+ models.DatasetExample.dataset_id == dataset_id,
206
+ models.DatasetExampleRevision.dataset_version_id <= resolved_version_id,
207
+ )
186
208
  )
209
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
187
210
  )
188
- )
189
- async with info.context.db() as session:
190
- revisions = [revision async for revision in await session.stream_scalars(query)]
191
- if not revisions:
192
- raise BadRequest("No examples found for the given dataset and version")
193
-
194
- spans: dict[DatasetExampleID, streaming_llm_span] = {}
195
- async for payload in _merge_iterators(
196
- [
197
- _stream_chat_completion_over_dataset_example(
198
- input=input,
199
- llm_client_class=llm_client_class,
200
- revision=revision,
201
- spans=spans,
202
- )
203
- for revision in revisions
204
- ]
205
- ):
206
- yield payload
207
-
208
- async with info.context.db() as session:
211
+ if not (
212
+ revisions := [
213
+ rev
214
+ async for rev in await session.stream_scalars(
215
+ select(models.DatasetExampleRevision)
216
+ .where(
217
+ and_(
218
+ models.DatasetExampleRevision.id.in_(revision_ids),
219
+ models.DatasetExampleRevision.revision_kind != "DELETE",
220
+ )
221
+ )
222
+ .order_by(models.DatasetExampleRevision.dataset_example_id.asc())
223
+ .options(
224
+ load_only(
225
+ models.DatasetExampleRevision.dataset_example_id,
226
+ models.DatasetExampleRevision.input,
227
+ )
228
+ )
229
+ )
230
+ ]
231
+ ):
232
+ raise NotFound("No examples found for the given dataset and version")
209
233
  if (
210
234
  playground_project_id := await session.scalar(
211
235
  select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
@@ -219,81 +243,70 @@ class Subscription:
219
243
  description="Traces from prompt playground",
220
244
  )
221
245
  )
222
- db_spans = {
223
- example_id: span.add_to_session(session, playground_project_id)
224
- for example_id, span in spans.items()
225
- }
226
- assert (
227
- dataset_name := await session.scalar(
228
- select(models.Dataset.name).where(models.Dataset.id == dataset_id)
229
- )
230
- ) is not None
231
- if version_id is None:
232
- resolved_version_id = await session.scalar(
233
- select(models.DatasetVersion.id)
234
- .where(models.DatasetVersion.dataset_id == dataset_id)
235
- .order_by(models.DatasetVersion.id.desc())
236
- .limit(1)
237
- )
238
- else:
239
- resolved_version_id = await session.scalar(
240
- select(models.DatasetVersion.id).where(
241
- and_(
242
- models.DatasetVersion.dataset_id == dataset_id,
243
- models.DatasetVersion.id == version_id,
244
- )
245
- )
246
- )
247
- assert resolved_version_id is not None
248
- resolved_version_node_id = GlobalID(DatasetVersion.__name__, str(resolved_version_id))
249
246
  experiment = models.Experiment(
250
247
  dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
251
248
  dataset_version_id=resolved_version_id,
252
- name=input.experiment_name or _DEFAULT_PLAYGROUND_EXPERIMENT_NAME,
249
+ name=input.experiment_name or _default_playground_experiment_name(),
253
250
  description=input.experiment_description
254
- or _default_playground_experiment_description(dataset_name=dataset_name),
251
+ or _default_playground_experiment_description(dataset_name=dataset.name),
255
252
  repetitions=1,
256
253
  metadata_=input.experiment_metadata
257
254
  or _default_playground_experiment_metadata(
258
- dataset_name=dataset_name,
255
+ dataset_name=dataset.name,
259
256
  dataset_id=input.dataset_id,
260
- version_id=resolved_version_node_id,
257
+ version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
261
258
  ),
262
259
  project_name=PLAYGROUND_PROJECT_NAME,
263
260
  )
264
261
  session.add(experiment)
265
262
  await session.flush()
266
- runs = [
267
- models.ExperimentRun(
268
- experiment_id=experiment.id,
269
- dataset_example_id=from_global_id_with_expected_type(
270
- example_id, DatasetExample.__name__
271
- ),
272
- trace_id=span.trace_id,
273
- output=models.ExperimentRunOutput(
274
- task_output=_get_playground_experiment_task_output(span)
275
- ),
276
- repetition_number=1,
277
- start_time=span.start_time,
278
- end_time=span.end_time,
279
- error=error_message
280
- if (error_message := span.error_message) is not None
281
- else None,
282
- prompt_token_count=get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT),
283
- completion_token_count=get_attribute_value(
284
- span.attributes, LLM_TOKEN_COUNT_COMPLETION
285
- ),
286
- )
287
- for example_id, span in spans.items()
288
- ]
289
- session.add_all(runs)
290
- await session.flush()
291
- for example_id in spans:
292
- yield FinishedChatCompletion(
293
- span=to_gql_span(db_spans[example_id]),
294
- dataset_example_id=example_id,
263
+ yield ChatCompletionSubscriptionExperiment(
264
+ experiment=to_gql_experiment(experiment)
265
+ ) # eagerly yields experiment so it can be linked by consumers of the subscription
266
+
267
+ results_queue: Queue[ChatCompletionResult] = Queue()
268
+ chat_completion_streams = [
269
+ _stream_chat_completion_over_dataset_example(
270
+ input=input,
271
+ llm_client_class=llm_client_class,
272
+ revision=revision,
273
+ results_queue=results_queue,
274
+ experiment_id=experiment.id,
275
+ project_id=playground_project_id,
295
276
  )
296
- yield ChatCompletionOverDatasetSubscriptionResult(experiment=to_gql_experiment(experiment))
277
+ for revision in revisions
278
+ ]
279
+ stream_to_async_tasks: dict[
280
+ AsyncIterator[ChatCompletionSubscriptionPayload],
281
+ Task[ChatCompletionSubscriptionPayload],
282
+ ] = {iterator: _create_task_with_timeout(iterator) for iterator in chat_completion_streams}
283
+ batch_size = 10
284
+ while stream_to_async_tasks:
285
+ async_tasks_to_run = [task for task in stream_to_async_tasks.values()]
286
+ completed_tasks, _ = await wait(async_tasks_to_run, return_when=FIRST_COMPLETED)
287
+ for task in completed_tasks:
288
+ iterator = next(it for it, t in stream_to_async_tasks.items() if t == task)
289
+ try:
290
+ yield task.result()
291
+ except (StopAsyncIteration, asyncio.TimeoutError):
292
+ del stream_to_async_tasks[iterator] # removes exhausted iterator
293
+ except Exception as error:
294
+ del stream_to_async_tasks[iterator] # removes failed iterator
295
+ logger.exception(error)
296
+ else:
297
+ stream_to_async_tasks[iterator] = _create_task_with_timeout(iterator)
298
+ if results_queue.qsize() >= batch_size:
299
+ result_iterator = _chat_completion_result_payloads(
300
+ db=info.context.db, results=_drain_no_wait(results_queue)
301
+ )
302
+ stream_to_async_tasks[result_iterator] = _create_task_with_timeout(
303
+ result_iterator
304
+ )
305
+ if remaining_results := await _drain(results_queue):
306
+ async for result_payload in _chat_completion_result_payloads(
307
+ db=info.context.db, results=remaining_results
308
+ ):
309
+ yield result_payload
297
310
 
298
311
 
299
312
  async def _stream_chat_completion_over_dataset_example(
@@ -301,7 +314,9 @@ async def _stream_chat_completion_over_dataset_example(
301
314
  input: ChatCompletionOverDatasetInput,
302
315
  llm_client_class: type["PlaygroundStreamingClient"],
303
316
  revision: models.DatasetExampleRevision,
304
- spans: dict[DatasetExampleID, streaming_llm_span],
317
+ results_queue: Queue[ChatCompletionResult],
318
+ experiment_id: int,
319
+ project_id: int,
305
320
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
306
321
  example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
307
322
  llm_client = llm_client_class(
@@ -319,6 +334,7 @@ async def _stream_chat_completion_over_dataset_example(
319
334
  for message in input.messages
320
335
  ]
321
336
  try:
337
+ format_start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
322
338
  messages = list(
323
339
  _formatted_messages(
324
340
  messages=messages,
@@ -327,15 +343,31 @@ async def _stream_chat_completion_over_dataset_example(
327
343
  )
328
344
  )
329
345
  except TemplateFormatterError as error:
346
+ format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
330
347
  yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
348
+ await results_queue.put(
349
+ (
350
+ example_id,
351
+ None,
352
+ models.ExperimentRun(
353
+ experiment_id=experiment_id,
354
+ dataset_example_id=revision.dataset_example_id,
355
+ trace_id=None,
356
+ output={},
357
+ repetition_number=1,
358
+ start_time=format_start_time,
359
+ end_time=format_end_time,
360
+ error=str(error),
361
+ trace=None,
362
+ ),
363
+ )
364
+ )
331
365
  return
332
- span = streaming_llm_span(
366
+ async with streaming_llm_span(
333
367
  input=input,
334
368
  messages=messages,
335
369
  invocation_parameters=invocation_parameters,
336
- )
337
- spans[example_id] = span
338
- async with span:
370
+ ) as span:
339
371
  async for chunk in llm_client.chat_completion_create(
340
372
  messages=messages, tools=input.tools or [], **invocation_parameters
341
373
  ):
@@ -343,35 +375,60 @@ async def _stream_chat_completion_over_dataset_example(
343
375
  chunk.dataset_example_id = example_id
344
376
  yield chunk
345
377
  span.set_attributes(llm_client.attributes)
346
- if span.error_message is not None:
378
+ db_trace = get_db_trace(span, project_id)
379
+ db_span = get_db_span(span, db_trace)
380
+ db_run = get_db_experiment_run(
381
+ db_span, db_trace, experiment_id=experiment_id, example_id=revision.dataset_example_id
382
+ )
383
+ await results_queue.put((example_id, db_span, db_run))
384
+ if span.status_message is not None:
347
385
  yield ChatCompletionSubscriptionError(
348
- message=span.error_message, dataset_example_id=example_id
386
+ message=span.status_message, dataset_example_id=example_id
349
387
  )
350
388
 
351
389
 
352
- async def _merge_iterators(
353
- iterators: Collection[AsyncIterator[GenericType]],
354
- ) -> AsyncIterator[GenericType]:
355
- tasks: dict[AsyncIterator[GenericType], Task[GenericType]] = {
356
- iterable: _as_task(iterable) for iterable in iterators
357
- }
358
- while tasks:
359
- completed_tasks, _ = await wait(tasks.values(), return_when=FIRST_COMPLETED)
360
- for task in completed_tasks:
361
- iterator = next(it for it, t in tasks.items() if t == task)
362
- try:
363
- yield task.result()
364
- except StopAsyncIteration:
365
- del tasks[iterator]
366
- except Exception as error:
367
- del tasks[iterator]
368
- logger.exception(error)
369
- else:
370
- tasks[iterator] = _as_task(iterator)
390
+ async def _chat_completion_result_payloads(
391
+ *,
392
+ db: DbSessionFactory,
393
+ results: Sequence[ChatCompletionResult],
394
+ ) -> AsyncIterator[ChatCompletionSubscriptionResult]:
395
+ if not results:
396
+ return
397
+ async with db() as session:
398
+ for _, span, run in results:
399
+ if span:
400
+ session.add(span)
401
+ session.add(run)
402
+ await session.flush()
403
+ for example_id, span, run in results:
404
+ yield ChatCompletionSubscriptionResult(
405
+ span=to_gql_span(span) if span else None,
406
+ experiment_run=to_gql_experiment_run(run),
407
+ dataset_example_id=example_id,
408
+ )
409
+
410
+
411
+ def _create_task_with_timeout(
412
+ iterable: AsyncIterator[GenericType], timeout_in_seconds: int = 60
413
+ ) -> Task[GenericType]:
414
+ return create_task(wait_for(_as_coroutine(iterable), timeout=timeout_in_seconds))
415
+
371
416
 
417
+ async def _drain(queue: Queue[GenericType]) -> list[GenericType]:
418
+ values: list[GenericType] = []
419
+ while not queue.empty():
420
+ values.append(await queue.get())
421
+ return values
372
422
 
373
- def _as_task(iterable: AsyncIterator[GenericType]) -> Task[GenericType]:
374
- return create_task(_as_coroutine(iterable))
423
+
424
+ def _drain_no_wait(queue: Queue[GenericType]) -> list[GenericType]:
425
+ values: list[GenericType] = []
426
+ while True:
427
+ try:
428
+ values.append(queue.get_nowait())
429
+ except QueueEmpty:
430
+ break
431
+ return values
375
432
 
376
433
 
377
434
  async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
@@ -413,13 +470,8 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
413
470
  assert_never(template_language)
414
471
 
415
472
 
416
- def _get_playground_experiment_task_output(
417
- span: streaming_llm_span,
418
- ) -> Any:
419
- return get_attribute_value(span.attributes, LLM_OUTPUT_MESSAGES)
420
-
421
-
422
- _DEFAULT_PLAYGROUND_EXPERIMENT_NAME = "playground-experiment"
473
+ def _default_playground_experiment_name() -> str:
474
+ return "playground-experiment"
423
475
 
424
476
 
425
477
  def _default_playground_experiment_description(dataset_name: str) -> str:
@@ -4,6 +4,7 @@ import strawberry
4
4
  from strawberry.relay import GlobalID
5
5
 
6
6
  from .Experiment import Experiment
7
+ from .ExperimentRun import ExperimentRun
7
8
  from .Span import Span
8
9
 
9
10
 
@@ -30,8 +31,9 @@ class ToolCallChunk(ChatCompletionSubscriptionPayload):
30
31
 
31
32
 
32
33
  @strawberry.type
33
- class FinishedChatCompletion(ChatCompletionSubscriptionPayload):
34
- span: Span
34
+ class ChatCompletionSubscriptionResult(ChatCompletionSubscriptionPayload):
35
+ span: Optional[Span] = None
36
+ experiment_run: Optional[ExperimentRun] = None
35
37
 
36
38
 
37
39
  @strawberry.type
@@ -40,5 +42,5 @@ class ChatCompletionSubscriptionError(ChatCompletionSubscriptionPayload):
40
42
 
41
43
 
42
44
  @strawberry.type
43
- class ChatCompletionOverDatasetSubscriptionResult(ChatCompletionSubscriptionPayload):
45
+ class ChatCompletionSubscriptionExperiment(ChatCompletionSubscriptionPayload):
44
46
  experiment: Experiment
@@ -1,7 +1,9 @@
1
1
  from datetime import datetime
2
- from typing import Optional
2
+ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
+ from sqlalchemy import select
6
+ from sqlalchemy.orm import load_only
5
7
  from strawberry import UNSET
6
8
  from strawberry.relay import Connection, GlobalID, Node, NodeID
7
9
  from strawberry.scalars import JSON
@@ -20,6 +22,9 @@ from phoenix.server.api.types.pagination import (
20
22
  )
21
23
  from phoenix.server.api.types.Trace import Trace
22
24
 
25
+ if TYPE_CHECKING:
26
+ from phoenix.server.api.types.DatasetExample import DatasetExample
27
+
23
28
 
24
29
  @strawberry.type
25
30
  class ExperimentRun(Node):
@@ -62,6 +67,38 @@ class ExperimentRun(Node):
62
67
  trace_rowid, project_rowid = trace
63
68
  return Trace(id_attr=trace_rowid, trace_id=self.trace_id, project_rowid=project_rowid)
64
69
 
70
+ @strawberry.field
71
+ async def example(
72
+ self, info: Info
73
+ ) -> Annotated[
74
+ "DatasetExample", strawberry.lazy("phoenix.server.api.types.DatasetExample")
75
+ ]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
76
+ from phoenix.server.api.types.DatasetExample import DatasetExample
77
+
78
+ async with info.context.db() as session:
79
+ assert (
80
+ result := await session.execute(
81
+ select(models.DatasetExample, models.Experiment.dataset_version_id)
82
+ .select_from(models.ExperimentRun)
83
+ .join(
84
+ models.DatasetExample,
85
+ models.DatasetExample.id == models.ExperimentRun.dataset_example_id,
86
+ )
87
+ .join(
88
+ models.Experiment,
89
+ models.Experiment.id == models.ExperimentRun.experiment_id,
90
+ )
91
+ .where(models.ExperimentRun.id == self.id_attr)
92
+ .options(load_only(models.DatasetExample.id, models.DatasetExample.created_at))
93
+ )
94
+ ) is not None
95
+ example, version_id = result.first()
96
+ return DatasetExample(
97
+ id_attr=example.id,
98
+ created_at=example.created_at,
99
+ version_id=version_id,
100
+ )
101
+
65
102
 
66
103
  def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
67
104
  """
@@ -8,6 +8,7 @@ class GenerativeProviderKey(Enum):
8
8
  OPENAI = "OpenAI"
9
9
  ANTHROPIC = "Anthropic"
10
10
  AZURE_OPENAI = "Azure OpenAI"
11
+ GEMINI = "Google AI Studio"
11
12
 
12
13
 
13
14
  @strawberry.type
@@ -24,7 +25,7 @@ class GenerativeProvider:
24
25
 
25
26
  default_client = PLAYGROUND_CLIENT_REGISTRY.get_client(self.key, PROVIDER_DEFAULT)
26
27
  if default_client:
27
- return default_client.dependencies()
28
+ return [dependency.name for dependency in default_client.dependencies()]
28
29
  return []
29
30
 
30
31
  @strawberry.field
phoenix/server/app.py CHANGED
@@ -26,6 +26,7 @@ import strawberry
26
26
  from fastapi import APIRouter, Depends, FastAPI
27
27
  from fastapi.middleware.gzip import GZipMiddleware
28
28
  from fastapi.utils import is_body_allowed_for_status_code
29
+ from grpc.aio import ServerInterceptor
29
30
  from sqlalchemy import select
30
31
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
31
32
  from starlette.datastructures import State as StarletteState
@@ -54,6 +55,7 @@ from phoenix.config import (
54
55
  get_env_csrf_trusted_origins,
55
56
  get_env_fastapi_middleware_paths,
56
57
  get_env_gql_extension_paths,
58
+ get_env_grpc_interceptor_paths,
57
59
  get_env_host,
58
60
  get_env_port,
59
61
  server_instrumentation_is_enabled,
@@ -251,7 +253,7 @@ class Static(StaticFiles):
251
253
 
252
254
 
253
255
  class RequestOriginHostnameValidator(BaseHTTPMiddleware):
254
- def __init__(self, trusted_hostnames: list[str], *args: Any, **kwargs: Any) -> None:
256
+ def __init__(self, *args: Any, trusted_hostnames: list[str], **kwargs: Any) -> None:
255
257
  super().__init__(*args, **kwargs)
256
258
  self._trusted_hostnames = trusted_hostnames
257
259
 
@@ -305,6 +307,17 @@ def user_gql_extensions() -> list[Union[type[SchemaExtension], SchemaExtension]]
305
307
  return extensions
306
308
 
307
309
 
310
+ def user_grpc_interceptors() -> list[ServerInterceptor]:
311
+ paths = get_env_grpc_interceptor_paths()
312
+ interceptors = []
313
+ for file_path, object_name in paths:
314
+ interceptor_class = import_object_from_file(file_path, object_name)
315
+ if not issubclass(interceptor_class, ServerInterceptor):
316
+ raise TypeError(f"{interceptor_class} is not a subclass of ServerInterceptor")
317
+ interceptors.append(interceptor_class)
318
+ return interceptors
319
+
320
+
308
321
  ProjectRowId: TypeAlias = int
309
322
 
310
323
 
@@ -479,6 +492,7 @@ def _lifespan(
479
492
  tracer_provider=tracer_provider,
480
493
  enable_prometheus=enable_prometheus,
481
494
  token_store=token_store,
495
+ interceptors=user_grpc_interceptors(),
482
496
  )
483
497
  await stack.enter_async_context(grpc_server)
484
498
  await stack.enter_async_context(dml_event_handler)
@@ -753,7 +767,12 @@ def create_app(
753
767
  middlewares.extend(user_fastapi_middlewares())
754
768
  if origins := get_env_csrf_trusted_origins():
755
769
  trusted_hostnames = [h for o in origins if o and (h := urlparse(o).hostname)]
756
- middlewares.append(Middleware(RequestOriginHostnameValidator, trusted_hostnames))
770
+ middlewares.append(
771
+ Middleware(
772
+ RequestOriginHostnameValidator,
773
+ trusted_hostnames=trusted_hostnames,
774
+ )
775
+ )
757
776
  elif email_sender or oauth2_client_configs:
758
777
  logger.warning(
759
778
  "CSRF protection can be enabled by listing trusted origins via "