arize-phoenix 4.4.4rc4__py3-none-any.whl → 4.4.4rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (31) hide show
  1. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc5.dist-info}/METADATA +2 -2
  2. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc5.dist-info}/RECORD +30 -28
  3. phoenix/datasets/evaluators/code_evaluators.py +25 -53
  4. phoenix/datasets/evaluators/llm_evaluators.py +63 -32
  5. phoenix/datasets/evaluators/utils.py +292 -0
  6. phoenix/datasets/experiments.py +147 -82
  7. phoenix/datasets/tracing.py +19 -0
  8. phoenix/datasets/types.py +18 -52
  9. phoenix/db/insertion/dataset.py +19 -16
  10. phoenix/db/migrations/versions/10460e46d750_datasets.py +2 -2
  11. phoenix/db/models.py +8 -3
  12. phoenix/server/api/context.py +2 -0
  13. phoenix/server/api/dataloaders/__init__.py +2 -0
  14. phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
  15. phoenix/server/api/helpers/dataset_helpers.py +8 -7
  16. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  17. phoenix/server/api/mutations/project_mutations.py +9 -4
  18. phoenix/server/api/routers/v1/datasets.py +146 -42
  19. phoenix/server/api/routers/v1/experiment_evaluations.py +1 -0
  20. phoenix/server/api/routers/v1/experiment_runs.py +2 -2
  21. phoenix/server/api/types/Experiment.py +5 -0
  22. phoenix/server/api/types/ExperimentRun.py +1 -1
  23. phoenix/server/api/types/ExperimentRunAnnotation.py +1 -1
  24. phoenix/server/app.py +2 -0
  25. phoenix/server/static/index.js +610 -564
  26. phoenix/session/client.py +124 -2
  27. phoenix/version.py +1 -1
  28. phoenix/datasets/evaluators/_utils.py +0 -13
  29. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc5.dist-info}/WHEEL +0 -0
  30. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc5.dist-info}/licenses/IP_NOTICE +0 -0
  31. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc5.dist-info}/licenses/LICENSE +0 -0
@@ -72,7 +72,7 @@ def upgrade() -> None:
72
72
  sa.Column(
73
73
  "span_rowid",
74
74
  sa.Integer,
75
- sa.ForeignKey("spans.id"),
75
+ sa.ForeignKey("spans.id", ondelete="SET NULL"),
76
76
  nullable=True,
77
77
  index=True,
78
78
  ),
@@ -198,7 +198,7 @@ def upgrade() -> None:
198
198
  sa.String,
199
199
  nullable=True,
200
200
  ),
201
- sa.Column("output", JSON_, nullable=True),
201
+ sa.Column("output", JSON_, nullable=False),
202
202
  sa.Column("start_time", sa.TIMESTAMP(timezone=True), nullable=False),
203
203
  sa.Column("end_time", sa.TIMESTAMP(timezone=True), nullable=False),
204
204
  sa.Column(
phoenix/db/models.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from datetime import datetime, timezone
2
- from typing import Any, Dict, List, Optional
2
+ from typing import Any, Dict, List, Optional, TypedDict
3
3
 
4
4
  from sqlalchemy import (
5
5
  JSON,
@@ -91,6 +91,10 @@ class UtcTimeStamp(TypeDecorator[datetime]):
91
91
  return normalize_datetime(value, timezone.utc)
92
92
 
93
93
 
94
+ class ExperimentResult(TypedDict, total=False):
95
+ result: Dict[str, Any]
96
+
97
+
94
98
  class Base(DeclarativeBase):
95
99
  # Enforce best practices for naming constraints
96
100
  # https://alembic.sqlalchemy.org/en/latest/naming.html#integration-of-naming-conventions-into-operations-autogenerate
@@ -106,6 +110,7 @@ class Base(DeclarativeBase):
106
110
  type_annotation_map = {
107
111
  Dict[str, Any]: JsonDict,
108
112
  List[Dict[str, Any]]: JsonList,
113
+ ExperimentResult: JsonDict,
109
114
  }
110
115
 
111
116
 
@@ -483,7 +488,7 @@ class DatasetExample(Base):
483
488
  index=True,
484
489
  )
485
490
  span_rowid: Mapped[Optional[int]] = mapped_column(
486
- ForeignKey("spans.id"),
491
+ ForeignKey("spans.id", ondelete="SET NULL"),
487
492
  index=True,
488
493
  nullable=True,
489
494
  )
@@ -556,7 +561,7 @@ class ExperimentRun(Base):
556
561
  )
