arize-phoenix 11.32.1__py3-none-any.whl → 11.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (63) hide show
  1. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/METADATA +1 -1
  2. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/RECORD +57 -50
  3. phoenix/config.py +44 -0
  4. phoenix/db/bulk_inserter.py +111 -116
  5. phoenix/inferences/inferences.py +1 -2
  6. phoenix/server/api/context.py +20 -0
  7. phoenix/server/api/dataloaders/__init__.py +20 -0
  8. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  9. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  10. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  11. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  12. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +59 -0
  13. phoenix/server/api/dataloaders/experiment_repetition_counts.py +39 -0
  14. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  15. phoenix/server/api/helpers/playground_clients.py +4 -0
  16. phoenix/server/api/mutations/prompt_label_mutations.py +67 -58
  17. phoenix/server/api/queries.py +52 -37
  18. phoenix/server/api/routers/v1/documents.py +1 -1
  19. phoenix/server/api/routers/v1/evaluations.py +4 -4
  20. phoenix/server/api/routers/v1/experiment_runs.py +1 -1
  21. phoenix/server/api/routers/v1/experiments.py +1 -1
  22. phoenix/server/api/routers/v1/spans.py +2 -2
  23. phoenix/server/api/routers/v1/traces.py +18 -3
  24. phoenix/server/api/types/DatasetExample.py +49 -1
  25. phoenix/server/api/types/Experiment.py +12 -2
  26. phoenix/server/api/types/ExperimentComparison.py +3 -9
  27. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +146 -0
  28. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  29. phoenix/server/api/types/ExperimentRun.py +12 -19
  30. phoenix/server/api/types/Prompt.py +11 -0
  31. phoenix/server/api/types/PromptLabel.py +2 -19
  32. phoenix/server/api/types/node.py +10 -0
  33. phoenix/server/app.py +78 -20
  34. phoenix/server/cost_tracking/model_cost_manifest.json +1 -1
  35. phoenix/server/daemons/span_cost_calculator.py +10 -8
  36. phoenix/server/grpc_server.py +9 -9
  37. phoenix/server/prometheus.py +30 -6
  38. phoenix/server/static/.vite/manifest.json +43 -43
  39. phoenix/server/static/assets/components-CdQiQTvs.js +5778 -0
  40. phoenix/server/static/assets/{index-D1FDMBMV.js → index-B1VuXYRI.js} +12 -21
  41. phoenix/server/static/assets/pages-CnfZ3RhB.js +9163 -0
  42. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  43. phoenix/server/static/assets/vendor-Cfrr9FCF.js +903 -0
  44. phoenix/server/static/assets/{vendor-arizeai-DsYDNOqt.js → vendor-arizeai-Dz0kN-lQ.js} +4 -4
  45. phoenix/server/static/assets/vendor-codemirror-ClqtONZQ.js +25 -0
  46. phoenix/server/static/assets/{vendor-recharts-BTHn5Y2R.js → vendor-recharts-D6kvOpmb.js} +2 -2
  47. phoenix/server/static/assets/{vendor-shiki-BAcocHFl.js → vendor-shiki-xSOiKxt0.js} +1 -1
  48. phoenix/session/client.py +55 -1
  49. phoenix/session/data_extractor.py +5 -0
  50. phoenix/session/evaluation.py +8 -4
  51. phoenix/session/session.py +13 -0
  52. phoenix/trace/projects.py +1 -2
  53. phoenix/version.py +1 -1
  54. phoenix/server/static/assets/components-Cs9c4Nxp.js +0 -5698
  55. phoenix/server/static/assets/pages-Cbj9SjBx.js +0 -8928
  56. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  57. phoenix/server/static/assets/vendor-RdRDaQiR.js +0 -905
  58. phoenix/server/static/assets/vendor-codemirror-BzJDUbEx.js +0 -25
  59. phoenix/utilities/deprecation.py +0 -31
  60. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/WHEEL +0 -0
  61. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/entry_points.txt +0 -0
  62. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/licenses/IP_NOTICE +0 -0
  63. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/licenses/LICENSE +0 -0
@@ -102,7 +102,7 @@ async def post_evaluations(
102
102
  detail="Evaluation name must not be blank/empty",
103
103
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
104
104
  )
