arize-phoenix 12.8.0__py3-none-any.whl → 12.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (70) hide show
  1. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
  2. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +70 -67
  3. phoenix/config.py +131 -9
  4. phoenix/db/engines.py +127 -14
  5. phoenix/db/iam_auth.py +64 -0
  6. phoenix/db/pg_config.py +10 -0
  7. phoenix/server/api/context.py +23 -0
  8. phoenix/server/api/dataloaders/__init__.py +6 -0
  9. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
  10. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  11. phoenix/server/api/dataloaders/span_costs.py +3 -9
  12. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  13. phoenix/server/api/helpers/playground_clients.py +3 -3
  14. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  15. phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
  16. phoenix/server/api/mutations/api_key_mutations.py +2 -15
  17. phoenix/server/api/mutations/chat_mutations.py +3 -2
  18. phoenix/server/api/mutations/dataset_label_mutations.py +12 -6
  19. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  20. phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
  23. phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
  24. phoenix/server/api/mutations/prompt_mutations.py +24 -117
  25. phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
  26. phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
  27. phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
  28. phoenix/server/api/mutations/user_mutations.py +4 -4
  29. phoenix/server/api/queries.py +65 -210
  30. phoenix/server/api/subscriptions.py +4 -4
  31. phoenix/server/api/types/Annotation.py +90 -23
  32. phoenix/server/api/types/ApiKey.py +13 -17
  33. phoenix/server/api/types/Dataset.py +88 -48
  34. phoenix/server/api/types/DatasetExample.py +34 -30
  35. phoenix/server/api/types/DatasetLabel.py +47 -13
  36. phoenix/server/api/types/DatasetSplit.py +87 -21
  37. phoenix/server/api/types/DatasetVersion.py +49 -4
  38. phoenix/server/api/types/DocumentAnnotation.py +182 -62
  39. phoenix/server/api/types/Experiment.py +146 -55
  40. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
  41. phoenix/server/api/types/ExperimentRun.py +118 -61
  42. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  43. phoenix/server/api/types/GenerativeModel.py +95 -42
  44. phoenix/server/api/types/ModelInterface.py +7 -2
  45. phoenix/server/api/types/PlaygroundModel.py +12 -2
  46. phoenix/server/api/types/Project.py +70 -75
  47. phoenix/server/api/types/ProjectSession.py +69 -37
  48. phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
  49. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  50. phoenix/server/api/types/Prompt.py +82 -44
  51. phoenix/server/api/types/PromptLabel.py +47 -13
  52. phoenix/server/api/types/PromptVersion.py +11 -8
  53. phoenix/server/api/types/PromptVersionTag.py +65 -25
  54. phoenix/server/api/types/Span.py +116 -115
  55. phoenix/server/api/types/SpanAnnotation.py +189 -42
  56. phoenix/server/api/types/SystemApiKey.py +65 -1
  57. phoenix/server/api/types/Trace.py +45 -44
  58. phoenix/server/api/types/TraceAnnotation.py +144 -48
  59. phoenix/server/api/types/User.py +103 -33
  60. phoenix/server/api/types/UserApiKey.py +73 -26
  61. phoenix/server/app.py +29 -0
  62. phoenix/server/static/.vite/manifest.json +9 -9
  63. phoenix/server/static/assets/{components-Bem6_7MW.js → components-v927s3NF.js} +427 -397
  64. phoenix/server/static/assets/{index-NdiXbuNL.js → index-DrD9eSrN.js} +9 -5
  65. phoenix/server/static/assets/{pages-CEJgMVKU.js → pages-GVybXa_W.js} +489 -486
  66. phoenix/version.py +1 -1
  67. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
  68. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
  69. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  70. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,92 +1,212 @@
1
1
  from datetime import datetime
2
+ from math import isfinite
2
3
  from typing import TYPE_CHECKING, Annotated, Optional
3
4
 
4
5
  import strawberry
5
- from strawberry import Private
6
6
  from strawberry.relay import Node, NodeID
