arize-phoenix 12.7.1__py3-none-any.whl → 12.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +76 -73
- phoenix/config.py +131 -9
- phoenix/db/engines.py +127 -14
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/pg_config.py +10 -0
- phoenix/server/api/context.py +23 -0
- phoenix/server/api/dataloaders/__init__.py +6 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/helpers/playground_clients.py +3 -3
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
- phoenix/server/api/mutations/api_key_mutations.py +2 -15
- phoenix/server/api/mutations/chat_mutations.py +3 -2
- phoenix/server/api/mutations/dataset_label_mutations.py +109 -157
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
- phoenix/server/api/mutations/model_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
- phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
- phoenix/server/api/mutations/prompt_mutations.py +24 -117
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
- phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
- phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
- phoenix/server/api/mutations/user_mutations.py +4 -4
- phoenix/server/api/queries.py +80 -213
- phoenix/server/api/subscriptions.py +4 -4
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/Dataset.py +88 -48
- phoenix/server/api/types/DatasetExample.py +34 -30
- phoenix/server/api/types/DatasetLabel.py +47 -13
- phoenix/server/api/types/DatasetSplit.py +87 -21
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +182 -62
- phoenix/server/api/types/Experiment.py +146 -55
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
- phoenix/server/api/types/ExperimentRun.py +118 -61
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +70 -75
- phoenix/server/api/types/ProjectSession.py +69 -37
- phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +82 -44
- phoenix/server/api/types/PromptLabel.py +47 -13
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +116 -115
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +45 -44
- phoenix/server/api/types/TraceAnnotation.py +144 -48
- phoenix/server/api/types/User.py +103 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/app.py +29 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +2 -2
- phoenix/server/static/.vite/manifest.json +43 -43
- phoenix/server/static/assets/{components-BLK5vehh.js → components-v927s3NF.js} +471 -484
- phoenix/server/static/assets/{index-BP0Shd90.js → index-DrD9eSrN.js} +20 -16
- phoenix/server/static/assets/{pages-DIVgyYyy.js → pages-GVybXa_W.js} +754 -753
- phoenix/server/static/assets/{vendor-3BvTzoBp.js → vendor-D-csRHGZ.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-C6_oC0y8.js → vendor-arizeai-BJLCG_Gc.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-DPnZGAZA.js → vendor-codemirror-Cr963DyP.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-CjgSbsB0.js → vendor-recharts-DgmPLgIp.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-CJyhDG0E.js → vendor-shiki-wYOt1s7u.js} +1 -1
- phoenix/version.py +1 -1
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
2
|
from dataclasses import asdict, dataclass
|
|
3
3
|
from datetime import datetime
|
|
4
|
-
from typing import TYPE_CHECKING, Annotated,
|
|
4
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
5
5
|
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import strawberry
|
|
8
8
|
from openinference.semconv.trace import SpanAttributes
|
|
9
9
|
from sqlalchemy import select
|
|
10
|
-
from strawberry import UNSET, Info,
|
|
10
|
+
from strawberry import UNSET, Info, lazy
|
|
11
11
|
from strawberry.relay import Connection, Node, NodeID
|
|
12
12
|
|
|
13
13
|
from phoenix.db import models
|
|
@@ -30,12 +30,51 @@ if TYPE_CHECKING:
|
|
|
30
30
|
|
|
31
31
|
@strawberry.type
|
|
32
32
|
class ProjectSession(Node):
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
33
|
+
id: NodeID[int]
|
|
34
|
+
db_record: strawberry.Private[Optional[models.ProjectSession]] = None
|
|
35
|
+
|
|
36
|
+
def __post_init__(self) -> None:
|
|
37
|
+
if self.db_record and self.id != self.db_record.id:
|
|
38
|
+
raise ValueError("ProjectSession ID mismatch")
|
|
39
|
+
|
|
40
|
+
@strawberry.field
|
|
41
|
+
async def session_id(
|
|
42
|
+
self,
|
|
43
|
+
info: Info[Context, None],
|
|
44
|
+
) -> str:
|
|
45
|
+
if self.db_record:
|
|
46
|
+
val = self.db_record.session_id
|
|
47
|
+
else:
|
|
48
|
+
val = await info.context.data_loaders.project_session_fields.load(
|
|
49
|
+
(self.id, models.ProjectSession.session_id),
|
|
50
|
+
)
|
|
51
|
+
return val
|
|
52
|
+
|
|
53
|
+
@strawberry.field
|
|
54
|
+
async def start_time(
|
|
55
|
+
self,
|
|
56
|
+
info: Info[Context, None],
|
|
57
|
+
) -> datetime:
|
|
58
|
+
if self.db_record:
|
|
59
|
+
val = self.db_record.start_time
|
|
60
|
+
else:
|
|
61
|
+
val = await info.context.data_loaders.project_session_fields.load(
|
|
62
|
+
(self.id, models.ProjectSession.start_time),
|
|
63
|
+
)
|
|
64
|
+
return val
|
|
65
|
+
|
|
66
|
+
@strawberry.field
|
|
67
|
+
async def end_time(
|
|
68
|
+
self,
|
|
69
|
+
info: Info[Context, None],
|
|
70
|
+
) -> datetime:
|
|
71
|
+
if self.db_record:
|
|
72
|
+
val = self.db_record.end_time
|
|
73
|
+
else:
|
|
74
|
+
val = await info.context.data_loaders.project_session_fields.load(
|
|
75
|
+
(self.id, models.ProjectSession.end_time),
|
|
76
|
+
)
|
|
77
|
+
return val
|
|
39
78
|
|
|
40
79
|
@strawberry.field
|
|
41
80
|
async def project(
|
|
@@ -44,28 +83,34 @@ class ProjectSession(Node):
|
|
|
44
83
|
) -> Annotated["Project", lazy(".Project")]:
|
|
45
84
|
from phoenix.server.api.types.Project import Project
|
|
46
85
|
|
|
47
|
-
|
|
86
|
+
if self.db_record:
|
|
87
|
+
project_rowid = self.db_record.project_id
|
|
88
|
+
else:
|
|
89
|
+
project_rowid = await info.context.data_loaders.project_session_fields.load(
|
|
90
|
+
(self.id, models.ProjectSession.project_id),
|
|
91
|
+
)
|
|
92
|
+
return Project(id=project_rowid)
|
|
48
93
|
|
|
49
94
|
@strawberry.field
|
|
50
95
|
async def num_traces(
|
|
51
96
|
self,
|
|
52
97
|
info: Info[Context, None],
|
|
53
98
|
) -> int:
|
|
54
|
-
return await info.context.data_loaders.session_num_traces.load(self.
|
|
99
|
+
return await info.context.data_loaders.session_num_traces.load(self.id)
|
|
55
100
|
|
|
56
101
|
@strawberry.field
|
|
57
102
|
async def num_traces_with_error(
|
|
58
103
|
self,
|
|
59
104
|
info: Info[Context, None],
|
|
60
105
|
) -> int:
|
|
61
|
-
return await info.context.data_loaders.session_num_traces_with_error.load(self.
|
|
106
|
+
return await info.context.data_loaders.session_num_traces_with_error.load(self.id)
|
|
62
107
|
|
|
63
108
|
@strawberry.field
|
|
64
109
|
async def first_input(
|
|
65
110
|
self,
|
|
66
111
|
info: Info[Context, None],
|
|
67
112
|
) -> Optional[SpanIOValue]:
|
|
68
|
-
record = await info.context.data_loaders.session_first_inputs.load(self.
|
|
113
|
+
record = await info.context.data_loaders.session_first_inputs.load(self.id)
|
|
69
114
|
if record is None:
|
|
70
115
|
return None
|
|
71
116
|
return SpanIOValue(
|
|
@@ -78,7 +123,7 @@ class ProjectSession(Node):
|
|
|
78
123
|
self,
|
|
79
124
|
info: Info[Context, None],
|
|
80
125
|
) -> Optional[SpanIOValue]:
|
|
81
|
-
record = await info.context.data_loaders.session_last_outputs.load(self.
|
|
126
|
+
record = await info.context.data_loaders.session_last_outputs.load(self.id)
|
|
82
127
|
if record is None:
|
|
83
128
|
return None
|
|
84
129
|
return SpanIOValue(
|
|
@@ -91,7 +136,7 @@ class ProjectSession(Node):
|
|
|
91
136
|
self,
|
|
92
137
|
info: Info[Context, None],
|
|
93
138
|
) -> TokenUsage:
|
|
94
|
-
usage = await info.context.data_loaders.session_token_usages.load(self.
|
|
139
|
+
usage = await info.context.data_loaders.session_token_usages.load(self.id)
|
|
95
140
|
return TokenUsage(
|
|
96
141
|
prompt=usage.prompt,
|
|
97
142
|
completion=usage.completion,
|
|
@@ -116,12 +161,12 @@ class ProjectSession(Node):
|
|
|
116
161
|
)
|
|
117
162
|
stmt = (
|
|
118
163
|
select(models.Trace)
|
|
119
|
-
.filter_by(project_session_rowid=self.
|
|
164
|
+
.filter_by(project_session_rowid=self.id)
|
|
120
165
|
.order_by(models.Trace.start_time)
|
|
121
166
|
)
|
|
122
167
|
async with info.context.db() as session:
|
|
123
168
|
traces = await session.stream_scalars(stmt)
|
|
124
|
-
data = [Trace(
|
|
169
|
+
data = [Trace(id=trace.id, db_record=trace) async for trace in traces]
|
|
125
170
|
return connection_from_list(data=data, args=args)
|
|
126
171
|
|
|
127
172
|
@strawberry.field
|
|
@@ -131,7 +176,7 @@ class ProjectSession(Node):
|
|
|
131
176
|
probability: float,
|
|
132
177
|
) -> Optional[float]:
|
|
133
178
|
return await info.context.data_loaders.session_trace_latency_ms_quantile.load(
|
|
134
|
-
(self.
|
|
179
|
+
(self.id, probability)
|
|
135
180
|
)
|
|
136
181
|
|
|
137
182
|
@strawberry.field
|
|
@@ -140,7 +185,7 @@ class ProjectSession(Node):
|
|
|
140
185
|
info: Info[Context, None],
|
|
141
186
|
) -> SpanCostSummary:
|
|
142
187
|
loader = info.context.data_loaders.span_cost_summary_by_project_session
|
|
143
|
-
summary = await loader.load(self.
|
|
188
|
+
summary = await loader.load(self.id)
|
|
144
189
|
return SpanCostSummary(
|
|
145
190
|
prompt=CostBreakdown(
|
|
146
191
|
tokens=summary.prompt.tokens,
|
|
@@ -162,7 +207,7 @@ class ProjectSession(Node):
|
|
|
162
207
|
info: Info[Context, None],
|
|
163
208
|
) -> list[SpanCostDetailSummaryEntry]:
|
|
164
209
|
loader = info.context.data_loaders.span_cost_detail_summary_entries_by_project_session
|
|
165
|
-
summary = await loader.load(self.
|
|
210
|
+
summary = await loader.load(self.id)
|
|
166
211
|
return [
|
|
167
212
|
SpanCostDetailSummaryEntry(
|
|
168
213
|
token_type=entry.token_type,
|
|
@@ -181,15 +226,14 @@ class ProjectSession(Node):
|
|
|
181
226
|
info: Info[Context, None],
|
|
182
227
|
) -> list[Annotated["ProjectSessionAnnotation", lazy(".ProjectSessionAnnotation")]]:
|
|
183
228
|
"""Get all annotations for this session."""
|
|
184
|
-
from
|
|
185
|
-
to_gql_project_session_annotation,
|
|
186
|
-
)
|
|
229
|
+
from .ProjectSessionAnnotation import ProjectSessionAnnotation
|
|
187
230
|
|
|
188
|
-
stmt = select(models.ProjectSessionAnnotation).filter_by(project_session_id=self.
|
|
231
|
+
stmt = select(models.ProjectSessionAnnotation).filter_by(project_session_id=self.id)
|
|
189
232
|
async with info.context.db() as session:
|
|
190
233
|
annotations = await session.stream_scalars(stmt)
|
|
191
234
|
return [
|
|
192
|
-
|
|
235
|
+
ProjectSessionAnnotation(id=annotation.id, db_record=annotation)
|
|
236
|
+
async for annotation in annotations
|
|
193
237
|
]
|
|
194
238
|
|
|
195
239
|
@strawberry.field(
|
|
@@ -217,9 +261,7 @@ class ProjectSession(Node):
|
|
|
217
261
|
- data: A list of dictionaries with label statistics
|
|
218
262
|
"""
|
|
219
263
|
# Load all annotations for this span from the data loader
|
|
220
|
-
annotations = await info.context.data_loaders.session_annotations_by_session.load(
|
|
221
|
-
self.id_attr
|
|
222
|
-
)
|
|
264
|
+
annotations = await info.context.data_loaders.session_annotations_by_session.load(self.id)
|
|
223
265
|
|
|
224
266
|
# Apply filter if provided to narrow down the annotations
|
|
225
267
|
if filter:
|
|
@@ -251,16 +293,6 @@ class ProjectSession(Node):
|
|
|
251
293
|
return result
|
|
252
294
|
|
|
253
295
|
|
|
254
|
-
def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession:
|
|
255
|
-
return ProjectSession(
|
|
256
|
-
id_attr=project_session.id,
|
|
257
|
-
session_id=project_session.session_id,
|
|
258
|
-
start_time=project_session.start_time,
|
|
259
|
-
project_rowid=project_session.project_id,
|
|
260
|
-
end_time=project_session.end_time,
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
|
|
264
296
|
INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
|
|
265
297
|
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".")
|
|
266
298
|
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from
|
|
1
|
+
from math import isfinite
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
2
3
|
|
|
3
4
|
import strawberry
|
|
4
|
-
from strawberry import Private
|
|
5
5
|
from strawberry.relay import GlobalID, Node, NodeID
|
|
6
6
|
from strawberry.scalars import JSON
|
|
7
7
|
from strawberry.types import Info
|
|
@@ -10,59 +10,178 @@ from phoenix.db import models
|
|
|
10
10
|
from phoenix.server.api.context import Context
|
|
11
11
|
from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
|
|
12
12
|
|
|
13
|
+
from .Annotation import Annotation
|
|
13
14
|
from .AnnotationSource import AnnotationSource
|
|
14
|
-
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from .ProjectSession import ProjectSession
|
|
18
|
+
from .User import User
|
|
15
19
|
|
|
16
20
|
|
|
17
21
|
@strawberry.type
|
|
18
|
-
class ProjectSessionAnnotation(Node):
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
22
|
+
class ProjectSessionAnnotation(Node, Annotation):
|
|
23
|
+
id: NodeID[int]
|
|
24
|
+
db_record: strawberry.Private[Optional[models.ProjectSessionAnnotation]] = None
|
|
25
|
+
|
|
26
|
+
def __post_init__(self) -> None:
|
|
27
|
+
if self.db_record and self.id != self.db_record.id:
|
|
28
|
+
raise ValueError("ProjectSessionAnnotation ID mismatch")
|
|
29
|
+
|
|
30
|
+
@strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
|
|
31
|
+
async def name(
|
|
32
|
+
self,
|
|
33
|
+
info: Info[Context, None],
|
|
34
|
+
) -> str:
|
|
35
|
+
if self.db_record:
|
|
36
|
+
val = self.db_record.name
|
|
37
|
+
else:
|
|
38
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
39
|
+
(self.id, models.ProjectSessionAnnotation.name),
|
|
40
|
+
)
|
|
41
|
+
return val
|
|
42
|
+
|
|
43
|
+
@strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
|
|
44
|
+
async def annotator_kind(
|
|
45
|
+
self,
|
|
46
|
+
info: Info[Context, None],
|
|
47
|
+
) -> AnnotatorKind:
|
|
48
|
+
if self.db_record:
|
|
49
|
+
val = self.db_record.annotator_kind
|
|
50
|
+
else:
|
|
51
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
52
|
+
(self.id, models.ProjectSessionAnnotation.annotator_kind),
|
|
53
|
+
)
|
|
54
|
+
return AnnotatorKind(val)
|
|
55
|
+
|
|
56
|
+
@strawberry.field(
|
|
57
|
+
description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
|
|
58
|
+
) # type: ignore
|
|
59
|
+
async def label(
|
|
60
|
+
self,
|
|
61
|
+
info: Info[Context, None],
|
|
62
|
+
) -> Optional[str]:
|
|
63
|
+
if self.db_record:
|
|
64
|
+
val = self.db_record.label
|
|
65
|
+
else:
|
|
66
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
67
|
+
(self.id, models.ProjectSessionAnnotation.label),
|
|
68
|
+
)
|
|
69
|
+
return val
|
|
70
|
+
|
|
71
|
+
@strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
|
|
72
|
+
async def score(
|
|
73
|
+
self,
|
|
74
|
+
info: Info[Context, None],
|
|
75
|
+
) -> Optional[float]:
|
|
76
|
+
if self.db_record:
|
|
77
|
+
val = self.db_record.score
|
|
78
|
+
else:
|
|
79
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
80
|
+
(self.id, models.ProjectSessionAnnotation.score),
|
|
81
|
+
)
|
|
82
|
+
return val if val is not None and isfinite(val) else None
|
|
83
|
+
|
|
84
|
+
@strawberry.field(
|
|
85
|
+
description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
|
|
86
|
+
) # type: ignore
|
|
87
|
+
async def explanation(
|
|
88
|
+
self,
|
|
89
|
+
info: Info[Context, None],
|
|
90
|
+
) -> Optional[str]:
|
|
91
|
+
if self.db_record:
|
|
92
|
+
val = self.db_record.explanation
|
|
93
|
+
else:
|
|
94
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
95
|
+
(self.id, models.ProjectSessionAnnotation.explanation),
|
|
96
|
+
)
|
|
97
|
+
return val
|
|
98
|
+
|
|
99
|
+
@strawberry.field(description="Metadata about the annotation.") # type: ignore
|
|
100
|
+
async def metadata(
|
|
101
|
+
self,
|
|
102
|
+
info: Info[Context, None],
|
|
103
|
+
) -> JSON:
|
|
104
|
+
if self.db_record:
|
|
105
|
+
val = self.db_record.metadata_
|
|
106
|
+
else:
|
|
107
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
108
|
+
(self.id, models.ProjectSessionAnnotation.metadata_),
|
|
109
|
+
)
|
|
110
|
+
return val
|
|
111
|
+
|
|
112
|
+
@strawberry.field(description="The identifier of the annotation.") # type: ignore
|
|
113
|
+
async def identifier(
|
|
114
|
+
self,
|
|
115
|
+
info: Info[Context, None],
|
|
116
|
+
) -> str:
|
|
117
|
+
if self.db_record:
|
|
118
|
+
val = self.db_record.identifier
|
|
119
|
+
else:
|
|
120
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
121
|
+
(self.id, models.ProjectSessionAnnotation.identifier),
|
|
122
|
+
)
|
|
123
|
+
return val
|
|
124
|
+
|
|
125
|
+
@strawberry.field(description="The source of the annotation.") # type: ignore
|
|
126
|
+
async def source(
|
|
127
|
+
self,
|
|
128
|
+
info: Info[Context, None],
|
|
129
|
+
) -> AnnotationSource:
|
|
130
|
+
if self.db_record:
|
|
131
|
+
val = self.db_record.source
|
|
132
|
+
else:
|
|
133
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
134
|
+
(self.id, models.ProjectSessionAnnotation.source),
|
|
135
|
+
)
|
|
136
|
+
return AnnotationSource(val)
|
|
137
|
+
|
|
138
|
+
@strawberry.field(description="The project session associated with the annotation.") # type: ignore
|
|
139
|
+
async def project_session_id(
|
|
140
|
+
self,
|
|
141
|
+
info: Info[Context, None],
|
|
142
|
+
) -> GlobalID:
|
|
33
143
|
from phoenix.server.api.types.ProjectSession import ProjectSession
|
|
34
144
|
|
|
35
|
-
|
|
145
|
+
if self.db_record:
|
|
146
|
+
project_session_id = self.db_record.project_session_id
|
|
147
|
+
else:
|
|
148
|
+
project_session_id = (
|
|
149
|
+
await info.context.data_loaders.project_session_annotation_fields.load(
|
|
150
|
+
(self.id, models.ProjectSessionAnnotation.project_session_id),
|
|
151
|
+
)
|
|
152
|
+
)
|
|
153
|
+
return GlobalID(type_name=ProjectSession.__name__, node_id=str(project_session_id))
|
|
154
|
+
|
|
155
|
+
@strawberry.field(description="The project session associated with the annotation.") # type: ignore
|
|
156
|
+
async def project_session(
|
|
157
|
+
self,
|
|
158
|
+
info: Info[Context, None],
|
|
159
|
+
) -> Annotated["ProjectSession", strawberry.lazy(".ProjectSession")]:
|
|
160
|
+
if self.db_record:
|
|
161
|
+
project_session_id = self.db_record.project_session_id
|
|
162
|
+
else:
|
|
163
|
+
project_session_id = (
|
|
164
|
+
await info.context.data_loaders.project_session_annotation_fields.load(
|
|
165
|
+
(self.id, models.ProjectSessionAnnotation.project_session_id),
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
from .ProjectSession import ProjectSession
|
|
169
|
+
|
|
170
|
+
return ProjectSession(id=project_session_id)
|
|
36
171
|
|
|
37
|
-
@strawberry.field
|
|
172
|
+
@strawberry.field(description="The user that produced the annotation.") # type: ignore
|
|
38
173
|
async def user(
|
|
39
174
|
self,
|
|
40
175
|
info: Info[Context, None],
|
|
41
|
-
) -> Optional[User]:
|
|
42
|
-
if self.
|
|
176
|
+
) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
|
|
177
|
+
if self.db_record:
|
|
178
|
+
user_id = self.db_record.user_id
|
|
179
|
+
else:
|
|
180
|
+
user_id = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
181
|
+
(self.id, models.ProjectSessionAnnotation.user_id),
|
|
182
|
+
)
|
|
183
|
+
if user_id is None:
|
|
43
184
|
return None
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
return to_gql_user(user)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def to_gql_project_session_annotation(
|
|
51
|
-
annotation: models.ProjectSessionAnnotation,
|
|
52
|
-
) -> ProjectSessionAnnotation:
|
|
53
|
-
"""
|
|
54
|
-
Converts an ORM projectSession annotation to a GraphQL ProjectSessionAnnotation.
|
|
55
|
-
"""
|
|
56
|
-
return ProjectSessionAnnotation(
|
|
57
|
-
id_attr=annotation.id,
|
|
58
|
-
user_id=annotation.user_id,
|
|
59
|
-
_project_session_id=annotation.project_session_id,
|
|
60
|
-
name=annotation.name,
|
|
61
|
-
annotator_kind=AnnotatorKind(annotation.annotator_kind),
|
|
62
|
-
label=annotation.label,
|
|
63
|
-
score=annotation.score,
|
|
64
|
-
explanation=annotation.explanation,
|
|
65
|
-
metadata=JSON(annotation.metadata_),
|
|
66
|
-
identifier=annotation.identifier,
|
|
67
|
-
source=AnnotationSource(annotation.source),
|
|
68
|
-
)
|
|
185
|
+
from .User import User
|
|
186
|
+
|
|
187
|
+
return User(id=user_id)
|
|
@@ -106,5 +106,5 @@ class ProjectTraceRetentionPolicy(Node):
|
|
|
106
106
|
project_rowids = await info.context.data_loaders.projects_by_trace_retention_policy_id.load(
|
|
107
107
|
self.id
|
|
108
108
|
)
|
|
109
|
-
data = [Project(
|
|
109
|
+
data = [Project(id=project_rowid) for project_rowid in project_rowids]
|
|
110
110
|
return connection_from_list(data=data, args=args)
|
|
@@ -20,21 +20,76 @@ from phoenix.server.api.types.pagination import (
|
|
|
20
20
|
connection_from_list,
|
|
21
21
|
)
|
|
22
22
|
|
|
23
|
-
from .PromptLabel import PromptLabel
|
|
23
|
+
from .PromptLabel import PromptLabel
|
|
24
24
|
from .PromptVersion import (
|
|
25
25
|
PromptVersion,
|
|
26
26
|
to_gql_prompt_version,
|
|
27
27
|
)
|
|
28
|
-
from .PromptVersionTag import PromptVersionTag
|
|
28
|
+
from .PromptVersionTag import PromptVersionTag
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
@strawberry.type
|
|
32
32
|
class Prompt(Node):
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
33
|
+
id: NodeID[int]
|
|
34
|
+
db_record: strawberry.Private[Optional[models.Prompt]] = None
|
|
35
|
+
|
|
36
|
+
def __post_init__(self) -> None:
|
|
37
|
+
if self.db_record and self.id != self.db_record.id:
|
|
38
|
+
raise ValueError("Prompt ID mismatch")
|
|
39
|
+
|
|
40
|
+
@strawberry.field
|
|
41
|
+
async def source_prompt_id(
|
|
42
|
+
self,
|
|
43
|
+
info: Info[Context, None],
|
|
44
|
+
) -> Optional[GlobalID]:
|
|
45
|
+
if self.db_record:
|
|
46
|
+
source_id = self.db_record.source_prompt_id
|
|
47
|
+
else:
|
|
48
|
+
source_id = await info.context.data_loaders.prompt_fields.load(
|
|
49
|
+
(self.id, models.Prompt.source_prompt_id),
|
|
50
|
+
)
|
|
51
|
+
if not source_id:
|
|
52
|
+
return None
|
|
53
|
+
return GlobalID(Prompt.__name__, str(source_id))
|
|
54
|
+
|
|
55
|
+
@strawberry.field
|
|
56
|
+
async def name(
|
|
57
|
+
self,
|
|
58
|
+
info: Info[Context, None],
|
|
59
|
+
) -> Identifier:
|
|
60
|
+
if self.db_record:
|
|
61
|
+
val = self.db_record.name
|
|
62
|
+
else:
|
|
63
|
+
val = await info.context.data_loaders.prompt_fields.load(
|
|
64
|
+
(self.id, models.Prompt.name),
|
|
65
|
+
)
|
|
66
|
+
return Identifier(val.root)
|
|
67
|
+
|
|
68
|
+
@strawberry.field
|
|
69
|
+
async def description(
|
|
70
|
+
self,
|
|
71
|
+
info: Info[Context, None],
|
|
72
|
+
) -> Optional[str]:
|
|
73
|
+
if self.db_record:
|
|
74
|
+
val = self.db_record.description
|
|
75
|
+
else:
|
|
76
|
+
val = await info.context.data_loaders.prompt_fields.load(
|
|
77
|
+
(self.id, models.Prompt.description),
|
|
78
|
+
)
|
|
79
|
+
return val
|
|
80
|
+
|
|
81
|
+
@strawberry.field
|
|
82
|
+
async def created_at(
|
|
83
|
+
self,
|
|
84
|
+
info: Info[Context, None],
|
|
85
|
+
) -> datetime:
|
|
86
|
+
if self.db_record:
|
|
87
|
+
val = self.db_record.created_at
|
|
88
|
+
else:
|
|
89
|
+
val = await info.context.data_loaders.prompt_fields.load(
|
|
90
|
+
(self.id, models.Prompt.created_at),
|
|
91
|
+
)
|
|
92
|
+
return val
|
|
38
93
|
|
|
39
94
|
@strawberry.field
|
|
40
95
|
async def version(
|
|
@@ -49,7 +104,7 @@ class Prompt(Node):
|
|
|
49
104
|
version = await session.scalar(
|
|
50
105
|
select(models.PromptVersion).where(
|
|
51
106
|
models.PromptVersion.id == v_id,
|
|
52
|
-
models.PromptVersion.prompt_id == self.
|
|
107
|
+
models.PromptVersion.prompt_id == self.id,
|
|
53
108
|
)
|
|
54
109
|
)
|
|
55
110
|
if not version:
|
|
@@ -61,7 +116,7 @@ class Prompt(Node):
|
|
|
61
116
|
raise NotFound(f"Prompt version tag not found: {tag_name}")
|
|
62
117
|
version = await session.scalar(
|
|
63
118
|
select(models.PromptVersion)
|
|
64
|
-
.where(models.PromptVersion.prompt_id == self.
|
|
119
|
+
.where(models.PromptVersion.prompt_id == self.id)
|
|
65
120
|
.join_from(models.PromptVersion, models.PromptVersionTag)
|
|
66
121
|
.where(models.PromptVersionTag.name == name)
|
|
67
122
|
)
|
|
@@ -70,7 +125,7 @@ class Prompt(Node):
|
|
|
70
125
|
else:
|
|
71
126
|
stmt = (
|
|
72
127
|
select(models.PromptVersion)
|
|
73
|
-
.where(models.PromptVersion.prompt_id == self.
|
|
128
|
+
.where(models.PromptVersion.prompt_id == self.id)
|
|
74
129
|
.order_by(models.PromptVersion.id.desc())
|
|
75
130
|
.limit(1)
|
|
76
131
|
)
|
|
@@ -83,10 +138,11 @@ class Prompt(Node):
|
|
|
83
138
|
async def version_tags(self, info: Info[Context, None]) -> list[PromptVersionTag]:
|
|
84
139
|
async with info.context.db() as session:
|
|
85
140
|
stmt = select(models.PromptVersionTag).where(
|
|
86
|
-
models.PromptVersionTag.prompt_id == self.
|
|
141
|
+
models.PromptVersionTag.prompt_id == self.id
|
|
87
142
|
)
|
|
88
143
|
return [
|
|
89
|
-
|
|
144
|
+
PromptVersionTag(id=tag.id, db_record=tag)
|
|
145
|
+
async for tag in await session.stream_scalars(stmt)
|
|
90
146
|
]
|
|
91
147
|
|
|
92
148
|
@strawberry.field
|
|
@@ -107,7 +163,7 @@ class Prompt(Node):
|
|
|
107
163
|
row_number = func.row_number().over(order_by=models.PromptVersion.id).label("row_number")
|
|
108
164
|
stmt = (
|
|
109
165
|
select(models.PromptVersion, row_number)
|
|
110
|
-
.where(models.PromptVersion.prompt_id == self.
|
|
166
|
+
.where(models.PromptVersion.prompt_id == self.id)
|
|
111
167
|
.order_by(models.PromptVersion.id.desc())
|
|
112
168
|
)
|
|
113
169
|
async with info.context.db() as session:
|
|
@@ -119,20 +175,19 @@ class Prompt(Node):
|
|
|
119
175
|
|
|
120
176
|
@strawberry.field
|
|
121
177
|
async def source_prompt(self, info: Info[Context, None]) -> Optional["Prompt"]:
|
|
122
|
-
if
|
|
178
|
+
if self.db_record:
|
|
179
|
+
id_ = self.db_record.source_prompt_id
|
|
180
|
+
else:
|
|
181
|
+
id_ = await info.context.data_loaders.prompt_fields.load(
|
|
182
|
+
(self.id, models.Prompt.source_prompt_id),
|
|
183
|
+
)
|
|
184
|
+
if not id_:
|
|
123
185
|
return None
|
|
124
|
-
|
|
125
|
-
source_prompt_id = from_global_id_with_expected_type(
|
|
126
|
-
global_id=self.source_prompt_id, expected_type_name=Prompt.__name__
|
|
127
|
-
)
|
|
128
|
-
|
|
129
186
|
async with info.context.db() as session:
|
|
130
|
-
source_prompt = await session.
|
|
131
|
-
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
raise NotFound(f"Source prompt not found: {self.source_prompt_id}")
|
|
135
|
-
return to_gql_prompt_from_orm(source_prompt)
|
|
187
|
+
source_prompt = await session.get(models.Prompt, id_)
|
|
188
|
+
if not source_prompt:
|
|
189
|
+
raise NotFound(f"Source prompt not found: {id_}")
|
|
190
|
+
return Prompt(id=source_prompt.id, db_record=source_prompt)
|
|
136
191
|
|
|
137
192
|
@strawberry.field
|
|
138
193
|
async def labels(self, info: Info[Context, None]) -> list["PromptLabel"]:
|
|
@@ -140,23 +195,6 @@ class Prompt(Node):
|
|
|
140
195
|
labels = await session.scalars(
|
|
141
196
|
select(models.PromptLabel)
|
|
142
197
|
.join(models.PromptPromptLabel)
|
|
143
|
-
.where(models.PromptPromptLabel.prompt_id == self.
|
|
198
|
+
.where(models.PromptPromptLabel.prompt_id == self.id)
|
|
144
199
|
)
|
|
145
|
-
return [
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
def to_gql_prompt_from_orm(orm_model: "models.Prompt") -> Prompt:
|
|
149
|
-
if not orm_model.source_prompt_id:
|
|
150
|
-
source_prompt_gid = None
|
|
151
|
-
else:
|
|
152
|
-
source_prompt_gid = GlobalID(
|
|
153
|
-
Prompt.__name__,
|
|
154
|
-
str(orm_model.source_prompt_id),
|
|
155
|
-
)
|
|
156
|
-
return Prompt(
|
|
157
|
-
id_attr=orm_model.id,
|
|
158
|
-
source_prompt_id=source_prompt_gid,
|
|
159
|
-
name=Identifier(orm_model.name.root),
|
|
160
|
-
description=orm_model.description,
|
|
161
|
-
created_at=orm_model.created_at,
|
|
162
|
-
)
|
|
200
|
+
return [PromptLabel(id=label.id, db_record=label) for label in labels]
|