105
- await request.state.queue_evaluation_for_bulk_insert(evaluation)
105
+ await request.state.enqueue_evaluation(evaluation)
106
106
  return Response()
107
107
 
108
108
 
@@ -221,7 +221,7 @@ async def _add_evaluations(state: State, evaluations: Evaluations) -> None:
221
221
  explanation=explanation,
222
222
  metadata_={},
223
223
  )
224
- await state.enqueue(document_annotation)
224
+ await state.enqueue_annotations(document_annotation)
225
225
  elif len(names) == 1 and names[0] in ("context.span_id", "span_id"):
226
226
  for index, row in dataframe.iterrows():
227
227
  score, label, explanation = _get_annotation_result(row)
@@ -235,7 +235,7 @@ async def _add_evaluations(state: State, evaluations: Evaluations) -> None:
235
235
  explanation=explanation,
236
236
  metadata_={},
237
237
  )
238
- await state.enqueue(span_annotation)
238
+ await state.enqueue_annotations(span_annotation)
239
239
  elif len(names) == 1 and names[0] in ("context.trace_id", "trace_id"):
240
240
  for index, row in dataframe.iterrows():
241
241
  score, label, explanation = _get_annotation_result(row)
@@ -249,7 +249,7 @@ async def _add_evaluations(state: State, evaluations: Evaluations) -> None:
249
249
  explanation=explanation,
250
250
  metadata_={},
251
251
  )
252
- await state.enqueue(trace_annotation)
252
+ await state.enqueue_annotations(trace_annotation)
253
253
 
254
254
 
255
255
  def _get_annotation_result(
@@ -27,7 +27,7 @@ class ExperimentRun(V1RoutesBaseModel):
27
27
  description="The ID of the dataset example used in the experiment run"
28
28
  )
29
29
  output: Any = Field(description="The output of the experiment task")
30
- repetition_number: int = Field(description="The repetition number of the experiment run")
30
+ repetition_number: int = Field(description="The repetition number of the experiment run", gt=0)
31
31
  start_time: datetime = Field(description="The start time of the experiment run")
32
32
  end_time: datetime = Field(description="The end time of the experiment run")