7
7
  from strawberry.scalars import JSON
8
8
  from strawberry.types import Info
9
9
 
10
10
  from phoenix.db import models
11
11
  from phoenix.server.api.context import Context
12
- from phoenix.server.api.interceptor import GqlValueMediator
13
12
 
14
13
  from .Annotation import Annotation
15
14
  from .AnnotationSource import AnnotationSource
16
15
  from .AnnotatorKind import AnnotatorKind
17
- from .User import User, to_gql_user
18
16
 
19
17
  if TYPE_CHECKING:
20
18
  from .Span import Span
19
+ from .User import User
21
20
 
22
21
 
23
22
  @strawberry.type
24
23
  class DocumentAnnotation(Node, Annotation):
25
- id_attr: NodeID[int]
26
- user_id: Private[Optional[int]]
27
- name: str = strawberry.field(
28
- description="Name of the annotation, e.g. 'helpfulness' or 'relevance'."
29
- )
30
- annotator_kind: AnnotatorKind
31
- label: Optional[str] = strawberry.field(
24
+ id: NodeID[int]
25
+ db_record: strawberry.Private[Optional[models.DocumentAnnotation]] = None
26
+
27
+ def __post_init__(self) -> None:
28
+ if self.db_record and self.id != self.db_record.id:
29
+ raise ValueError("DocumentAnnotation ID mismatch")
30
+
31
+ @strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
32
+ async def name(
33
+ self,
34
+ info: Info[Context, None],
35
+ ) -> str:
36
+ if self.db_record:
37
+ val = self.db_record.name
38
+ else:
39
+ val = await info.context.data_loaders.document_annotation_fields.load(
40
+ (self.id, models.DocumentAnnotation.name),
41
+ )
42
+ return val
43
+
44
+ @strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
45
+ async def annotator_kind(
46
+ self,
47
+ info: Info[Context, None],
48
+ ) -> AnnotatorKind:
49
+ if self.db_record:
50
+ val = self.db_record.annotator_kind
51
+ else:
52
+ val = await info.context.data_loaders.document_annotation_fields.load(
53
+ (self.id, models.DocumentAnnotation.annotator_kind),
54
+ )
55
+ return AnnotatorKind(val)
56
+
57
+ @strawberry.field(
32
58
  description="Value of the annotation in the form of a string, e.g. "
33
59
  "'helpful' or 'not helpful'. Note that the label is not necessarily binary."
34
- )
35
- score: Optional[float] = strawberry.field(
60
+ ) # type: ignore
61
+ async def label(
62
+ self,
63
+ info: Info[Context, None],
64
+ ) -> Optional[str]:
65
+ if self.db_record:
66
+ val = self.db_record.label
67
+ else:
68
+ val = await info.context.data_loaders.document_annotation_fields.load(
69
+ (self.id, models.DocumentAnnotation.label),
70
+ )
71
+ return val
72
+
73
+ @strawberry.field(
36
74
  description="Value of the annotation in the form of a numeric score.",
37
- default=GqlValueMediator(),
38
- )
39
- explanation: Optional[str] = strawberry.field(
75
+ ) # type: ignore
76
+ async def score(
77
+ self,
78
+ info: Info[Context, None],
79
+ ) -> Optional[float]:
80
+ if self.db_record:
81
+ val = self.db_record.score
82
+ else:
83
+ val = await info.context.data_loaders.document_annotation_fields.load(
84
+ (self.id, models.DocumentAnnotation.score),
85
+ )
86
+ return val if val is not None and isfinite(val) else None
87
+
88
+ @strawberry.field(
40
89
  description="The annotator's explanation for the annotation result (i.e. "
41
90
  "score or label, or both) given to the subject."
42
- )
43
- metadata: JSON
44
- document_position: int
45
- span_rowid: Private[int]
46
- identifier: str
47
- source: AnnotationSource
48
- created_at: datetime = strawberry.field(
49
- description="The date and time when the annotation was created."
50
- )
51
- updated_at: datetime = strawberry.field(
52
- description="The date and time when the annotation was last updated."
53
- )
54
-
55
- @strawberry.field
56
- async def span(self) -> Annotated["Span", strawberry.lazy(".Span")]:
57
- from phoenix.server.api.types.Span import Span
58
-
59
- return Span(span_rowid=self.span_rowid)
60
-
61
- @strawberry.field
91
+ ) # type: ignore
92
+ async def explanation(
93
+ self,
94
+ info: Info[Context, None],
95
+ ) -> Optional[str]:
96
+ if self.db_record:
97
+ val = self.db_record.explanation
98
+ else:
99
+ val = await info.context.data_loaders.document_annotation_fields.load(
100
+ (self.id, models.DocumentAnnotation.explanation),
101
+ )
102
+ return val
103
+
104
+ @strawberry.field(description="The metadata associated with the annotation.") # type: ignore
105
+ async def metadata(
106
+ self,
107
+ info: Info[Context, None],
108
+ ) -> JSON:
109
+ if self.db_record:
110
+ val = self.db_record.metadata_
111
+ else:
112
+ val = await info.context.data_loaders.document_annotation_fields.load(
113
+ (self.id, models.DocumentAnnotation.metadata_),
114
+ )
115
+ return val
116
+
117
+ @strawberry.field(description="The position of the annotation in the document.") # type: ignore
118
+ async def document_position(
119
+ self,
120
+ info: Info[Context, None],
121
+ ) -> int:
122
+ if self.db_record:
123
+ val = self.db_record.document_position
124
+ else:
125
+ val = await info.context.data_loaders.document_annotation_fields.load(
126
+ (self.id, models.DocumentAnnotation.document_position),
127
+ )
128
+ return val
129
+
130
+ @strawberry.field(description="The identifier of the annotation.") # type: ignore
131
+ async def identifier(
132
+ self,
133
+ info: Info[Context, None],
134
+ ) -> str:
135
+ if self.db_record:
136
+ val = self.db_record.identifier
137
+ else:
138
+ val = await info.context.data_loaders.document_annotation_fields.load(
139
+ (self.id, models.DocumentAnnotation.identifier),
140
+ )
141
+ return val
142
+
143
+ @strawberry.field(description="The source of the annotation.") # type: ignore
144
+ async def source(
145
+ self,
146
+ info: Info[Context, None],
147
+ ) -> AnnotationSource:
148
+ if self.db_record:
149
+ val = self.db_record.source
150
+ else:
151
+ val = await info.context.data_loaders.document_annotation_fields.load(
152
+ (self.id, models.DocumentAnnotation.source),
153
+ )
154
+ return AnnotationSource(val)
155
+
156
+ @strawberry.field(description="The date and time when the annotation was created.") # type: ignore
157
+ async def created_at(
158
+ self,
159
+ info: Info[Context, None],
160
+ ) -> datetime:
161
+ if self.db_record:
162
+ val = self.db_record.created_at
163
+ else:
164
+ val = await info.context.data_loaders.document_annotation_fields.load(
165
+ (self.id, models.DocumentAnnotation.created_at),
166
+ )
167
+ return val
168
+
169
+ @strawberry.field(description="The date and time when the annotation was last updated.") # type: ignore
170
+ async def updated_at(
171
+ self,
172
+ info: Info[Context, None],
173
+ ) -> datetime:
174
+ if self.db_record:
175
+ val = self.db_record.updated_at
176
+ else:
177
+ val = await info.context.data_loaders.document_annotation_fields.load(
178
+ (self.id, models.DocumentAnnotation.updated_at),
179
+ )
180
+ return val
181
+
182
+ @strawberry.field(description="The span associated with the annotation.") # type: ignore
183
+ async def span(
184
+ self,
185
+ info: Info[Context, None],
186
+ ) -> Annotated["Span", strawberry.lazy(".Span")]:
187
+ if self.db_record:
188
+ span_rowid = self.db_record.span_rowid
189
+ else:
190
+ span_rowid = await info.context.data_loaders.document_annotation_fields.load(
191
+ (self.id, models.DocumentAnnotation.span_rowid),
192
+ )
193
+ from .Span import Span
194
+
195
+ return Span(id=span_rowid)
196
+
197
+ @strawberry.field(description="The user that produced the annotation.") # type: ignore
62
198
  async def user(
63
199
  self,
64
200
  info: Info[Context, None],
65
- ) -> Optional[User]:
66
- if self.user_id is None:
201
+ ) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
202
+ if self.db_record:
203
+ user_id = self.db_record.user_id
204
+ else:
205
+ user_id = await info.context.data_loaders.document_annotation_fields.load(
206
+ (self.id, models.DocumentAnnotation.user_id),
207
+ )
208
+ if user_id is None:
67
209
  return None
