arize-phoenix 12.8.0__py3-none-any.whl → 12.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +70 -67
- phoenix/config.py +131 -9
- phoenix/db/engines.py +127 -14
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/pg_config.py +10 -0
- phoenix/server/api/context.py +23 -0
- phoenix/server/api/dataloaders/__init__.py +6 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/helpers/playground_clients.py +3 -3
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
- phoenix/server/api/mutations/api_key_mutations.py +2 -15
- phoenix/server/api/mutations/chat_mutations.py +3 -2
- phoenix/server/api/mutations/dataset_label_mutations.py +12 -6
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
- phoenix/server/api/mutations/model_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
- phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
- phoenix/server/api/mutations/prompt_mutations.py +24 -117
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
- phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
- phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
- phoenix/server/api/mutations/user_mutations.py +4 -4
- phoenix/server/api/queries.py +65 -210
- phoenix/server/api/subscriptions.py +4 -4
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/Dataset.py +88 -48
- phoenix/server/api/types/DatasetExample.py +34 -30
- phoenix/server/api/types/DatasetLabel.py +47 -13
- phoenix/server/api/types/DatasetSplit.py +87 -21
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +182 -62
- phoenix/server/api/types/Experiment.py +146 -55
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
- phoenix/server/api/types/ExperimentRun.py +118 -61
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +70 -75
- phoenix/server/api/types/ProjectSession.py +69 -37
- phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +82 -44
- phoenix/server/api/types/PromptLabel.py +47 -13
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +116 -115
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +45 -44
- phoenix/server/api/types/TraceAnnotation.py +144 -48
- phoenix/server/api/types/User.py +103 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/app.py +29 -0
- phoenix/server/static/.vite/manifest.json +9 -9
- phoenix/server/static/assets/{components-Bem6_7MW.js → components-v927s3NF.js} +427 -397
- phoenix/server/static/assets/{index-NdiXbuNL.js → index-DrD9eSrN.js} +9 -5
- phoenix/server/static/assets/{pages-CEJgMVKU.js → pages-GVybXa_W.js} +489 -486
- phoenix/version.py +1 -1
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,92 +1,212 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
+
from math import isfinite
|
|
2
3
|
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
4
|
|
|
4
5
|
import strawberry
|
|
5
|
-
from strawberry import Private
|
|
6
6
|
from strawberry.relay import Node, NodeID
|
|
7
7
|
from strawberry.scalars import JSON
|
|
8
8
|
from strawberry.types import Info
|
|
9
9
|
|
|
10
10
|
from phoenix.db import models
|
|
11
11
|
from phoenix.server.api.context import Context
|
|
12
|
-
from phoenix.server.api.interceptor import GqlValueMediator
|
|
13
12
|
|
|
14
13
|
from .Annotation import Annotation
|
|
15
14
|
from .AnnotationSource import AnnotationSource
|
|
16
15
|
from .AnnotatorKind import AnnotatorKind
|
|
17
|
-
from .User import User, to_gql_user
|
|
18
16
|
|
|
19
17
|
if TYPE_CHECKING:
|
|
20
18
|
from .Span import Span
|
|
19
|
+
from .User import User
|
|
21
20
|
|
|
22
21
|
|
|
23
22
|
@strawberry.type
|
|
24
23
|
class DocumentAnnotation(Node, Annotation):
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
24
|
+
id: NodeID[int]
|
|
25
|
+
db_record: strawberry.Private[Optional[models.DocumentAnnotation]] = None
|
|
26
|
+
|
|
27
|
+
def __post_init__(self) -> None:
|
|
28
|
+
if self.db_record and self.id != self.db_record.id:
|
|
29
|
+
raise ValueError("DocumentAnnotation ID mismatch")
|
|
30
|
+
|
|
31
|
+
@strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
|
|
32
|
+
async def name(
|
|
33
|
+
self,
|
|
34
|
+
info: Info[Context, None],
|
|
35
|
+
) -> str:
|
|
36
|
+
if self.db_record:
|
|
37
|
+
val = self.db_record.name
|
|
38
|
+
else:
|
|
39
|
+
val = await info.context.data_loaders.document_annotation_fields.load(
|
|
40
|
+
(self.id, models.DocumentAnnotation.name),
|
|
41
|
+
)
|
|
42
|
+
return val
|
|
43
|
+
|
|
44
|
+
@strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
|
|
45
|
+
async def annotator_kind(
|
|
46
|
+
self,
|
|
47
|
+
info: Info[Context, None],
|
|
48
|
+
) -> AnnotatorKind:
|
|
49
|
+
if self.db_record:
|
|
50
|
+
val = self.db_record.annotator_kind
|
|
51
|
+
else:
|
|
52
|
+
val = await info.context.data_loaders.document_annotation_fields.load(
|
|
53
|
+
(self.id, models.DocumentAnnotation.annotator_kind),
|
|
54
|
+
)
|
|
55
|
+
return AnnotatorKind(val)
|
|
56
|
+
|
|
57
|
+
@strawberry.field(
|
|
32
58
|
description="Value of the annotation in the form of a string, e.g. "
|
|
33
59
|
"'helpful' or 'not helpful'. Note that the label is not necessarily binary."
|
|
34
|
-
)
|
|
35
|
-
|
|
60
|
+
) # type: ignore
|
|
61
|
+
async def label(
|
|
62
|
+
self,
|
|
63
|
+
info: Info[Context, None],
|
|
64
|
+
) -> Optional[str]:
|
|
65
|
+
if self.db_record:
|
|
66
|
+
val = self.db_record.label
|
|
67
|
+
else:
|
|
68
|
+
val = await info.context.data_loaders.document_annotation_fields.load(
|
|
69
|
+
(self.id, models.DocumentAnnotation.label),
|
|
70
|
+
)
|
|
71
|
+
return val
|
|
72
|
+
|
|
73
|
+
@strawberry.field(
|
|
36
74
|
description="Value of the annotation in the form of a numeric score.",
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
75
|
+
) # type: ignore
|
|
76
|
+
async def score(
|
|
77
|
+
self,
|
|
78
|
+
info: Info[Context, None],
|
|
79
|
+
) -> Optional[float]:
|
|
80
|
+
if self.db_record:
|
|
81
|
+
val = self.db_record.score
|
|
82
|
+
else:
|
|
83
|
+
val = await info.context.data_loaders.document_annotation_fields.load(
|
|
84
|
+
(self.id, models.DocumentAnnotation.score),
|
|
85
|
+
)
|
|
86
|
+
return val if val is not None and isfinite(val) else None
|
|
87
|
+
|
|
88
|
+
@strawberry.field(
|
|
40
89
|
description="The annotator's explanation for the annotation result (i.e. "
|
|
41
90
|
"score or label, or both) given to the subject."
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
@strawberry.field
|
|
56
|
-
async def
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
91
|
+
) # type: ignore
|
|
92
|
+
async def explanation(
|
|
93
|
+
self,
|
|
94
|
+
info: Info[Context, None],
|
|
95
|
+
) -> Optional[str]:
|
|
96
|
+
if self.db_record:
|
|
97
|
+
val = self.db_record.explanation
|
|
98
|
+
else:
|
|
99
|
+
val = await info.context.data_loaders.document_annotation_fields.load(
|
|
100
|
+
(self.id, models.DocumentAnnotation.explanation),
|
|
101
|
+
)
|
|
102
|
+
return val
|
|
103
|
+
|
|
104
|
+
@strawberry.field(description="The metadata associated with the annotation.") # type: ignore
|
|
105
|
+
async def metadata(
|
|
106
|
+
self,
|
|
107
|
+
info: Info[Context, None],
|
|
108
|
+
) -> JSON:
|
|
109
|
+
if self.db_record:
|
|
110
|
+
val = self.db_record.metadata_
|
|
111
|
+
else:
|
|
112
|
+
val = await info.context.data_loaders.document_annotation_fields.load(
|
|
113
|
+
(self.id, models.DocumentAnnotation.metadata_),
|
|
114
|
+
)
|
|
115
|
+
return val
|
|
116
|
+
|
|
117
|
+
@strawberry.field(description="The position of the annotation in the document.") # type: ignore
|
|
118
|
+
async def document_position(
|
|
119
|
+
self,
|
|
120
|
+
info: Info[Context, None],
|
|
121
|
+
) -> int:
|
|
122
|
+
if self.db_record:
|
|
123
|
+
val = self.db_record.document_position
|
|
124
|
+
else:
|
|
125
|
+
val = await info.context.data_loaders.document_annotation_fields.load(
|
|
126
|
+
(self.id, models.DocumentAnnotation.document_position),
|
|
127
|
+
)
|
|
128
|
+
return val
|
|
129
|
+
|
|
130
|
+
@strawberry.field(description="The identifier of the annotation.") # type: ignore
|
|
131
|
+
async def identifier(
|
|
132
|
+
self,
|
|
133
|
+
info: Info[Context, None],
|
|
134
|
+
) -> str:
|
|
135
|
+
if self.db_record:
|
|
136
|
+
val = self.db_record.identifier
|
|
137
|
+
else:
|
|
138
|
+
val = await info.context.data_loaders.document_annotation_fields.load(
|
|
139
|
+
(self.id, models.DocumentAnnotation.identifier),
|
|
140
|
+
)
|
|
141
|
+
return val
|
|
142
|
+
|
|
143
|
+
@strawberry.field(description="The source of the annotation.") # type: ignore
|
|
144
|
+
async def source(
|
|
145
|
+
self,
|
|
146
|
+
info: Info[Context, None],
|
|
147
|
+
) -> AnnotationSource:
|
|
148
|
+
if self.db_record:
|
|
149
|
+
val = self.db_record.source
|
|
150
|
+
else:
|
|
151
|
+
val = await info.context.data_loaders.document_annotation_fields.load(
|
|
152
|
+
(self.id, models.DocumentAnnotation.source),
|
|
153
|
+
)
|
|
154
|
+
return AnnotationSource(val)
|
|
155
|
+
|
|
156
|
+
@strawberry.field(description="The date and time when the annotation was created.") # type: ignore
|
|
157
|
+
async def created_at(
|
|
158
|
+
self,
|
|
159
|
+
info: Info[Context, None],
|
|
160
|
+
) -> datetime:
|
|
161
|
+
if self.db_record:
|
|
162
|
+
val = self.db_record.created_at
|
|
163
|
+
else:
|
|
164
|
+
val = await info.context.data_loaders.document_annotation_fields.load(
|
|
165
|
+
(self.id, models.DocumentAnnotation.created_at),
|
|
166
|
+
)
|
|
167
|
+
return val
|
|
168
|
+
|
|
169
|
+
@strawberry.field(description="The date and time when the annotation was last updated.") # type: ignore
|
|
170
|
+
async def updated_at(
|
|
171
|
+
self,
|
|
172
|
+
info: Info[Context, None],
|
|
173
|
+
) -> datetime:
|
|
174
|
+
if self.db_record:
|
|
175
|
+
val = self.db_record.updated_at
|
|
176
|
+
else:
|
|
177
|
+
val = await info.context.data_loaders.document_annotation_fields.load(
|
|
178
|
+
(self.id, models.DocumentAnnotation.updated_at),
|
|
179
|
+
)
|
|
180
|
+
return val
|
|
181
|
+
|
|
182
|
+
@strawberry.field(description="The span associated with the annotation.") # type: ignore
|
|
183
|
+
async def span(
|
|
184
|
+
self,
|
|
185
|
+
info: Info[Context, None],
|
|
186
|
+
) -> Annotated["Span", strawberry.lazy(".Span")]:
|
|
187
|
+
if self.db_record:
|
|
188
|
+
span_rowid = self.db_record.span_rowid
|
|
189
|
+
else:
|
|
190
|
+
span_rowid = await info.context.data_loaders.document_annotation_fields.load(
|
|
191
|
+
(self.id, models.DocumentAnnotation.span_rowid),
|
|
192
|
+
)
|
|
193
|
+
from .Span import Span
|
|
194
|
+
|
|
195
|
+
return Span(id=span_rowid)
|
|
196
|
+
|
|
197
|
+
@strawberry.field(description="The user that produced the annotation.") # type: ignore
|
|
62
198
|
async def user(
|
|
63
199
|
self,
|
|
64
200
|
info: Info[Context, None],
|
|
65
|
-
) -> Optional[User]:
|
|
66
|
-
if self.
|
|
201
|
+
) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
|
|
202
|
+
if self.db_record:
|
|
203
|
+
user_id = self.db_record.user_id
|
|
204
|
+
else:
|
|
205
|
+
user_id = await info.context.data_loaders.document_annotation_fields.load(
|
|
206
|
+
(self.id, models.DocumentAnnotation.user_id),
|
|
207
|
+
)
|
|
208
|
+
if user_id is None:
|
|
67
209
|
return None
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
return to_gql_user(user)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def to_gql_document_annotation(
|
|
75
|
-
annotation: models.DocumentAnnotation,
|
|
76
|
-
) -> DocumentAnnotation:
|
|
77
|
-
return DocumentAnnotation(
|
|
78
|
-
id_attr=annotation.id,
|
|
79
|
-
user_id=annotation.user_id,
|
|
80
|
-
name=annotation.name,
|
|
81
|
-
annotator_kind=AnnotatorKind(annotation.annotator_kind),
|
|
82
|
-
label=annotation.label,
|
|
83
|
-
score=annotation.score,
|
|
84
|
-
explanation=annotation.explanation,
|
|
85
|
-
metadata=annotation.metadata_,
|
|
86
|
-
span_rowid=annotation.span_rowid,
|
|
87
|
-
source=AnnotationSource(annotation.source),
|
|
88
|
-
identifier=annotation.identifier,
|
|
89
|
-
document_position=annotation.document_position,
|
|
90
|
-
created_at=annotation.created_at,
|
|
91
|
-
updated_at=annotation.updated_at,
|
|
92
|
-
)
|
|
210
|
+
from .User import User
|
|
211
|
+
|
|
212
|
+
return User(id=user_id)
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from sqlalchemy import func, select
|
|
6
|
-
from sqlalchemy.orm import joinedload
|
|
7
6
|
from strawberry import UNSET, Private
|
|
8
7
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
9
8
|
from strawberry.scalars import JSON
|
|
@@ -18,10 +17,10 @@ from phoenix.server.api.input_types.ExperimentRunSort import (
|
|
|
18
17
|
get_experiment_run_cursor,
|
|
19
18
|
)
|
|
20
19
|
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
21
|
-
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
20
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
22
21
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
23
22
|
from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
|
|
24
|
-
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
23
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
25
24
|
from phoenix.server.api.types.pagination import (
|
|
26
25
|
ConnectionArgs,
|
|
27
26
|
Cursor,
|
|
@@ -29,26 +28,128 @@ from phoenix.server.api.types.pagination import (
|
|
|
29
28
|
connection_from_cursors_and_nodes,
|
|
30
29
|
connection_from_list,
|
|
31
30
|
)
|
|
32
|
-
from phoenix.server.api.types.Project import Project
|
|
33
31
|
from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
|
|
34
32
|
from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
35
33
|
|
|
36
34
|
_DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE = 50
|
|
37
35
|
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from .Project import Project
|
|
38
|
+
|
|
38
39
|
|
|
39
40
|
@strawberry.type
|
|
40
41
|
class Experiment(Node):
|
|
41
|
-
|
|
42
|
+
id: NodeID[int]
|
|
43
|
+
db_record: strawberry.Private[Optional[models.Experiment]] = None
|
|
42
44
|
cached_sequence_number: Private[Optional[int]] = None
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
45
|
+
|
|
46
|
+
def __post_init__(self) -> None:
|
|
47
|
+
if self.db_record and self.id != self.db_record.id:
|
|
48
|
+
raise ValueError("Experiment ID mismatch")
|
|
49
|
+
|
|
50
|
+
@strawberry.field
|
|
51
|
+
async def name(
|
|
52
|
+
self,
|
|
53
|
+
info: Info[Context, None],
|
|
54
|
+
) -> str:
|
|
55
|
+
if self.db_record:
|
|
56
|
+
val = self.db_record.name
|
|
57
|
+
else:
|
|
58
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
59
|
+
(self.id, models.Experiment.name),
|
|
60
|
+
)
|
|
61
|
+
return val
|
|
62
|
+
|
|
63
|
+
@strawberry.field
|
|
64
|
+
async def project_name(
|
|
65
|
+
self,
|
|
66
|
+
info: Info[Context, None],
|
|
67
|
+
) -> Optional[str]:
|
|
68
|
+
if self.db_record:
|
|
69
|
+
val = self.db_record.project_name
|
|
70
|
+
else:
|
|
71
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
72
|
+
(self.id, models.Experiment.project_name),
|
|
73
|
+
)
|
|
74
|
+
return val
|
|
75
|
+
|
|
76
|
+
@strawberry.field
|
|
77
|
+
async def description(
|
|
78
|
+
self,
|
|
79
|
+
info: Info[Context, None],
|
|
80
|
+
) -> Optional[str]:
|
|
81
|
+
if self.db_record:
|
|
82
|
+
val = self.db_record.description
|
|
83
|
+
else:
|
|
84
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
85
|
+
(self.id, models.Experiment.description),
|
|
86
|
+
)
|
|
87
|
+
return val
|
|
88
|
+
|
|
89
|
+
@strawberry.field
|
|
90
|
+
async def repetitions(
|
|
91
|
+
self,
|
|
92
|
+
info: Info[Context, None],
|
|
93
|
+
) -> int:
|
|
94
|
+
if self.db_record:
|
|
95
|
+
val = self.db_record.repetitions
|
|
96
|
+
else:
|
|
97
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
98
|
+
(self.id, models.Experiment.repetitions),
|
|
99
|
+
)
|
|
100
|
+
return val
|
|
101
|
+
|
|
102
|
+
@strawberry.field
|
|
103
|
+
async def dataset_version_id(
|
|
104
|
+
self,
|
|
105
|
+
info: Info[Context, None],
|
|
106
|
+
) -> GlobalID:
|
|
107
|
+
if self.db_record:
|
|
108
|
+
version_id = self.db_record.dataset_version_id
|
|
109
|
+
else:
|
|
110
|
+
version_id = await info.context.data_loaders.experiment_fields.load(
|
|
111
|
+
(self.id, models.Experiment.dataset_version_id),
|
|
112
|
+
)
|
|
113
|
+
return GlobalID(DatasetVersion.__name__, str(version_id))
|
|
114
|
+
|
|
115
|
+
@strawberry.field
|
|
116
|
+
async def metadata(
|
|
117
|
+
self,
|
|
118
|
+
info: Info[Context, None],
|
|
119
|
+
) -> JSON:
|
|
120
|
+
if self.db_record:
|
|
121
|
+
val = self.db_record.metadata_
|
|
122
|
+
else:
|
|
123
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
124
|
+
(self.id, models.Experiment.metadata_),
|
|
125
|
+
)
|
|
126
|
+
return val
|
|
127
|
+
|
|
128
|
+
@strawberry.field
|
|
129
|
+
async def created_at(
|
|
130
|
+
self,
|
|
131
|
+
info: Info[Context, None],
|
|
132
|
+
) -> datetime:
|
|
133
|
+
if self.db_record:
|
|
134
|
+
val = self.db_record.created_at
|
|
135
|
+
else:
|
|
136
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
137
|
+
(self.id, models.Experiment.created_at),
|
|
138
|
+
)
|
|
139
|
+
return val
|
|
140
|
+
|
|
141
|
+
@strawberry.field
|
|
142
|
+
async def updated_at(
|
|
143
|
+
self,
|
|
144
|
+
info: Info[Context, None],
|
|
145
|
+
) -> datetime:
|
|
146
|
+
if self.db_record:
|
|
147
|
+
val = self.db_record.updated_at
|
|
148
|
+
else:
|
|
149
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
150
|
+
(self.id, models.Experiment.updated_at),
|
|
151
|
+
)
|
|
152
|
+
return val
|
|
52
153
|
|
|
53
154
|
@strawberry.field(
|
|
54
155
|
description="Sequence number (1-based) of experiments belonging to the same dataset"
|
|
@@ -58,9 +159,9 @@ class Experiment(Node):
|
|
|
58
159
|
info: Info[Context, None],
|
|
59
160
|
) -> int:
|
|
60
161
|
if self.cached_sequence_number is None:
|
|
61
|
-
seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.
|
|
162
|
+
seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.id)
|
|
62
163
|
if seq_num is None:
|
|
63
|
-
raise ValueError(f"invalid experiment: id={self.
|
|
164
|
+
raise ValueError(f"invalid experiment: id={self.id}")
|
|
64
165
|
self.cached_sequence_number = seq_num
|
|
65
166
|
return self.cached_sequence_number
|
|
66
167
|
|
|
@@ -74,12 +175,10 @@ class Experiment(Node):
|
|
|
74
175
|
) -> Connection[ExperimentRun]:
|
|
75
176
|
if first is not None and first <= 0:
|
|
76
177
|
raise BadRequest("first must be a positive integer if set")
|
|
77
|
-
experiment_rowid = self.id_attr
|
|
78
178
|
page_size = first or _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE
|
|
79
179
|
experiment_runs_query = (
|
|
80
180
|
select(models.ExperimentRun)
|
|
81
|
-
.where(models.ExperimentRun.experiment_id ==
|
|
82
|
-
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
181
|
+
.where(models.ExperimentRun.experiment_id == self.id)
|
|
83
182
|
.limit(page_size + 1)
|
|
84
183
|
)
|
|
85
184
|
|
|
@@ -94,7 +193,7 @@ class Experiment(Node):
|
|
|
94
193
|
experiment_runs_query = add_order_by_and_page_start_to_query(
|
|
95
194
|
query=experiment_runs_query,
|
|
96
195
|
sort=sort,
|
|
97
|
-
experiment_rowid=
|
|
196
|
+
experiment_rowid=self.id,
|
|
98
197
|
after_experiment_run_rowid=after_experiment_run_rowid,
|
|
99
198
|
after_sort_column_value=after_sort_column_value,
|
|
100
199
|
)
|
|
@@ -111,7 +210,7 @@ class Experiment(Node):
|
|
|
111
210
|
for result in results:
|
|
112
211
|
run = result[0]
|
|
113
212
|
annotation_score = result[1] if len(result) > 1 else None
|
|
114
|
-
gql_run =
|
|
213
|
+
gql_run = ExperimentRun(id=run.id, db_record=run)
|
|
115
214
|
cursor = get_experiment_run_cursor(
|
|
116
215
|
run=run, annotation_score=annotation_score, sort=sort
|
|
117
216
|
)
|
|
@@ -125,14 +224,13 @@ class Experiment(Node):
|
|
|
125
224
|
|
|
126
225
|
@strawberry.field
|
|
127
226
|
async def run_count(self, info: Info[Context, None]) -> int:
|
|
128
|
-
|
|
129
|
-
return await info.context.data_loaders.experiment_run_counts.load(experiment_id)
|
|
227
|
+
return await info.context.data_loaders.experiment_run_counts.load(self.id)
|
|
130
228
|
|
|
131
229
|
@strawberry.field
|
|
132
230
|
async def annotation_summaries(
|
|
133
231
|
self, info: Info[Context, None]
|
|
134
232
|
) -> list[ExperimentAnnotationSummary]:
|
|
135
|
-
experiment_id = self.
|
|
233
|
+
experiment_id = self.id
|
|
136
234
|
return [
|
|
137
235
|
ExperimentAnnotationSummary(
|
|
138
236
|
annotation_name=summary.annotation_name,
|
|
@@ -149,40 +247,42 @@ class Experiment(Node):
|
|
|
149
247
|
|
|
150
248
|
@strawberry.field
|
|
151
249
|
async def error_rate(self, info: Info[Context, None]) -> Optional[float]:
|
|
152
|
-
return await info.context.data_loaders.experiment_error_rates.load(self.
|
|
250
|
+
return await info.context.data_loaders.experiment_error_rates.load(self.id)
|
|
153
251
|
|
|
154
252
|
@strawberry.field
|
|
155
253
|
async def average_run_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
|
|
156
|
-
latency_ms = await info.context.data_loaders.average_experiment_run_latency.load(
|
|
157
|
-
self.id_attr
|
|
158
|
-
)
|
|
254
|
+
latency_ms = await info.context.data_loaders.average_experiment_run_latency.load(self.id)
|
|
159
255
|
return latency_ms
|
|
160
256
|
|
|
161
257
|
@strawberry.field
|
|
162
|
-
async def project(
|
|
163
|
-
|
|
258
|
+
async def project(
|
|
259
|
+
self, info: Info[Context, None]
|
|
260
|
+
) -> Optional[Annotated["Project", strawberry.lazy(".Project")]]:
|
|
261
|
+
if self.db_record:
|
|
262
|
+
project_name = self.db_record.project_name
|
|
263
|
+
else:
|
|
264
|
+
project_name = await info.context.data_loaders.experiment_fields.load(
|
|
265
|
+
(self.id, models.Experiment.project_name),
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if project_name is None:
|
|
164
269
|
return None
|
|
165
270
|
|
|
166
|
-
db_project = await info.context.data_loaders.project_by_name.load(
|
|
271
|
+
db_project = await info.context.data_loaders.project_by_name.load(project_name)
|
|
167
272
|
|
|
168
273
|
if db_project is None:
|
|
169
274
|
return None
|
|
275
|
+
from .Project import Project
|
|
170
276
|
|
|
171
|
-
return Project(
|
|
172
|
-
project_rowid=db_project.id,
|
|
173
|
-
db_project=db_project,
|
|
174
|
-
)
|
|
277
|
+
return Project(id=db_project.id, db_record=db_project)
|
|
175
278
|
|
|
176
279
|
@strawberry.field
|
|
177
280
|
def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
178
|
-
return info.context.last_updated_at.get(
|
|
281
|
+
return info.context.last_updated_at.get(models.Experiment, self.id)
|
|
179
282
|
|
|
180
283
|
@strawberry.field
|
|
181
284
|
async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
|
|
182
|
-
|
|
183
|
-
summary = await info.context.data_loaders.span_cost_summary_by_experiment.load(
|
|
184
|
-
experiment_id
|
|
185
|
-
)
|
|
285
|
+
summary = await info.context.data_loaders.span_cost_summary_by_experiment.load(self.id)
|
|
186
286
|
return SpanCostSummary(
|
|
187
287
|
prompt=CostBreakdown(
|
|
188
288
|
tokens=summary.prompt.tokens,
|
|
@@ -202,8 +302,6 @@ class Experiment(Node):
|
|
|
202
302
|
async def cost_detail_summary_entries(
|
|
203
303
|
self, info: Info[Context, None]
|
|
204
304
|
) -> list[SpanCostDetailSummaryEntry]:
|
|
205
|
-
experiment_id = self.id_attr
|
|
206
|
-
|
|
207
305
|
stmt = (
|
|
208
306
|
select(
|
|
209
307
|
models.SpanCostDetail.token_type,
|
|
@@ -216,7 +314,7 @@ class Experiment(Node):
|
|
|
216
314
|
.join(models.Span, models.SpanCost.span_rowid == models.Span.id)
|
|
217
315
|
.join(models.Trace, models.Span.trace_rowid == models.Trace.id)
|
|
218
316
|
.join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
219
|
-
.where(models.ExperimentRun.experiment_id ==
|
|
317
|
+
.where(models.ExperimentRun.experiment_id == self.id)
|
|
220
318
|
.group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
221
319
|
)
|
|
222
320
|
|
|
@@ -237,9 +335,9 @@ class Experiment(Node):
|
|
|
237
335
|
info: Info[Context, None],
|
|
238
336
|
) -> Connection[DatasetSplit]:
|
|
239
337
|
"""Returns the dataset splits associated with this experiment."""
|
|
240
|
-
splits = await info.context.data_loaders.experiment_dataset_splits.load(self.
|
|
338
|
+
splits = await info.context.data_loaders.experiment_dataset_splits.load(self.id)
|
|
241
339
|
return connection_from_list(
|
|
242
|
-
[
|
|
340
|
+
[DatasetSplit(id=split.id, db_record=split) for split in splits], ConnectionArgs()
|
|
243
341
|
)
|
|
244
342
|
|
|
245
343
|
|
|
@@ -251,14 +349,7 @@ def to_gql_experiment(
|
|
|
251
349
|
Converts an ORM experiment to a GraphQL Experiment.
|
|
252
350
|
"""
|
|
253
351
|
return Experiment(
|
|
352
|
+
id=experiment.id,
|
|
353
|
+
db_record=experiment,
|
|
254
354
|
cached_sequence_number=sequence_number,
|
|
255
|
-
id_attr=experiment.id,
|
|
256
|
-
name=experiment.name,
|
|
257
|
-
project_name=experiment.project_name,
|
|
258
|
-
description=experiment.description,
|
|
259
|
-
repetitions=experiment.repetitions,
|
|
260
|
-
dataset_version_id=GlobalID(DatasetVersion.__name__, str(experiment.dataset_version_id)),
|
|
261
|
-
metadata=experiment.metadata_,
|
|
262
|
-
created_at=experiment.created_at,
|
|
263
|
-
updated_at=experiment.updated_at,
|
|
264
355
|
)
|
|
@@ -26,7 +26,16 @@ DatasetExampleRowId: TypeAlias = int
|
|
|
26
26
|
class ExperimentRepeatedRunGroup(Node):
|
|
27
27
|
experiment_rowid: strawberry.Private[ExperimentRowId]
|
|
28
28
|
dataset_example_rowid: strawberry.Private[DatasetExampleRowId]
|
|
29
|
-
|
|
29
|
+
cached_runs: strawberry.Private[Optional[list[ExperimentRun]]] = None
|
|
30
|
+
|
|
31
|
+
@strawberry.field
|
|
32
|
+
async def runs(self, info: Info[Context, None]) -> list[ExperimentRun]:
|
|
33
|
+
if self.cached_runs is not None:
|
|
34
|
+
return self.cached_runs
|
|
35
|
+
runs = await info.context.data_loaders.experiment_runs_by_experiment_and_example.load(
|
|
36
|
+
(self.experiment_rowid, self.dataset_example_rowid)
|
|
37
|
+
)
|
|
38
|
+
return [ExperimentRun(id=run.id, db_record=run) for run in runs]
|
|
30
39
|
|
|
31
40
|
@classmethod
|
|
32
41
|
def resolve_id(
|