33
33
  trace_id: Optional[str] = Field(
@@ -46,7 +46,7 @@ class Experiment(V1RoutesBaseModel):
46
46
  dataset_version_id: str = Field(
47
47
  description="The ID of the dataset version associated with the experiment"
48
48
  )
49
- repetitions: int = Field(description="Number of times the experiment is repeated")
49
+ repetitions: int = Field(description="Number of times the experiment is repeated", gt=0)
50
50
  metadata: dict[str, Any] = Field(description="Metadata of the experiment")
51
51
  project_name: Optional[str] = Field(
52
52
  description="The name of the project associated with the experiment"
@@ -897,7 +897,7 @@ async def annotate_spans(
897
897
  )
898
898
  precursors = [d.as_precursor(user_id=user_id) for d in filtered_span_annotations]
899
899
  if not sync:
900
- await request.state.enqueue(*precursors)
900
+ await request.state.enqueue_annotations(*precursors)
901
901
  return AnnotateSpansResponseBody(data=[])
902
902
 
903
903
  span_ids = {p.span_id for p in precursors}
@@ -1072,7 +1072,7 @@ async def create_spans(
1072
1072
 
1073
1073
  # All spans are valid, queue them all
1074
1074
  for span_for_insertion, project_name in spans_to_queue:
1075
- await request.state.queue_span_for_bulk_insert(span_for_insertion, project_name)
1075
+ await request.state.enqueue_span(span_for_insertion, project_name)
1076
1076
 
1077
1077
  return CreateSpansResponseBody(
1078
1078
  total_received=total_received,
@@ -18,6 +18,7 @@ from starlette.status import (
18
18
  HTTP_404_NOT_FOUND,
19
19
  HTTP_415_UNSUPPORTED_MEDIA_TYPE,
20
20
  HTTP_422_UNPROCESSABLE_ENTITY,
21
+ HTTP_503_SERVICE_UNAVAILABLE,
21
22
  )
22
23
  from strawberry.relay import GlobalID
23
24
 
@@ -29,6 +30,7 @@ from phoenix.server.api.types.node import from_global_id_with_expected_type
29
30
  from phoenix.server.authorization import is_not_locked
30
31
  from phoenix.server.bearer_auth import PhoenixUser
31
32
  from phoenix.server.dml_event import SpanDeleteEvent, TraceAnnotationInsertEvent
33
+ from phoenix.server.prometheus import SPAN_QUEUE_REJECTIONS
32
34
  from phoenix.trace.otel import decode_otlp_span
33
35
  from phoenix.utilities.project import get_project_name
34
36
 
@@ -42,9 +44,18 @@ from .utils import (
42
44
  router = APIRouter(tags=["traces"])
43
45
 
44
46
 
47
+ def is_not_at_capacity(request: Request) -> None:
48
+ if request.app.state.span_queue_is_full():
49
+ SPAN_QUEUE_REJECTIONS.inc()
50
+ raise HTTPException(
51
+ detail="Server is at capacity and cannot process more requests",
52
+ status_code=HTTP_503_SERVICE_UNAVAILABLE,
53
+ )
54
+
55
+
45
56
  @router.post(
46
57
  "/traces",
47
- dependencies=[Depends(is_not_locked)],
58
+ dependencies=[Depends(is_not_locked), Depends(is_not_at_capacity)],
48
59
  operation_id="addTraces",
49
60
  summary="Send traces",
50
61
  responses=add_errors_to_responses(
@@ -56,6 +67,10 @@ router = APIRouter(tags=["traces"])
56
67
  ),
57
68
  },
58
69
  {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid request body"},
70
+ {
71
+ "status_code": HTTP_503_SERVICE_UNAVAILABLE,
72
+ "description": "Server is at capacity and cannot process more requests",
73
+ },
59
74
  ]
60
75
  ),
61
76
  openapi_extra={
@@ -145,7 +160,7 @@ async def annotate_traces(
145
160
 
146
161
  precursors = [d.as_precursor(user_id=user_id) for d in request_body.data]
147
162
  if not sync:
148
- await request.state.enqueue(*precursors)
163
+ await request.state.enqueue_annotations(*precursors)
149
164
  return AnnotateTracesResponseBody(data=[])
150
165
 
151
166
  trace_ids = {p.trace_id for p in precursors}
@@ -193,7 +208,7 @@ async def _add_spans(req: ExportTraceServiceRequest, state: State) -> None:
193
208
  for scope_span in resource_spans.scope_spans:
194
209
  for otlp_span in scope_span.spans:
195
210
  span = await run_in_threadpool(decode_otlp_span, otlp_span)
196
- await state.queue_span_for_bulk_insert(span, project_name)
211
+ await state.enqueue_span(span, project_name)
197
212
 
198
213
 
199
214
  @router.delete(
@@ -10,8 +10,12 @@ from strawberry.types import Info
10
10
 
11
11
  from phoenix.db import models
12
12
  from phoenix.server.api.context import Context
13
+ from phoenix.server.api.exceptions import BadRequest
13
14
  from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
14
15
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
16
+ from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
17
+ ExperimentRepeatedRunGroup,
18
+ )
15
19
  from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
16
20
  from phoenix.server.api.types.node import from_global_id_with_expected_type
17
21
  from phoenix.server.api.types.pagination import (
@@ -65,6 +69,7 @@ class DatasetExample(Node):
65
69
  last: Optional[int] = UNSET,
66
70
  after: Optional[CursorString] = UNSET,
67
71
  before: Optional[CursorString] = UNSET,
72
+ experiment_ids: Optional[list[GlobalID]] = UNSET,
68
73
  ) -> Connection[ExperimentRun]:
69
74
  args = ConnectionArgs(
70
75
  first=first,
@@ -78,8 +83,51 @@ class DatasetExample(Node):
78
83
  .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
79
84
  .join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
80
85
  .where(models.ExperimentRun.dataset_example_id == example_id)
81
- .order_by(models.Experiment.id.desc())
86
+ .order_by(
87
+ models.ExperimentRun.experiment_id.asc(),
88
+ models.ExperimentRun.repetition_number.asc(),
89
+ )
82
90
  )
91
+ if experiment_ids:
92
+ experiment_db_ids = [
93
+ from_global_id_with_expected_type(
94
+ global_id=experiment_id,
95
+ expected_type_name=models.Experiment.__name__,
96
+ )
97
+ for experiment_id in experiment_ids or []
98
+ ]
99
+ query = query.where(models.ExperimentRun.experiment_id.in_(experiment_db_ids))
83
100
  async with info.context.db() as session:
84
101
  runs = (await session.scalars(query)).all()
85
102
  return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
103
+
104
+ @strawberry.field
105
+ async def experiment_repeated_run_groups(
106
+ self,
107
+ info: Info[Context, None],
108
+ experiment_ids: list[GlobalID],
109
+ ) -> list[ExperimentRepeatedRunGroup]:
110
+ example_rowid = self.id_attr
111
+ experiment_rowids = []
112
+ for experiment_id in experiment_ids:
113
+ try:
114
+ experiment_rowid = from_global_id_with_expected_type(
115
+ global_id=experiment_id,
116
+ expected_type_name=models.Experiment.__name__,
117
+ )
118
+ except Exception:
119
+ raise BadRequest(f"Invalid experiment ID: {experiment_id}")
120
+ experiment_rowids.append(experiment_rowid)
121
+ repeated_run_groups = (
122
+ await info.context.data_loaders.experiment_repeated_run_groups.load_many(
123
+ [(experiment_rowid, example_rowid) for experiment_rowid in experiment_rowids]
124
+ )
125
+ )
126
+ return [
127
+ ExperimentRepeatedRunGroup(
128
+ experiment_rowid=group.experiment_rowid,
129
+ dataset_example_rowid=group.dataset_example_rowid,
130
+ runs=[to_gql_experiment_run(run) for run in group.runs],
131
+ )
132
+ for group in repeated_run_groups
133
+ ]
@@ -5,13 +5,14 @@ import strawberry
5
5
  from sqlalchemy import func, select
6
6
  from sqlalchemy.orm import joinedload
7
7
  from strawberry import UNSET, Private
8
- from strawberry.relay import Connection, Node, NodeID
8
+ from strawberry.relay import Connection, GlobalID, Node, NodeID
9
9
  from strawberry.scalars import JSON
10
10
  from strawberry.types import Info
11
11
 
12
12
  from phoenix.db import models
13
13
  from phoenix.server.api.context import Context
14
14
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
15
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
15
16
  from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
16
17
  from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
17
18
  from phoenix.server.api.types.pagination import (
@@ -32,6 +33,7 @@ class Experiment(Node):
32
33
  name: str
33
34
  project_name: Optional[str]
34
35
  description: Optional[str]
36
+ dataset_version_id: GlobalID
35
37
  metadata: JSON
36
38
  created_at: datetime
37
39
  updated_at: datetime
@@ -71,7 +73,10 @@ class Experiment(Node):
71
73
  await session.scalars(
72
74
  select(models.ExperimentRun)
73
75
  .where(models.ExperimentRun.experiment_id == experiment_id)
74
- .order_by(models.ExperimentRun.id.desc())
76
+ .order_by(
77
+ models.ExperimentRun.dataset_example_id.asc(),
78
+ models.ExperimentRun.repetition_number.asc(),
79
+ )
75
80
  .options(
76
81
  joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
77
82
  )
@@ -187,6 +192,10 @@ class Experiment(Node):
187
192
  async for token_type, is_prompt, cost, tokens in data
188
193
  ]
189
194
 
195
+ @strawberry.field
196
+ async def repetition_count(self, info: Info[Context, None]) -> int:
197
+ return await info.context.data_loaders.experiment_repetition_counts.load(self.id_attr)
198
+
190
199
 
191
200
  def to_gql_experiment(
192
201
  experiment: models.Experiment,
@@ -201,6 +210,7 @@ def to_gql_experiment(
201
210
  name=experiment.name,
202
211
  project_name=experiment.project_name,
203
212
  description=experiment.description,
213
+ dataset_version_id=GlobalID(DatasetVersion.__name__, str(experiment.dataset_version_id)),
204
214
  metadata=experiment.metadata_,
205
215
  created_at=experiment.created_at,
206
216
  updated_at=experiment.updated_at,
@@ -1,18 +1,12 @@
1
1
  import strawberry
2
- from strawberry.relay import GlobalID, Node, NodeID
2
+ from strawberry.relay import Node, NodeID
3
3
 
4
4
  from phoenix.server.api.types.DatasetExample import DatasetExample
5
- from phoenix.server.api.types.ExperimentRun import ExperimentRun
6
-
7
-
8
- @strawberry.type
9
- class RunComparisonItem:
10
- experiment_id: GlobalID
11
- runs: list[ExperimentRun]
5
+ from phoenix.server.api.types.ExperimentRepeatedRunGroup import ExperimentRepeatedRunGroup
12
6
 
13
7
 
14
8
  @strawberry.type
15
9
  class ExperimentComparison(Node):
16
10
  id_attr: NodeID[int]
17
11
  example: DatasetExample
18
- run_comparison_items: list[RunComparisonItem]
12
+ repeated_run_groups: list[ExperimentRepeatedRunGroup]
@@ -0,0 +1,146 @@
1
+ import re
2
+ from base64 import b64decode
3
+ from typing import Optional
4
+
5
+ import strawberry
6
+ from sqlalchemy import func, select
7
+ from strawberry.relay import GlobalID, Node
8
+ from strawberry.types import Info
9
+ from typing_extensions import Self, TypeAlias
10
+
11
+ from phoenix.db import models
12
+ from phoenix.server.api.context import Context
13
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
14
+ from phoenix.server.api.types.ExperimentRepeatedRunGroupAnnotationSummary import (
15
+ ExperimentRepeatedRunGroupAnnotationSummary,
16
+ )
17
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
18
+ from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
19
+ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
20
+
21
+ ExperimentRowId: TypeAlias = int
22
+ DatasetExampleRowId: TypeAlias = int
23
+
24
+
25
+ @strawberry.type
26
+ class ExperimentRepeatedRunGroup(Node):
27
+ experiment_rowid: strawberry.Private[ExperimentRowId]
28
+ dataset_example_rowid: strawberry.Private[DatasetExampleRowId]
29
+ runs: list[ExperimentRun]
30
+
31
+ @classmethod
32
+ def resolve_id(
33
+ cls,
34
+ root: Self,
35
+ *,
36
+ info: Info,
37
+ ) -> str:
38
+ return (
39
+ f"experiment_id={root.experiment_rowid}:dataset_example_id={root.dataset_example_rowid}"
40
+ )
41
+
42
+ @strawberry.field
43
+ def experiment_id(self) -> strawberry.ID:
44
+ from phoenix.server.api.types.Experiment import Experiment
45
+
46
+ return strawberry.ID(str(GlobalID(Experiment.__name__, str(self.experiment_rowid))))
47
+
48
+ @strawberry.field
49
+ async def average_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
50
+ return await info.context.data_loaders.average_experiment_repeated_run_group_latency.load(
51
+ (self.experiment_rowid, self.dataset_example_rowid)
52
+ )
53
+
54
+ @strawberry.field
55
+ async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
56
+ experiment_id = self.experiment_rowid
57
+ example_id = self.dataset_example_rowid
58
+ summary = (
59
+ await info.context.data_loaders.span_cost_summary_by_experiment_repeated_run_group.load(
60
+ (experiment_id, example_id)
61
+ )
62
+ )
63
+ return SpanCostSummary(
64
+ prompt=CostBreakdown(
65
+ tokens=summary.prompt.tokens,
66
+ cost=summary.prompt.cost,
67
+ ),
68
+ completion=CostBreakdown(
69
+ tokens=summary.completion.tokens,
70
+ cost=summary.completion.cost,
71
+ ),
72
+ total=CostBreakdown(
73
+ tokens=summary.total.tokens,
74
+ cost=summary.total.cost,
75
+ ),
76
+ )
77
+
78
+ @strawberry.field
79
+ async def cost_detail_summary_entries(
80
+ self, info: Info[Context, None]
81
+ ) -> list[SpanCostDetailSummaryEntry]:
82
+ experiment_id = self.experiment_rowid
83
+ example_id = self.dataset_example_rowid
84
+ stmt = (
85
+ select(
86
+ models.SpanCostDetail.token_type,
87
+ models.SpanCostDetail.is_prompt,
88
+ func.sum(models.SpanCostDetail.cost).label("cost"),
89
+ func.sum(models.SpanCostDetail.tokens).label("tokens"),
90
+ )
91
+ .select_from(models.SpanCostDetail)
92
+ .join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
93
+ .join(models.Trace, models.SpanCost.trace_rowid == models.Trace.id)
94
+ .join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
95
+ .where(models.ExperimentRun.experiment_id == experiment_id)
96
+ .where(models.ExperimentRun.dataset_example_id == example_id)
97
+ .group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
98
+ )
99
+
100
+ async with info.context.db() as session:
101
+ data = await session.stream(stmt)
102
+ return [
103
+ SpanCostDetailSummaryEntry(
104
+ token_type=token_type,
105
+ is_prompt=is_prompt,
106
+ value=CostBreakdown(tokens=tokens, cost=cost),
107
+ )
108
+ async for token_type, is_prompt, cost, tokens in data
109
+ ]
110
+
111
+ @strawberry.field
112
+ async def annotation_summaries(
113
+ self,
114
+ info: Info[Context, None],
115
+ ) -> list[ExperimentRepeatedRunGroupAnnotationSummary]:
116
+ loader = info.context.data_loaders.experiment_repeated_run_group_annotation_summaries
117
+ summaries = await loader.load((self.experiment_rowid, self.dataset_example_rowid))
118
+ return [
119
+ ExperimentRepeatedRunGroupAnnotationSummary(
120
+ annotation_name=summary.annotation_name,
121
+ mean_score=summary.mean_score,
122
+ )
123
+ for summary in summaries
124
+ ]
125
+
126
+
127
+ _EXPERIMENT_REPEATED_RUN_GROUP_NODE_ID_PATTERN = re.compile(
128
+ r"ExperimentRepeatedRunGroup:experiment_id=(\d+):dataset_example_id=(\d+)"
129
+ )
130
+
131
+
132
+ def parse_experiment_repeated_run_group_node_id(
133
+ node_id: str,
134
+ ) -> tuple[ExperimentRowId, DatasetExampleRowId]:
135
+ decoded_node_id = _base64_decode(node_id)
136
+ match = re.match(_EXPERIMENT_REPEATED_RUN_GROUP_NODE_ID_PATTERN, decoded_node_id)
137
+ if not match:
138
+ raise ValueError(f"Invalid node ID format: {node_id}")
139
+
140
+ experiment_id = int(match.group(1))
141
+ dataset_example_id = int(match.group(2))
142
+ return experiment_id, dataset_example_id
143
+
144
+
145
+ def _base64_decode(string: str) -> str:
146
+ return b64decode(string.encode()).decode()
@@ -0,0 +1,9 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.type
7
+ class ExperimentRepeatedRunGroupAnnotationSummary:
8
+ annotation_name: str
9
+ mean_score: Optional[float]
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
5
  from sqlalchemy import func, select
6
- from sqlalchemy.orm import load_only
7
6
  from sqlalchemy.sql.functions import coalesce
8
7
  from strawberry import UNSET
9
8
  from strawberry.relay import Connection, GlobalID, Node, NodeID
@@ -34,12 +33,17 @@ if TYPE_CHECKING:
34
33
  class ExperimentRun(Node):
35
34
  id_attr: NodeID[int]
36
35
  experiment_id: GlobalID
36
+ repetition_number: int
37
37
  trace_id: Optional[str]
38
38
  output: Optional[JSON]
39
39
  start_time: datetime
40
40
  end_time: datetime
41
41
  error: Optional[str]
42
42
 
43
+ @strawberry.field
44
+ def latency_ms(self) -> float:
45
+ return (self.end_time - self.start_time).total_seconds() * 1000
46
+
43
47
  @strawberry.field
44
48
  async def annotations(
45
49
  self,
@@ -78,24 +82,12 @@ class ExperimentRun(Node):
78
82
  ]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
79
83
  from phoenix.server.api.types.DatasetExample import DatasetExample
80
84
 
81
- async with info.context.db() as session:
82
- assert (
83
- result := await session.execute(
84
- select(models.DatasetExample, models.Experiment.dataset_version_id)
85
- .select_from(models.ExperimentRun)
86
- .join(
87
- models.DatasetExample,
88
- models.DatasetExample.id == models.ExperimentRun.dataset_example_id,
89
- )
90
- .join(
91
- models.Experiment,
92
- models.Experiment.id == models.ExperimentRun.experiment_id,
93
- )
94
- .where(models.ExperimentRun.id == self.id_attr)
95
- .options(load_only(models.DatasetExample.id, models.DatasetExample.created_at))
96
- )
97
- ) is not None
98
- example, version_id = result.first()
85
+ (
86
+ example,
87
+ version_id,
88
+ ) = await info.context.data_loaders.dataset_examples_and_versions_by_experiment_run.load(
89
+ self.id_attr
90
+ )
99
91
  return DatasetExample(
100
92
  id_attr=example.id,
101
93
  created_at=example.created_at,
@@ -165,6 +157,7 @@ def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
165
157
  return ExperimentRun(
166
158
  id_attr=run.id,
167
159
  experiment_id=GlobalID(Experiment.__name__, str(run.experiment_id)),
160
+ repetition_number=run.repetition_number,
168
161
  trace_id=run.trace.trace_id if run.trace else None,
169
162
  output=run.output.get("task_output"),
170
163
  start_time=run.start_time,
@@ -19,6 +19,7 @@ from phoenix.server.api.types.pagination import (
19
19
  connection_from_list,
20
20
  )
21
21
 
22
+ from .PromptLabel import PromptLabel, to_gql_prompt_label
22
23
  from .PromptVersion import (
23
24
  PromptVersion,
24
25
  to_gql_prompt_version,
@@ -116,6 +117,16 @@ class Prompt(Node):
116
117
  raise NotFound(f"Source prompt not found: {self.source_prompt_id}")
117
118
  return to_gql_prompt_from_orm(source_prompt)
118
119
 
120
+ @strawberry.field
121
+ async def labels(self, info: Info[Context, None]) -> list["PromptLabel"]:
122
+ async with info.context.db() as session:
123
+ labels = await session.scalars(
124
+ select(models.PromptLabel)
125
+ .join(models.PromptPromptLabel)
126
+ .where(models.PromptPromptLabel.prompt_id == self.id_attr)
127
+ )
128
+ return [to_gql_prompt_label(label) for label in labels]
129
+
119
130
 
120
131
  def to_gql_prompt_from_orm(orm_model: "models.Prompt") -> Prompt:
121
132
  if not orm_model.source_prompt_id:
@@ -1,14 +1,10 @@
1
1
  from typing import Optional
2
2
 
3
3
  import strawberry
4
- from sqlalchemy import select
5
4
  from strawberry.relay import Node, NodeID
6
- from strawberry.types import Info
7
5
 
8
6
  from phoenix.db import models
9
- from phoenix.server.api.context import Context
10
7
  from phoenix.server.api.types.Identifier import Identifier
11
- from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
12
8
 
13
9
 
14
10
  @strawberry.type
@@ -16,21 +12,7 @@ class PromptLabel(Node):
16
12
  id_attr: NodeID[int]
17
13
  name: Identifier
18
14
  description: Optional[str] = None
19
-
20
- @strawberry.field
21
- async def prompts(self, info: Info[Context, None]) -> list[Prompt]:
22
- async with info.context.db() as session:
23
- statement = (
24
- select(models.Prompt)
25
- .join(
26
- models.PromptPromptLabel, models.Prompt.id == models.PromptPromptLabel.prompt_id
27
- )
28
- .where(models.PromptPromptLabel.prompt_label_id == self.id_attr)
29
- )
30
- return [
31
- to_gql_prompt_from_orm(prompt_orm)
32
- async for prompt_orm in await session.stream_scalars(statement)
33
- ]
15
+ color: str
34
16
 
35
17
 
36
18
  def to_gql_prompt_label(label_orm: models.PromptLabel) -> PromptLabel:
@@ -38,4 +20,5 @@ def to_gql_prompt_label(label_orm: models.PromptLabel) -> PromptLabel:
38
20
  id_attr=label_orm.id,
39
21
  name=Identifier(label_orm.name),
40
22
  description=label_orm.description,
23
+ color=label_orm.color,
41
24
  )
@@ -1,5 +1,15 @@
1
+ import re
2
+ from base64 import b64decode
3
+
1
4
  from strawberry.relay import GlobalID
2
5
 
6
+ _GLOBAL_ID_PATTERN = re.compile(r"[a-zA-Z]+:[0-9]+")
7
+
8
+
9
+ def is_global_id(node_id: str) -> bool:
10
+ decoded_node_id = b64decode(node_id).decode()
11
+ return _GLOBAL_ID_PATTERN.match(decoded_node_id) is not None
12
+
3
13
 
4
14
  def from_global_id(global_id: GlobalID) -> tuple[str, int]:
5
15
  """