68
- user = await info.context.data_loaders.users.load(self.user_id)
69
- if user is None:
70
- return None
71
- return to_gql_user(user)
72
-
73
-
74
- def to_gql_document_annotation(
75
- annotation: models.DocumentAnnotation,
76
- ) -> DocumentAnnotation:
77
- return DocumentAnnotation(
78
- id_attr=annotation.id,
79
- user_id=annotation.user_id,
80
- name=annotation.name,
81
- annotator_kind=AnnotatorKind(annotation.annotator_kind),
82
- label=annotation.label,
83
- score=annotation.score,
84
- explanation=annotation.explanation,
85
- metadata=annotation.metadata_,
86
- span_rowid=annotation.span_rowid,
87
- source=AnnotationSource(annotation.source),
88
- identifier=annotation.identifier,
89
- document_position=annotation.document_position,
90
- created_at=annotation.created_at,
91
- updated_at=annotation.updated_at,
92
- )
210
+ from .User import User
211
+
212
+ return User(id=user_id)
@@ -1,9 +1,8 @@
1
1
  from datetime import datetime
2
- from typing import ClassVar, Optional
2
+ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
5
  from sqlalchemy import func, select
6
- from sqlalchemy.orm import joinedload
7
6
  from strawberry import UNSET, Private
8
7
  from strawberry.relay import Connection, GlobalID, Node, NodeID
