arize-phoenix 11.32.1__py3-none-any.whl → 11.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (63) hide show
  1. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/METADATA +1 -1
  2. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/RECORD +57 -50
  3. phoenix/config.py +44 -0
  4. phoenix/db/bulk_inserter.py +111 -116
  5. phoenix/inferences/inferences.py +1 -2
  6. phoenix/server/api/context.py +20 -0
  7. phoenix/server/api/dataloaders/__init__.py +20 -0
  8. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  9. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  10. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  11. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  12. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +59 -0
  13. phoenix/server/api/dataloaders/experiment_repetition_counts.py +39 -0
  14. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  15. phoenix/server/api/helpers/playground_clients.py +4 -0
  16. phoenix/server/api/mutations/prompt_label_mutations.py +67 -58
  17. phoenix/server/api/queries.py +52 -37
  18. phoenix/server/api/routers/v1/documents.py +1 -1
  19. phoenix/server/api/routers/v1/evaluations.py +4 -4
  20. phoenix/server/api/routers/v1/experiment_runs.py +1 -1
  21. phoenix/server/api/routers/v1/experiments.py +1 -1
  22. phoenix/server/api/routers/v1/spans.py +2 -2
  23. phoenix/server/api/routers/v1/traces.py +18 -3
  24. phoenix/server/api/types/DatasetExample.py +49 -1
  25. phoenix/server/api/types/Experiment.py +12 -2
  26. phoenix/server/api/types/ExperimentComparison.py +3 -9
  27. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +146 -0
  28. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  29. phoenix/server/api/types/ExperimentRun.py +12 -19
  30. phoenix/server/api/types/Prompt.py +11 -0
  31. phoenix/server/api/types/PromptLabel.py +2 -19
  32. phoenix/server/api/types/node.py +10 -0
  33. phoenix/server/app.py +78 -20
  34. phoenix/server/cost_tracking/model_cost_manifest.json +1 -1
  35. phoenix/server/daemons/span_cost_calculator.py +10 -8
  36. phoenix/server/grpc_server.py +9 -9
  37. phoenix/server/prometheus.py +30 -6
  38. phoenix/server/static/.vite/manifest.json +43 -43
  39. phoenix/server/static/assets/components-CdQiQTvs.js +5778 -0
  40. phoenix/server/static/assets/{index-D1FDMBMV.js → index-B1VuXYRI.js} +12 -21
  41. phoenix/server/static/assets/pages-CnfZ3RhB.js +9163 -0
  42. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  43. phoenix/server/static/assets/vendor-Cfrr9FCF.js +903 -0
  44. phoenix/server/static/assets/{vendor-arizeai-DsYDNOqt.js → vendor-arizeai-Dz0kN-lQ.js} +4 -4
  45. phoenix/server/static/assets/vendor-codemirror-ClqtONZQ.js +25 -0
  46. phoenix/server/static/assets/{vendor-recharts-BTHn5Y2R.js → vendor-recharts-D6kvOpmb.js} +2 -2
  47. phoenix/server/static/assets/{vendor-shiki-BAcocHFl.js → vendor-shiki-xSOiKxt0.js} +1 -1
  48. phoenix/session/client.py +55 -1
  49. phoenix/session/data_extractor.py +5 -0
  50. phoenix/session/evaluation.py +8 -4
  51. phoenix/session/session.py +13 -0
  52. phoenix/trace/projects.py +1 -2
  53. phoenix/version.py +1 -1
  54. phoenix/server/static/assets/components-Cs9c4Nxp.js +0 -5698
  55. phoenix/server/static/assets/pages-Cbj9SjBx.js +0 -8928
  56. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  57. phoenix/server/static/assets/vendor-RdRDaQiR.js +0 -905
  58. phoenix/server/static/assets/vendor-codemirror-BzJDUbEx.js +0 -25
  59. phoenix/utilities/deprecation.py +0 -31
  60. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/WHEEL +0 -0
  61. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/entry_points.txt +0 -0
  62. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/licenses/IP_NOTICE +0 -0
  63. {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,77 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ from sqlalchemy import func, select, tuple_
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ ExperimentID: TypeAlias = int
12
+ DatasetExampleID: TypeAlias = int
13
+ AnnotationName: TypeAlias = str
14
+ MeanAnnotationScore: TypeAlias = float
15
+
16
+
17
+ @dataclass
18
+ class AnnotationSummary:
19
+ annotation_name: AnnotationName
20
+ mean_score: Optional[MeanAnnotationScore]
21
+
22
+
23
+ Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
24
+ Result: TypeAlias = list[AnnotationSummary]
25
+
26
+
27
+ class ExperimentRepeatedRunGroupAnnotationSummariesDataLoader(DataLoader[Key, Result]):
28
+ def __init__(
29
+ self,
30
+ db: DbSessionFactory,
31
+ ) -> None:
32
+ super().__init__(load_fn=self._load_fn)
33
+ self._db = db
34
+
35
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
36
+ annotation_summaries_query = (
37
+ select(
38
+ models.ExperimentRun.experiment_id.label("experiment_id"),
39
+ models.ExperimentRun.dataset_example_id.label("dataset_example_id"),
40
+ models.ExperimentRunAnnotation.name.label("annotation_name"),
41
+ func.avg(models.ExperimentRunAnnotation.score).label("mean_score"),
42
+ )
43
+ .select_from(models.ExperimentRunAnnotation)
44
+ .join(
45
+ models.ExperimentRun,
46
+ models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
47
+ )
48
+ .where(
49
+ tuple_(
50
+ models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
51
+ ).in_(set(keys))
52
+ )
53
+ .group_by(
54
+ models.ExperimentRun.experiment_id,
55
+ models.ExperimentRun.dataset_example_id,
56
+ models.ExperimentRunAnnotation.name,
57
+ )
58
+ )
59
+ async with self._db() as session:
60
+ annotation_summaries = (await session.execute(annotation_summaries_query)).all()
61
+ annotation_summaries_by_key: dict[Key, list[AnnotationSummary]] = {}
62
+ for summary in annotation_summaries:
63
+ key = (summary.experiment_id, summary.dataset_example_id)
64
+ gql_summary = AnnotationSummary(
65
+ annotation_name=summary.annotation_name,
66
+ mean_score=summary.mean_score,
67
+ )
68
+ if key not in annotation_summaries_by_key:
69
+ annotation_summaries_by_key[key] = []
70
+ annotation_summaries_by_key[key].append(gql_summary)
71
+ return [
72
+ sorted(
73
+ annotation_summaries_by_key.get(key, []),
74
+ key=lambda summary: summary.annotation_name,
75
+ )
76
+ for key in keys
77
+ ]
@@ -0,0 +1,59 @@
1
+ from dataclasses import dataclass
2
+
3
+ from sqlalchemy import select, tuple_
4
+ from sqlalchemy.orm import joinedload
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ ExperimentID: TypeAlias = int
12
+ DatasetExampleID: TypeAlias = int
13
+ Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
14
+
15
+
16
+ @dataclass
17
+ class ExperimentRepeatedRunGroup:
18
+ experiment_rowid: int
19
+ dataset_example_rowid: int
20
+ runs: list[models.ExperimentRun]
21
+
22
+
23
+ Result: TypeAlias = ExperimentRepeatedRunGroup
24
+
25
+
26
+ class ExperimentRepeatedRunGroupsDataLoader(DataLoader[Key, Result]):
27
+ def __init__(self, db: DbSessionFactory) -> None:
28
+ super().__init__(load_fn=self._load_fn)
29
+ self._db = db
30
+
31
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
32
+ repeated_run_groups_query = (
33
+ select(models.ExperimentRun)
34
+ .where(
35
+ tuple_(
36
+ models.ExperimentRun.experiment_id,
37
+ models.ExperimentRun.dataset_example_id,
38
+ ).in_(set(keys))
39
+ )
40
+ .order_by(models.ExperimentRun.repetition_number)
41
+ .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
42
+ )
43
+
44
+ async with self._db() as session:
45
+ runs_by_key: dict[Key, list[models.ExperimentRun]] = {}
46
+ for run in (await session.scalars(repeated_run_groups_query)).all():
47
+ key = (run.experiment_id, run.dataset_example_id)
48
+ if key not in runs_by_key:
49
+ runs_by_key[key] = []
50
+ runs_by_key[key].append(run)
51
+
52
+ return [
53
+ ExperimentRepeatedRunGroup(
54
+ experiment_rowid=experiment_id,
55
+ dataset_example_rowid=dataset_example_id,
56
+ runs=runs_by_key.get((experiment_id, dataset_example_id), []),
57
+ )
58
+ for (experiment_id, dataset_example_id) in keys
59
+ ]
@@ -0,0 +1,39 @@
1
+ from sqlalchemy import func, select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ ExperimentID: TypeAlias = int
9
+ RepetitionCount: TypeAlias = int
10
+ Key: TypeAlias = ExperimentID
11
+ Result: TypeAlias = RepetitionCount
12
+
13
+
14
+ class ExperimentRepetitionCountsDataLoader(DataLoader[Key, Result]):
15
+ def __init__(
16
+ self,
17
+ db: DbSessionFactory,
18
+ ) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ experiment_ids = keys
24
+ repetition_counts_query = (
25
+ select(
26
+ models.ExperimentRun.experiment_id,
27
+ func.max(models.ExperimentRun.repetition_number).label("repetition_count"),
28
+ )
29
+ .group_by(models.ExperimentRun.experiment_id)
30
+ .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
31
+ )
32
+ async with self._db() as session:
33
+ repetition_counts = {
34
+ experiment_id: repetition_count
35
+ for experiment_id, repetition_count in await session.execute(
36
+ repetition_counts_query
37
+ )
38
+ }
39
+ return [repetition_counts.get(experiment_id, 0) for experiment_id in keys]
@@ -0,0 +1,64 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select, tuple_
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ ExperimentId: TypeAlias = int
12
+ DatasetExampleId: TypeAlias = int
13
+ Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByExperimentRepeatedRunGroupDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ stmt = (
24
+ select(
25
+ models.ExperimentRun.experiment_id,
26
+ models.ExperimentRun.dataset_example_id,
27
+ func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
28
+ func.sum(models.SpanCost.completion_cost).label("completion_cost"),
29
+ func.sum(models.SpanCost.total_cost).label("total_cost"),
30
+ func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
31
+ func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
32
+ func.sum(models.SpanCost.total_tokens).label("total_tokens"),
33
+ )
34
+ .select_from(models.ExperimentRun)
35
+ .join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
36
+ .join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
37
+ .where(
38
+ tuple_(
39
+ models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
40
+ ).in_(set(keys))
41
+ )
42
+ .group_by(models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id)
43
+ )
44
+
45
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
46
+ async with self._db() as session:
47
+ data = await session.stream(stmt)
48
+ async for (
49
+ experiment_id,
50
+ dataset_example_id,
51
+ prompt_cost,
52
+ completion_cost,
53
+ total_cost,
54
+ prompt_tokens,
55
+ completion_tokens,
56
+ total_tokens,
57
+ ) in data:
58
+ summary = SpanCostSummary(
59
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
60
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
61
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
62
+ )
63
+ results[(experiment_id, dataset_example_id)] = summary
64
+ return [results.get(key, SpanCostSummary()) for key in keys]
@@ -1669,7 +1669,11 @@ class AnthropicReasoningStreamingClient(AnthropicStreamingClient):
1669
1669
  provider_key=GenerativeProviderKey.GOOGLE,
1670
1670
  model_names=[
1671
1671
  PROVIDER_DEFAULT,
1672
+ "gemini-2.5-flash",
1673
+ "gemini-2.5-flash-lite",
1674
+ "gemini-2.5-pro",
1672
1675
  "gemini-2.5-pro-preview-03-25",
1676
+ "gemini-2.0-flash",
1673
1677
  "gemini-2.0-flash-lite",
1674
1678
  "gemini-2.0-flash-001",
1675
1679
  "gemini-2.0-flash-thinking-exp-01-21",
@@ -10,12 +10,10 @@ from strawberry.relay import GlobalID
10
10
  from strawberry.types import Info
11
11
 
12
12
  from phoenix.db import models
13
- from phoenix.db.types.identifier import Identifier as IdentifierModel
14
13
  from phoenix.server.api.auth import IsLocked, IsNotReadOnly
15
14
  from phoenix.server.api.context import Context
16
15
  from phoenix.server.api.exceptions import Conflict, NotFound
17
16
  from phoenix.server.api.queries import Query
18
- from phoenix.server.api.types.Identifier import Identifier
19
17
  from phoenix.server.api.types.node import from_global_id_with_expected_type
20
18
  from phoenix.server.api.types.Prompt import Prompt
21
19
  from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
@@ -23,37 +21,49 @@ from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_labe
23
21
 
24
22
  @strawberry.input
25
23
  class CreatePromptLabelInput:
26
- name: Identifier
24
+ name: str
27
25
  description: Optional[str] = None
26
+ color: str
28
27
 
29
28
 
30
29
  @strawberry.input
31
30
  class PatchPromptLabelInput:
32
31
  prompt_label_id: GlobalID
33
- name: Optional[Identifier] = None
32
+ name: Optional[str] = None
34
33
  description: Optional[str] = None
35
34
 
36
35
 
37
36
  @strawberry.input
38
- class DeletePromptLabelInput:
39
- prompt_label_id: GlobalID
37
+ class DeletePromptLabelsInput:
38
+ prompt_label_ids: list[GlobalID]
40
39
 
41
40
 
42
41
  @strawberry.input
43
- class SetPromptLabelInput:
42
+ class SetPromptLabelsInput:
44
43
  prompt_id: GlobalID
45
- prompt_label_id: GlobalID
44
+ prompt_label_ids: list[GlobalID]
46
45
 
47
46
 
48
47
  @strawberry.input
49
- class UnsetPromptLabelInput:
48
+ class UnsetPromptLabelsInput:
50
49
  prompt_id: GlobalID
51
- prompt_label_id: GlobalID
50
+ prompt_label_ids: list[GlobalID]
52
51
 
53
52
 
54
53
  @strawberry.type
55
54
  class PromptLabelMutationPayload:
56
- prompt_label: Optional["PromptLabel"]
55
+ prompt_labels: list["PromptLabel"]
56
+ query: "Query"
57
+
58
+
59
+ @strawberry.type
60
+ class PromptLabelDeleteMutationPayload:
61
+ deleted_prompt_label_ids: list["GlobalID"]
62
+ query: "Query"
63
+
64
+
65
+ @strawberry.type
66
+ class PromptLabelAssociationMutationPayload:
57
67
  query: "Query"
58
68
 
59
69
 
@@ -64,17 +74,18 @@ class PromptLabelMutationMixin:
64
74
  self, info: Info[Context, None], input: CreatePromptLabelInput
65
75
  ) -> PromptLabelMutationPayload:
66
76
  async with info.context.db() as session:
67
- name = IdentifierModel.model_validate(str(input.name))
68
- label_orm = models.PromptLabel(name=name, description=input.description)
77
+ label_orm = models.PromptLabel(
78
+ name=input.name, description=input.description, color=input.color
79
+ )
69
80
  session.add(label_orm)
70
81
 
71
82
  try:
72
83
  await session.commit()
73
84
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
74
- raise Conflict(f"A prompt label named '{name}' already exists.")
85
+ raise Conflict(f"A prompt label named '{input.name}' already exists.")
75
86
 
76
87
  return PromptLabelMutationPayload(
77
- prompt_label=to_gql_prompt_label(label_orm),
88
+ prompt_labels=[to_gql_prompt_label(label_orm)],
78
89
  query=Query(),
79
90
  )
80
91
 
@@ -82,7 +93,6 @@ class PromptLabelMutationMixin:
82
93
  async def patch_prompt_label(
83
94
  self, info: Info[Context, None], input: PatchPromptLabelInput
84
95
  ) -> PromptLabelMutationPayload:
85
- validated_name = IdentifierModel.model_validate(str(input.name)) if input.name else None
86
96
  async with info.context.db() as session:
87
97
  label_id = from_global_id_with_expected_type(
88
98
  input.prompt_label_id, PromptLabel.__name__
@@ -92,8 +102,8 @@ class PromptLabelMutationMixin:
92
102
  if not label_orm:
93
103
  raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
94
104
 
95
- if validated_name is not None:
96
- label_orm.name = validated_name.root
105
+ if input.name is not None:
106
+ label_orm.name = input.name
97
107
  if input.description is not None:
98
108
  label_orm.description = input.description
99
109
 
@@ -103,46 +113,48 @@ class PromptLabelMutationMixin:
103
113
  raise Conflict("Error patching PromptLabel. Possibly a name conflict?")
104
114
 
105
115
  return PromptLabelMutationPayload(
106
- prompt_label=to_gql_prompt_label(label_orm),
116
+ prompt_labels=[to_gql_prompt_label(label_orm)],
107
117
  query=Query(),
108
118
  )
109
119
 
110
120
  @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
111
- async def delete_prompt_label(
112
- self, info: Info[Context, None], input: DeletePromptLabelInput
113
- ) -> PromptLabelMutationPayload:
121
+ async def delete_prompt_labels(
122
+ self, info: Info[Context, None], input: DeletePromptLabelsInput
123
+ ) -> PromptLabelDeleteMutationPayload:
114
124
  """
115
125
  Deletes a PromptLabel (and any crosswalk references).
116
126
  """
117
127
  async with info.context.db() as session:
118
- label_id = from_global_id_with_expected_type(
119
- input.prompt_label_id, PromptLabel.__name__
120
- )
121
- stmt = delete(models.PromptLabel).where(models.PromptLabel.id == label_id)
122
- result = await session.execute(stmt)
123
-
124
- if result.rowcount == 0:
125
- raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
128
+ label_ids = [
129
+ from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
130
+ for prompt_label_id in input.prompt_label_ids
131
+ ]
132
+ stmt = delete(models.PromptLabel).where(models.PromptLabel.id.in_(label_ids))
133
+ await session.execute(stmt)
126
134
 
127
135
  await session.commit()
128
136
 
129
- return PromptLabelMutationPayload(
130
- prompt_label=None,
137
+ return PromptLabelDeleteMutationPayload(
138
+ deleted_prompt_label_ids=input.prompt_label_ids,
131
139
  query=Query(),
132
140
  )
133
141
 
134
142
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
135
- async def set_prompt_label(
136
- self, info: Info[Context, None], input: SetPromptLabelInput
137
- ) -> PromptLabelMutationPayload:
143
+ async def set_prompt_labels(
144
+ self, info: Info[Context, None], input: SetPromptLabelsInput
145
+ ) -> PromptLabelAssociationMutationPayload:
138
146
  async with info.context.db() as session:
139
147
  prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
140
- label_id = from_global_id_with_expected_type(
141
- input.prompt_label_id, PromptLabel.__name__
142
- )
148
+ label_ids = [
149
+ from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
150
+ for prompt_label_id in input.prompt_label_ids
151
+ ]
143
152
 
144
- crosswalk = models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
145
- session.add(crosswalk)
153
+ crosswalk_items = [
154
+ models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
155
+ for label_id in label_ids
156
+ ]
157
+ session.add_all(crosswalk_items)
146
158
 
147
159
  try:
148
160
  await session.commit()
@@ -152,41 +164,38 @@ class PromptLabelMutationMixin:
152
164
  # - Foreign key violation => prompt_id or label_id doesn't exist
153
165
  raise Conflict("Failed to associate PromptLabel with Prompt.") from e
154
166
 
155
- label_orm = await session.get(models.PromptLabel, label_id)
156
- if not label_orm:
157
- raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
158
-
159
- return PromptLabelMutationPayload(
160
- prompt_label=to_gql_prompt_label(label_orm),
167
+ return PromptLabelAssociationMutationPayload(
161
168
  query=Query(),
162
169
  )
163
170
 
164
171
  @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
165
- async def unset_prompt_label(
166
- self, info: Info[Context, None], input: UnsetPromptLabelInput
167
- ) -> PromptLabelMutationPayload:
172
+ async def unset_prompt_labels(
173
+ self, info: Info[Context, None], input: UnsetPromptLabelsInput
174
+ ) -> PromptLabelAssociationMutationPayload:
168
175
  """
169
176
  Unsets a PromptLabel from a Prompt by removing the row in the crosswalk.
170
177
  """
171
178
  async with info.context.db() as session:
172
179
  prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
173
- label_id = from_global_id_with_expected_type(
174
- input.prompt_label_id, PromptLabel.__name__
175
- )
180
+ label_ids = [
181
+ from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
182
+ for prompt_label_id in input.prompt_label_ids
183
+ ]
176
184
 
177
185
  stmt = delete(models.PromptPromptLabel).where(
178
186
  (models.PromptPromptLabel.prompt_id == prompt_id)
179
- & (models.PromptPromptLabel.prompt_label_id == label_id)
187
+ & (models.PromptPromptLabel.prompt_label_id.in_(label_ids))
180
188
  )
181
189
  result = await session.execute(stmt)
182
190
 
183
- if result.rowcount == 0:
184
- raise NotFound(f"No association between prompt={prompt_id} and label={label_id}.")
191
+ if result.rowcount != len(label_ids):
192
+ label_ids_str = ", ".join(str(i) for i in label_ids)
193
+ raise NotFound(
194
+ f"No association between prompt={prompt_id} and labels={label_ids_str}."
195
+ )
185
196
 
186
197
  await session.commit()
187
198
 
188
- label_orm = await session.get(models.PromptLabel, label_id)
189
- return PromptLabelMutationPayload(
190
- prompt_label=to_gql_prompt_label(label_orm) if label_orm else None,
199
+ return PromptLabelAssociationMutationPayload(
191
200
  query=Query(),
192
201
  )
@@ -56,15 +56,25 @@ from phoenix.server.api.types.EmbeddingDimension import (
56
56
  to_gql_embedding_dimension,
57
57
  )
58
58
  from phoenix.server.api.types.Event import create_event_id, unpack_event_id
59
- from phoenix.server.api.types.Experiment import Experiment
60
- from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
59
+ from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
60
+ from phoenix.server.api.types.ExperimentComparison import (
61
+ ExperimentComparison,
62
+ )
63
+ from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
64
+ ExperimentRepeatedRunGroup,
65
+ parse_experiment_repeated_run_group_node_id,
66
+ )
61
67
  from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
62
68
  from phoenix.server.api.types.Functionality import Functionality
63
69
  from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
64
70
  from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
65
71
  from phoenix.server.api.types.InferenceModel import InferenceModel
66
72
  from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
67
- from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
73
+ from phoenix.server.api.types.node import (
74
+ from_global_id,
75
+ from_global_id_with_expected_type,
76
+ is_global_id,
77
+ )
68
78
  from phoenix.server.api.types.pagination import (
69
79
  ConnectionArgs,
70
80
  Cursor,
@@ -513,11 +523,12 @@ class Query:
513
523
 
514
524
  cursors_and_nodes = []
515
525
  for example in examples:
516
- run_comparison_items = []
526
+ repeated_run_groups = []
517
527
  for experiment_id in experiment_rowids:
518
- run_comparison_items.append(
519
- RunComparisonItem(
520
- experiment_id=GlobalID(Experiment.__name__, str(experiment_id)),
528
+ repeated_run_groups.append(
529
+ ExperimentRepeatedRunGroup(
530
+ experiment_rowid=experiment_id,
531
+ dataset_example_rowid=example.id,
521
532
  runs=[
522
533
  to_gql_experiment_run(run)
523
534
  for run in sorted(
@@ -533,7 +544,7 @@ class Query:
533
544
  created_at=example.created_at,
534
545
  version_id=base_experiment.dataset_version_id,
535
546
  ),
536
- run_comparison_items=run_comparison_items,
547
+ repeated_run_groups=repeated_run_groups,
537
548
  )
538
549
  cursors_and_nodes.append((Cursor(rowid=example.id), experiment_comparison))
539
550
 
@@ -863,8 +874,37 @@ class Query:
863
874
  return InferenceModel()
864
875
 
865
876
  @strawberry.field
866
- async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
867
- type_name, node_id = from_global_id(id)
877
+ async def node(self, id: strawberry.ID, info: Info[Context, None]) -> Node:
878
+ if not is_global_id(id):
879
+ try:
880
+ experiment_rowid, dataset_example_rowid = (
881
+ parse_experiment_repeated_run_group_node_id(id)
882
+ )
883
+ except Exception:
884
+ raise NotFound(f"Unknown node: {id}")
885
+
886
+ async with info.context.db() as session:
887
+ runs = (
888
+ await session.scalars(
889
+ select(models.ExperimentRun)
890
+ .where(models.ExperimentRun.experiment_id == experiment_rowid)
891
+ .where(models.ExperimentRun.dataset_example_id == dataset_example_rowid)
892
+ .order_by(models.ExperimentRun.repetition_number.asc())
893
+ .options(
894
+ joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
895
+ )
896
+ )
897
+ ).all()
898
+ if not runs:
899
+ raise NotFound(f"Unknown experiment or dataset example: {id}")
900
+ return ExperimentRepeatedRunGroup(
901
+ experiment_rowid=experiment_rowid,
902
+ dataset_example_rowid=dataset_example_rowid,
903
+ runs=[to_gql_experiment_run(run) for run in runs],
904
+ )
905
+
906
+ global_id = GlobalID.from_id(id)
907
+ type_name, node_id = from_global_id(global_id)
868
908
  if type_name == "Dimension":
869
909
  dimension = info.context.model.scalar_dimensions[node_id]
870
910
  return to_gql_dimension(node_id, dimension)
@@ -909,26 +949,9 @@ class Query:
909
949
  return to_gql_dataset(dataset)
910
950
  elif type_name == DatasetExample.__name__:
911
951
  example_id = node_id
912
- latest_revision_id = (
913
- select(func.max(models.DatasetExampleRevision.id))
914
- .where(models.DatasetExampleRevision.dataset_example_id == example_id)
915
- .scalar_subquery()
916
- )
917
952
  async with info.context.db() as session:
918
953
  example = await session.scalar(
919
- select(models.DatasetExample)
920
- .join(
921
- models.DatasetExampleRevision,
922
- onclause=models.DatasetExampleRevision.dataset_example_id
923
- == models.DatasetExample.id,
924
- )
925
- .where(
926
- and_(
927
- models.DatasetExample.id == example_id,
928
- models.DatasetExampleRevision.id == latest_revision_id,
929
- models.DatasetExampleRevision.revision_kind != "DELETE",
930
- )
931
- )
954
+ select(models.DatasetExample).where(models.DatasetExample.id == example_id)
932
955
  )
933
956
  if not example:
934
957
  raise NotFound(f"Unknown dataset example: {id}")
@@ -943,15 +966,7 @@ class Query:
943
966
  )
944
967
  if not experiment:
945
968
  raise NotFound(f"Unknown experiment: {id}")
946
- return Experiment(
947
- id_attr=experiment.id,
948
- name=experiment.name,
949
- project_name=experiment.project_name,
950
- description=experiment.description,
951
- created_at=experiment.created_at,
952
- updated_at=experiment.updated_at,
953
- metadata=experiment.metadata_,
954
- )
969
+ return to_gql_experiment(experiment)
955
970
  elif type_name == ExperimentRun.__name__:
956
971
  async with info.context.db() as session:
957
972
  if not (
@@ -82,7 +82,7 @@ async def annotate_span_documents(
82
82
  annotation.as_precursor(user_id=user_id) for annotation in span_document_annotations
83
83
  ]
84
84
  if not sync:
85
- await request.state.enqueue(*precursors)
85
+ await request.state.enqueue_annotations(*precursors)
86
86
  return AnnotateSpanDocumentsResponseBody(data=[])
87
87
 
88
88
  span_ids = {p.span_id for p in precursors}