557
562
  repetition_number: Mapped[int]
558
563
  trace_id: Mapped[Optional[str]]
559
- output: Mapped[Optional[Dict[str, Any]]]
564
+ output: Mapped[ExperimentResult]
560
565
  start_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
561
566
  end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
562
567
  prompt_token_count: Mapped[Optional[int]]
@@ -20,6 +20,7 @@ from phoenix.server.api.dataloaders import (
20
20
  EvaluationSummaryDataLoader,
21
21
  ExperimentAnnotationSummaryDataLoader,
22
22
  ExperimentErrorRatesDataLoader,
23
+ ExperimentRunCountsDataLoader,
23
24
  ExperimentSequenceNumberDataLoader,
24
25
  LatencyMsQuantileDataLoader,
25
26
  MinStartOrMaxEndTimeDataLoader,
@@ -44,6 +45,7 @@ class DataLoaders:
44
45
  evaluation_summaries: EvaluationSummaryDataLoader
45
46
  experiment_annotation_summaries: ExperimentAnnotationSummaryDataLoader
46
47
  experiment_error_rates: ExperimentErrorRatesDataLoader
48
+ experiment_run_counts: ExperimentRunCountsDataLoader
47
49
  experiment_sequence_number: ExperimentSequenceNumberDataLoader
48
50
  latency_ms_quantile: LatencyMsQuantileDataLoader
49
51
  min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
@@ -19,6 +19,7 @@ from .document_retrieval_metrics import DocumentRetrievalMetricsDataLoader
19
19
  from .evaluation_summaries import EvaluationSummaryCache, EvaluationSummaryDataLoader
20
20
  from .experiment_annotation_summaries import ExperimentAnnotationSummaryDataLoader
21
21
  from .experiment_error_rates import ExperimentErrorRatesDataLoader
22
+ from .experiment_run_counts import ExperimentRunCountsDataLoader
22
23
  from .experiment_sequence_number import ExperimentSequenceNumberDataLoader
23
24
  from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLoader
24
25
  from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
@@ -41,6 +42,7 @@ __all__ = [
41
42
  "EvaluationSummaryDataLoader",
42
43
  "ExperimentAnnotationSummaryDataLoader",
43
44
  "ExperimentErrorRatesDataLoader",
45
+ "ExperimentRunCountsDataLoader",
44
46
  "ExperimentSequenceNumberDataLoader",
45
47
  "LatencyMsQuantileDataLoader",
46
48
  "MinStartOrMaxEndTimeDataLoader",
@@ -0,0 +1,42 @@
1
+ from typing import (
2
+ AsyncContextManager,
3
+ Callable,
4
+ List,
5
+ )
6
+
7
+ from sqlalchemy import func, select
8
+ from sqlalchemy.ext.asyncio import AsyncSession
9
+ from strawberry.dataloader import DataLoader
10
+ from typing_extensions import TypeAlias
11
+
12
+ from phoenix.db import models
13
+
14
+ ExperimentID: TypeAlias = int
15
+ RunCount: TypeAlias = int
16
+ Key: TypeAlias = ExperimentID
17
+ Result: TypeAlias = RunCount
18
+
19
+
20
+ class ExperimentRunCountsDataLoader(DataLoader[Key, Result]):
21
+ def __init__(
22
+ self,
23
+ db: Callable[[], AsyncContextManager[AsyncSession]],
24
+ ) -> None:
25
+ super().__init__(load_fn=self._load_fn)
26
+ self._db = db
27
+
28
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
29
+ experiment_ids = keys
30
+ async with self._db() as session:
31
+ run_counts = {
32
+ experiment_id: run_count
33
+ async for experiment_id, run_count in await session.stream(
34
+ select(models.ExperimentRun.experiment_id, func.count())
35
+ .where(models.ExperimentRun.experiment_id.in_(set(experiment_ids)))
36
+ .group_by(models.ExperimentRun.experiment_id)
37
+ )
38
+ }
39
+ return [
40
+ run_counts.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
41
+ for experiment_id in experiment_ids
42
+ ]
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from typing import Any, Dict, Literal, Mapping, Optional, Protocol
2
3
 
3
4
  from openinference.semconv.trace import (
@@ -128,14 +129,14 @@ def _get_generic_io_value(
128
129
  Makes a best-effort attempt to extract the input or output value from a span
129
130
  and returns it as a dictionary.
130
131
  """
131
- if isinstance(io_value, str) and (
132
- mime_type == OpenInferenceMimeTypeValues.TEXT.value or mime_type is None
133
- ):
132
+ if mime_type == OpenInferenceMimeTypeValues.JSON.value:
133
+ parsed_value = json.loads(io_value)
134
+ if isinstance(parsed_value, dict):
135
+ return parsed_value
136
+ else:
137
+ return {kind: parsed_value}
138
+ if isinstance(io_value, str):
134
139
  return {kind: io_value}
135
- if isinstance(io_value, dict) and (
136
- mime_type == OpenInferenceMimeTypeValues.JSON.value or mime_type is None
137
- ):
138
- return io_value
139
140
  return {}
140
141
 
141
142
 
@@ -0,0 +1,15 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from strawberry import UNSET
6
+ from strawberry.relay import GlobalID
7
+
8
+
9
+ @strawberry.input
10
+ class ClearProjectInput:
11
+ id: GlobalID
12
+ end_time: Optional[datetime] = strawberry.field(
13
+ default=UNSET,
14
+ description="The time up to which to purge data. Time is right-open /non-inclusive.",
15
+ )
@@ -8,6 +8,7 @@ from phoenix.config import DEFAULT_PROJECT_NAME
8
8
  from phoenix.db import models
9
9
  from phoenix.db.insertion.span import ClearProjectSpansEvent
10
10
  from phoenix.server.api.context import Context
11
+ from phoenix.server.api.input_types.ClearProjectInput import ClearProjectInput
11
12
  from phoenix.server.api.mutations.auth import IsAuthenticated
12
13
  from phoenix.server.api.queries import Query
13
14
  from phoenix.server.api.types.node import from_global_id_with_expected_type
@@ -32,11 +33,15 @@ class ProjectMutationMixin:
32
33
  return Query()
33
34
 
34
35
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
35
- async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
36
- project_id = from_global_id_with_expected_type(global_id=id, expected_type_name="Project")
36
+ async def clear_project(self, info: Info[Context, None], input: ClearProjectInput) -> Query:
37
+ project_id = from_global_id_with_expected_type(
38
+ global_id=input.id, expected_type_name="Project"
39
+ )
37
40
  delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
41
+ if input.end_time is not None:
42
+ delete_statement = delete_statement.where(models.Trace.start_time < input.end_time)
38
43
  async with info.context.db() as session:
39
44
  await session.execute(delete_statement)
40
- if cache := info.context.cache_for_dataloaders:
41
- cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
45
+ if cache := info.context.cache_for_dataloaders:
46
+ cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
42
47
  return Query()
@@ -13,11 +13,12 @@ from typing import (
13
13
  Awaitable,
14
14
  Callable,
15
15
  Coroutine,
16
- Dict,
17
16
  FrozenSet,
18
17
  Iterator,
19
18
  List,
19
+ Mapping,
20
20
  Optional,
21
+ Sequence,
21
22
  Tuple,
22
23
  Union,
23
24
  cast,
@@ -32,8 +33,8 @@ from starlette.datastructures import FormData, UploadFile
32
33
  from starlette.requests import Request
33
34
  from starlette.responses import JSONResponse, Response
34
35
  from starlette.status import (
35
- HTTP_403_FORBIDDEN,
36
36
  HTTP_404_NOT_FOUND,
37
+ HTTP_409_CONFLICT,
37
38
  HTTP_422_UNPROCESSABLE_ENTITY,
38
39
  HTTP_429_TOO_MANY_REQUESTS,
39
40
  )
@@ -44,6 +45,7 @@ from phoenix.db import models
44
45
  from phoenix.db.insertion.dataset import (
45
46
  DatasetAction,
46
47
  DatasetExampleAdditionEvent,
48
+ ExampleContent,
47
49
  add_dataset_examples,
48
50
  )
49
51
  from phoenix.server.api.types.Dataset import Dataset
@@ -350,7 +352,7 @@ async def get_dataset_versions(request: Request) -> Response:
350
352
 
351
353
  async def post_datasets_upload(request: Request) -> Response:
352
354
  """
353
- summary: Upload CSV or PyArrow file as dataset
355
+ summary: Upload dataset as either JSON or file (CSV or PyArrow)
354
356
  operationId: uploadDataset
355
357
  tags:
356
358
  - datasets
@@ -362,6 +364,32 @@ async def post_datasets_upload(request: Request) -> Response:
362
364
  type: boolean
363
365
  requestBody:
364
366
  content:
367
+ application/json:
368
+ schema:
369
+ type: object
370
+ required:
371
+ - name
372
+ - inputs
373
+ properties:
374
+ action:
375
+ type: string
376
+ enum: [create, append]
377
+ name:
378
+ type: string
379
+ description:
380
+ type: string
381
+ inputs:
382
+ type: array
383
+ items:
384
+ type: object
385
+ outputs:
386
+ type: array
387
+ items:
388
+ type: object
389
+ metadata:
390
+ type: array
391
+ items:
392
+ type: object
365
393
  multipart/form-data:
366
394
  schema:
367
395
  type: object
@@ -401,22 +429,18 @@ async def post_datasets_upload(request: Request) -> Response:
401
429
  description: Success
402
430
  403:
403
431
  description: Forbidden
432
+ 409:
433
+ description: Dataset of the same name already exists
404
434
  422:
405
435
  description: Request body is invalid
406
436
  """
407
- if request.app.state.read_only:
408
- return Response(status_code=HTTP_403_FORBIDDEN)
409
- async with request.form() as form:
437
+ request_content_type = request.headers["content-type"]
438
+ examples: Union[Examples, Awaitable[Examples]]
439
+ if request_content_type.startswith("application/json"):
410
440
  try:
411
- (
412
- action,
413
- name,
414
- description,
415
- input_keys,
416
- output_keys,
417
- metadata_keys,
418
- file,
419
- ) = await _parse_form_data(form)
441
+ examples, action, name, description = await run_in_threadpool(
442
+ _process_json, await request.json()
443
+ )
420
444
  except ValueError as e:
421
445
  return Response(
422
446
  content=str(e),
@@ -426,24 +450,53 @@ async def post_datasets_upload(request: Request) -> Response:
426
450
  async with request.app.state.db() as session:
427
451
  if await _check_table_exists(session, name):
428
452
  return Response(
429
- content=f"Dataset already exists: {name=}",
430
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
453
+ content=f"Dataset with the same name already exists: {name=}",
454
+ status_code=HTTP_409_CONFLICT,
431
455
  )
432
- content = await file.read()
433
- try:
434
- examples: Union[Examples, Awaitable[Examples]]
435
- content_type = FileContentType(file.content_type)
436
- if content_type is FileContentType.CSV:
437
- encoding = FileContentEncoding(file.headers.get("content-encoding"))
438
- examples, column_headers = await _process_csv(content, encoding)
439
- elif content_type is FileContentType.PYARROW:
440
- examples, column_headers = await _process_pyarrow(content)
441
- else:
442
- assert_never(content_type)
443
- _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
444
- except ValueError as e:
456
+ elif request_content_type.startswith("multipart/form-data"):
457
+ async with request.form() as form:
458
+ try:
459
+ (
460
+ action,
461
+ name,
462
+ description,
463
+ input_keys,
464
+ output_keys,
465
+ metadata_keys,
466
+ file,
467
+ ) = await _parse_form_data(form)
468
+ except ValueError as e:
469
+ return Response(
470
+ content=str(e),
471
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
472
+ )
473
+ if action is DatasetAction.CREATE:
474
+ async with request.app.state.db() as session:
475
+ if await _check_table_exists(session, name):
476
+ return Response(
477
+ content=f"Dataset with the same name already exists: {name=}",
478
+ status_code=HTTP_409_CONFLICT,
479
+ )
480
+ content = await file.read()
481
+ try:
482
+ file_content_type = FileContentType(file.content_type)
483
+ if file_content_type is FileContentType.CSV:
484
+ encoding = FileContentEncoding(file.headers.get("content-encoding"))
485
+ examples = await _process_csv(
486
+ content, encoding, input_keys, output_keys, metadata_keys
487
+ )
488
+ elif file_content_type is FileContentType.PYARROW:
489
+ examples = await _process_pyarrow(content, input_keys, output_keys, metadata_keys)
490
+ else:
491
+ assert_never(file_content_type)
492
+ except ValueError as e:
493
+ return Response(
494
+ content=str(e),
495
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
496
+ )
497
+ else:
445
498
  return Response(
446
- content=str(e),
499
+ content=str("Invalid request Content-Type"),
447
500
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
448
501
  )
449
502
  operation = cast(
@@ -454,9 +507,6 @@ async def post_datasets_upload(request: Request) -> Response:
454
507
  action=action,
455
508
  name=name,
456
509
  description=description,
457
- input_keys=input_keys,
458
- output_keys=output_keys,
459
- metadata_keys=metadata_keys,
460
510
  ),
461
511
  )
462
512
  if request.query_params.get("sync") == "true":
@@ -505,13 +555,46 @@ InputKeys: TypeAlias = FrozenSet[str]
505
555
  OutputKeys: TypeAlias = FrozenSet[str]
506
556
  MetadataKeys: TypeAlias = FrozenSet[str]
507
557
  DatasetId: TypeAlias = int
508
- Examples: TypeAlias = Iterator[Dict[str, Any]]
558
+ Examples: TypeAlias = Iterator[ExampleContent]
559
+
560
+
561
+ def _process_json(
562
+ data: Mapping[str, Any],
563
+ ) -> Tuple[Examples, DatasetAction, Name, Description]:
564
+ name = data.get("name")
565
+ if not name:
566
+ raise ValueError("Dataset name is required")
567
+ description = data.get("description") or ""
568
+ inputs = data.get("inputs")
569
+ if not inputs:
570
+ raise ValueError("input is required")
571
+ if not isinstance(inputs, list) or not _is_all_dict(inputs):
572
+ raise ValueError("Input should be a list containing only dictionary objects")
573
+ outputs, metadata = data.get("outputs"), data.get("metadata")
574
+ for k, v in {"outputs": outputs, "metadata": metadata}.items():
575
+ if v and not (isinstance(v, list) and len(v) == len(inputs) and _is_all_dict(v)):
576
+ raise ValueError(
577
+ f"{k} should be a list of same length as input containing only dictionary objects"
578
+ )
579
+ examples: List[ExampleContent] = []
580
+ for i, obj in enumerate(inputs):
581
+ example = ExampleContent(
582
+ input=obj,
583
+ output=outputs[i] if outputs else {},
584
+ metadata=metadata[i] if metadata else {},
585
+ )
586
+ examples.append(example)
587
+ action = DatasetAction(cast(Optional[str], data.get("action")) or "create")
588
+ return iter(examples), action, name, description
509
589
 
510
590
 
511
591
  async def _process_csv(
512
592
  content: bytes,
513
593
  content_encoding: FileContentEncoding,
514
- ) -> Tuple[Examples, FrozenSet[str]]:
594
+ input_keys: InputKeys,
595
+ output_keys: OutputKeys,
596
+ metadata_keys: MetadataKeys,
597
+ ) -> Examples:
515
598
  if content_encoding is FileContentEncoding.GZIP:
516
599
  content = await run_in_threadpool(gzip.decompress, content)
517
600
  elif content_encoding is FileContentEncoding.DEFLATE:
@@ -525,22 +608,39 @@ async def _process_csv(
525
608
  if freq > 1:
526
609
  raise ValueError(f"Duplicated column header in CSV file: {header}")
527
610
  column_headers = frozenset(reader.fieldnames)
528
- return reader, column_headers
611
+ _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
612
+ return (
613
+ ExampleContent(
614
+ input={k: row.get(k) for k in input_keys},
615
+ output={k: row.get(k) for k in output_keys},
616
+ metadata={k: row.get(k) for k in metadata_keys},
617
+ )
618
+ for row in iter(reader)
619
+ )
529
620
 
530
621
 
531
622
  async def _process_pyarrow(
532
623
  content: bytes,
533
- ) -> Tuple[Awaitable[Examples], FrozenSet[str]]:
624
+ input_keys: InputKeys,
625
+ output_keys: OutputKeys,
626
+ metadata_keys: MetadataKeys,
627
+ ) -> Awaitable[Examples]:
534
628
  try:
535
629
  reader = pa.ipc.open_stream(content)
536
630
  except pa.ArrowInvalid as e:
537
631
  raise ValueError("File is not valid pyarrow") from e
538
632
  column_headers = frozenset(reader.schema.names)
633
+ _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
634
+
635
+ def get_examples() -> Iterator[ExampleContent]:
636
+ for row in reader.read_pandas().to_dict(orient="records"):
637
+ yield ExampleContent(
638
+ input={k: row.get(k) for k in input_keys},
639
+ output={k: row.get(k) for k in output_keys},
640
+ metadata={k: row.get(k) for k in metadata_keys},
641
+ )
539
642
 
540
- def get_examples() -> Iterator[Dict[str, Any]]:
541
- yield from reader.read_pandas().to_dict(orient="records")
542
-
543
- return run_in_threadpool(get_examples), column_headers
643
+ return run_in_threadpool(get_examples)
544
644
 
545
645
 
546
646
  async def _check_table_exists(session: AsyncSession, name: str) -> bool:
@@ -859,3 +959,7 @@ async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetEx
859
959
  raise ValueError("Dataset does not exist.")
860
960
  examples = [r async for r in await session.stream_scalars(stmt)]
861
961
  return dataset_name, examples
962
+
963
+
964
+ def _is_all_dict(seq: Sequence[Any]) -> bool:
965
+ return all(map(lambda obj: isinstance(obj, dict), seq))
@@ -43,6 +43,7 @@ async def create_experiment_evaluation(request: Request) -> Response:
43
43
  metadata_=metadata,
44
44
  start_time=datetime.fromisoformat(start_time),
45
45
  end_time=datetime.fromisoformat(end_time),
46
+ trace_id=payload.get("trace_id"),
46
47
  )
47
48
  session.add(exp_eval_run)
48
49
  await session.flush()
@@ -61,7 +61,7 @@ async def create_experiment_run(request: Request) -> Response:
61
61
  experiment_id=str(experiment_gid),
62
62
  dataset_example_id=str(example_gid),
63
63
  repetition_number=exp_run.repetition_number,
64
- output=ExperimentResult(result=exp_run.output),
64
+ output=ExperimentResult.from_dict(exp_run.output) if exp_run.output else None,
65
65
  error=exp_run.error,
66
66
  id=str(run_gid),
67
67
  trace_id=exp_run.trace_id,
@@ -99,7 +99,7 @@ async def list_experiment_runs(request: Request) -> Response:
99
99
  experiment_id=str(experiment_gid),
100
100
  dataset_example_id=str(example_gid),
101
101
  repetition_number=exp_run.repetition_number,
102
- output=ExperimentResult(result=exp_run.output),
102
+ output=ExperimentResult.from_dict(exp_run.output) if exp_run.output else None,
103
103
  error=exp_run.error,
104
104
  id=str(run_gid),
105
105
  trace_id=exp_run.trace_id,
@@ -75,6 +75,11 @@ class Experiment(Node):
75
75
  ).all()
76
76
  return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
77
77
 
78
+ @strawberry.field
79
+ async def run_count(self, info: Info[Context, None]) -> int:
80
+ experiment_id = self.id_attr
81
+ return await info.context.data_loaders.experiment_run_counts.load(experiment_id)
82
+
78
83
  @strawberry.field
79
84
  async def annotation_summaries(
80
85
  self, info: Info[Context, None]
@@ -84,7 +84,7 @@ def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
84
84
  trace_id=trace_id
85
85
  if (trace := run.trace) and (trace_id := trace.trace_id) is not None
86
86
  else None,
87
- output=run.output,
87
+ output=run.output.get("result"),
88
88
  start_time=run.start_time,
89
89
  end_time=run.end_time,
90
90
  error=run.error,
@@ -33,7 +33,7 @@ class ExperimentRunAnnotation(Node):
33
33
  if (trace := await dataloader.load(self.trace_id)) is None:
34
34
  return None
35
35
  trace_row_id, project_row_id = trace
36
- return Trace(id_attr=trace_row_id, trace_id=trace.trace_id, project_rowid=project_row_id)
36
+ return Trace(id_attr=trace_row_id, trace_id=self.trace_id, project_rowid=project_row_id)
37
37
 
38
38
 
39
39
  def to_gql_experiment_run_annotation(
phoenix/server/app.py CHANGED
@@ -65,6 +65,7 @@ from phoenix.server.api.dataloaders import (
65
65
  EvaluationSummaryDataLoader,
66
66
  ExperimentAnnotationSummaryDataLoader,
67
67
  ExperimentErrorRatesDataLoader,
68
+ ExperimentRunCountsDataLoader,
68
69
  ExperimentSequenceNumberDataLoader,
69
70
  LatencyMsQuantileDataLoader,
70
71
  MinStartOrMaxEndTimeDataLoader,
@@ -208,6 +209,7 @@ class GraphQLWithContext(GraphQL): # type: ignore
208
209
  ),
209
210
  experiment_annotation_summaries=ExperimentAnnotationSummaryDataLoader(self.db),
210
211
  experiment_error_rates=ExperimentErrorRatesDataLoader(self.db),
212
+ experiment_run_counts=ExperimentRunCountsDataLoader(self.db),
211
213
  experiment_sequence_number=ExperimentSequenceNumberDataLoader(self.db),
212
214
  latency_ms_quantile=LatencyMsQuantileDataLoader(
213
215
  self.db,