9
8
  from strawberry.scalars import JSON
@@ -18,10 +17,10 @@ from phoenix.server.api.input_types.ExperimentRunSort import (
18
17
  get_experiment_run_cursor,
19
18
  )
20
19
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
21
- from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
20
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
22
21
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
23
22
  from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
24
- from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
23
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
25
24
  from phoenix.server.api.types.pagination import (
26
25
  ConnectionArgs,
27
26
  Cursor,
@@ -29,26 +28,128 @@ from phoenix.server.api.types.pagination import (
29
28
  connection_from_cursors_and_nodes,
30
29
  connection_from_list,
31
30
  )
32
- from phoenix.server.api.types.Project import Project
33
31
  from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
34
32
  from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
35
33
 
36
34
  _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE = 50
37
35
 
36
+ if TYPE_CHECKING:
37
+ from .Project import Project
38
+
38
39
 
39
40
  @strawberry.type
40
41
  class Experiment(Node):
41
- _table: ClassVar[type[models.Base]] = models.Experiment
42
+ id: NodeID[int]
43
+ db_record: strawberry.Private[Optional[models.Experiment]] = None
42
44
  cached_sequence_number: Private[Optional[int]] = None
43
- id_attr: NodeID[int]
44
- name: str
45
- project_name: Optional[str]
46
- description: Optional[str]
47
- repetitions: int
48
- dataset_version_id: GlobalID
49
- metadata: JSON
50
- created_at: datetime
51
- updated_at: datetime
45
+
46
+ def __post_init__(self) -> None:
47
+ if self.db_record and self.id != self.db_record.id:
48
+ raise ValueError("Experiment ID mismatch")
49
+
50
+ @strawberry.field
51
+ async def name(
52
+ self,
53
+ info: Info[Context, None],
54
+ ) -> str:
55
+ if self.db_record:
56
+ val = self.db_record.name
57
+ else:
58
+ val = await info.context.data_loaders.experiment_fields.load(
59
+ (self.id, models.Experiment.name),
60
+ )
61
+ return val
62
+
63
+ @strawberry.field
64
+ async def project_name(
65
+ self,
66
+ info: Info[Context, None],
67
+ ) -> Optional[str]:
68
+ if self.db_record:
69
+ val = self.db_record.project_name
70
+ else:
71
+ val = await info.context.data_loaders.experiment_fields.load(
72
+ (self.id, models.Experiment.project_name),
73
+ )
74
+ return val
75
+
76
+ @strawberry.field
77
+ async def description(
78
+ self,
79
+ info: Info[Context, None],
80
+ ) -> Optional[str]:
81
+ if self.db_record:
82
+ val = self.db_record.description
83
+ else:
84
+ val = await info.context.data_loaders.experiment_fields.load(
85
+ (self.id, models.Experiment.description),
86
+ )
87
+ return val
88
+
89
+ @strawberry.field
90
+ async def repetitions(
91
+ self,
92
+ info: Info[Context, None],
93
+ ) -> int:
94
+ if self.db_record:
95
+ val = self.db_record.repetitions
96
+ else:
97
+ val = await info.context.data_loaders.experiment_fields.load(
98
+ (self.id, models.Experiment.repetitions),
99
+ )
100
+ return val
101
+
102
+ @strawberry.field
103
+ async def dataset_version_id(
104
+ self,
105
+ info: Info[Context, None],
106
+ ) -> GlobalID:
107
+ if self.db_record:
108
+ version_id = self.db_record.dataset_version_id
109
+ else:
110
+ version_id = await info.context.data_loaders.experiment_fields.load(
111
+ (self.id, models.Experiment.dataset_version_id),
112
+ )
113
+ return GlobalID(DatasetVersion.__name__, str(version_id))
114
+
115
+ @strawberry.field
116
+ async def metadata(
117
+ self,
118
+ info: Info[Context, None],
119
+ ) -> JSON:
120
+ if self.db_record:
121
+ val = self.db_record.metadata_
122
+ else:
123
+ val = await info.context.data_loaders.experiment_fields.load(
124
+ (self.id, models.Experiment.metadata_),
125
+ )
126
+ return val
127
+
128
+ @strawberry.field
129
+ async def created_at(
130
+ self,
131
+ info: Info[Context, None],
132
+ ) -> datetime:
133
+ if self.db_record:
134
+ val = self.db_record.created_at
135
+ else:
136
+ val = await info.context.data_loaders.experiment_fields.load(
137
+ (self.id, models.Experiment.created_at),
138
+ )
139
+ return val
140
+
141
+ @strawberry.field
142
+ async def updated_at(
143
+ self,
144
+ info: Info[Context, None],
145
+ ) -> datetime:
146
+ if self.db_record:
147
+ val = self.db_record.updated_at
148
+ else:
149
+ val = await info.context.data_loaders.experiment_fields.load(
150
+ (self.id, models.Experiment.updated_at),
151
+ )
152
+ return val
52
153
 
53
154
  @strawberry.field(
54
155
  description="Sequence number (1-based) of experiments belonging to the same dataset"
@@ -58,9 +159,9 @@ class Experiment(Node):
58
159
  info: Info[Context, None],
59
160
  ) -> int:
60
161
  if self.cached_sequence_number is None:
61
- seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.id_attr)
162
+ seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.id)
62
163
  if seq_num is None:
63
- raise ValueError(f"invalid experiment: id={self.id_attr}")
164
+ raise ValueError(f"invalid experiment: id={self.id}")
64
165
  self.cached_sequence_number = seq_num
65
166
  return self.cached_sequence_number
66
167
 
@@ -74,12 +175,10 @@ class Experiment(Node):
74
175
  ) -> Connection[ExperimentRun]:
75
176
  if first is not None and first <= 0:
76
177
  raise BadRequest("first must be a positive integer if set")
77
- experiment_rowid = self.id_attr
78
178
  page_size = first or _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE
79
179
  experiment_runs_query = (
80
180
  select(models.ExperimentRun)
81
- .where(models.ExperimentRun.experiment_id == experiment_rowid)
82
- .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
181
+ .where(models.ExperimentRun.experiment_id == self.id)
83
182
  .limit(page_size + 1)
84
183
  )
85
184
 
@@ -94,7 +193,7 @@ class Experiment(Node):
94
193
  experiment_runs_query = add_order_by_and_page_start_to_query(
95
194
  query=experiment_runs_query,
96
195
  sort=sort,
97
- experiment_rowid=experiment_rowid,
196
+ experiment_rowid=self.id,
98
197
  after_experiment_run_rowid=after_experiment_run_rowid,
99
198
  after_sort_column_value=after_sort_column_value,
100
199
  )
@@ -111,7 +210,7 @@ class Experiment(Node):
111
210
  for result in results:
112
211
  run = result[0]
113
212
  annotation_score = result[1] if len(result) > 1 else None
114
- gql_run = to_gql_experiment_run(run)
213
+ gql_run = ExperimentRun(id=run.id, db_record=run)
115
214
  cursor = get_experiment_run_cursor(
116
215
  run=run, annotation_score=annotation_score, sort=sort
117
216
  )
@@ -125,14 +224,13 @@ class Experiment(Node):
125
224
 
126
225
  @strawberry.field
127
226
  async def run_count(self, info: Info[Context, None]) -> int:
128
- experiment_id = self.id_attr
129
- return await info.context.data_loaders.experiment_run_counts.load(experiment_id)
227
+ return await info.context.data_loaders.experiment_run_counts.load(self.id)
130
228
 
131
229
  @strawberry.field
132
230
  async def annotation_summaries(
133
231
  self, info: Info[Context, None]
134
232
  ) -> list[ExperimentAnnotationSummary]:
135
- experiment_id = self.id_attr
233
+ experiment_id = self.id
136
234
  return [
137
235
  ExperimentAnnotationSummary(
138
236
  annotation_name=summary.annotation_name,
@@ -149,40 +247,42 @@ class Experiment(Node):
149
247
 
150
248
  @strawberry.field
151
249
  async def error_rate(self, info: Info[Context, None]) -> Optional[float]:
152
- return await info.context.data_loaders.experiment_error_rates.load(self.id_attr)
250
+ return await info.context.data_loaders.experiment_error_rates.load(self.id)
153
251
 
154
252
  @strawberry.field
155
253
  async def average_run_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
156
- latency_ms = await info.context.data_loaders.average_experiment_run_latency.load(
157
- self.id_attr
158
- )
254
+ latency_ms = await info.context.data_loaders.average_experiment_run_latency.load(self.id)
159
255
  return latency_ms
160
256
 
161
257
  @strawberry.field
162
- async def project(self, info: Info[Context, None]) -> Optional[Project]:
163
- if self.project_name is None:
258
+ async def project(
259
+ self, info: Info[Context, None]
260
+ ) -> Optional[Annotated["Project", strawberry.lazy(".Project")]]:
261
+ if self.db_record:
262
+ project_name = self.db_record.project_name
263
+ else:
264
+ project_name = await info.context.data_loaders.experiment_fields.load(
265
+ (self.id, models.Experiment.project_name),
266
+ )
267
+
268
+ if project_name is None:
164
269
  return None
165
270
 
166
- db_project = await info.context.data_loaders.project_by_name.load(self.project_name)
271
+ db_project = await info.context.data_loaders.project_by_name.load(project_name)
167
272
 
168
273
  if db_project is None:
169
274
  return None
275
+ from .Project import Project
170
276
 
171
- return Project(
172
- project_rowid=db_project.id,
173
- db_project=db_project,
174
- )
277
+ return Project(id=db_project.id, db_record=db_project)
175
278
 
176
279
  @strawberry.field
177
280
  def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
178
- return info.context.last_updated_at.get(self._table, self.id_attr)
281
+ return info.context.last_updated_at.get(models.Experiment, self.id)
179
282
 
180
283
  @strawberry.field
181
284
  async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
182
- experiment_id = self.id_attr
183
- summary = await info.context.data_loaders.span_cost_summary_by_experiment.load(
184
- experiment_id
185
- )
285
+ summary = await info.context.data_loaders.span_cost_summary_by_experiment.load(self.id)
186
286
  return SpanCostSummary(
187
287
  prompt=CostBreakdown(
188
288
  tokens=summary.prompt.tokens,
@@ -202,8 +302,6 @@ class Experiment(Node):
202
302
  async def cost_detail_summary_entries(
203
303
  self, info: Info[Context, None]
204
304
  ) -> list[SpanCostDetailSummaryEntry]:
205
- experiment_id = self.id_attr
206
-
207
305
  stmt = (
208
306
  select(
209
307
  models.SpanCostDetail.token_type,
@@ -216,7 +314,7 @@ class Experiment(Node):
216
314
  .join(models.Span, models.SpanCost.span_rowid == models.Span.id)
217
315
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
218
316
  .join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
219
- .where(models.ExperimentRun.experiment_id == experiment_id)
317
+ .where(models.ExperimentRun.experiment_id == self.id)
220
318
  .group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
221
319
  )
222
320
 
@@ -237,9 +335,9 @@ class Experiment(Node):
237
335
  info: Info[Context, None],
238
336
  ) -> Connection[DatasetSplit]:
239
337
  """Returns the dataset splits associated with this experiment."""
240
- splits = await info.context.data_loaders.experiment_dataset_splits.load(self.id_attr)
338
+ splits = await info.context.data_loaders.experiment_dataset_splits.load(self.id)
241
339
  return connection_from_list(
242
- [to_gql_dataset_split(split) for split in splits], ConnectionArgs()
340
+ [DatasetSplit(id=split.id, db_record=split) for split in splits], ConnectionArgs()
243
341
  )
244
342
 
245
343
 
@@ -251,14 +349,7 @@ def to_gql_experiment(
251
349
  Converts an ORM experiment to a GraphQL Experiment.
252
350
  """
253
351
  return Experiment(
352
+ id=experiment.id,
353
+ db_record=experiment,
254
354
  cached_sequence_number=sequence_number,
255
- id_attr=experiment.id,
256
- name=experiment.name,
257
- project_name=experiment.project_name,
258
- description=experiment.description,
259
- repetitions=experiment.repetitions,
260
- dataset_version_id=GlobalID(DatasetVersion.__name__, str(experiment.dataset_version_id)),
261
- metadata=experiment.metadata_,
262
- created_at=experiment.created_at,
263
- updated_at=experiment.updated_at,
264
355
  )
@@ -26,7 +26,16 @@ DatasetExampleRowId: TypeAlias = int
26
26
  class ExperimentRepeatedRunGroup(Node):
27
27
  experiment_rowid: strawberry.Private[ExperimentRowId]
28
28
  dataset_example_rowid: strawberry.Private[DatasetExampleRowId]
29
- runs: list[ExperimentRun]
29
+ cached_runs: strawberry.Private[Optional[list[ExperimentRun]]] = None
30
+
31
+ @strawberry.field
32
+ async def runs(self, info: Info[Context, None]) -> list[ExperimentRun]:
33
+ if self.cached_runs is not None:
34
+ return self.cached_runs
35
+ runs = await info.context.data_loaders.experiment_runs_by_experiment_and_example.load(
36
+ (self.experiment_rowid, self.dataset_example_rowid)
37
+ )
38
+ return [ExperimentRun(id=run.id, db_record=run) for run in runs]
30
39
 
31
40
  @classmethod
32
41
  def resolve_id(