arize-phoenix 12.7.1__py3-none-any.whl → 12.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +76 -73
- phoenix/config.py +131 -9
- phoenix/db/engines.py +127 -14
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/pg_config.py +10 -0
- phoenix/server/api/context.py +23 -0
- phoenix/server/api/dataloaders/__init__.py +6 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/helpers/playground_clients.py +3 -3
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
- phoenix/server/api/mutations/api_key_mutations.py +2 -15
- phoenix/server/api/mutations/chat_mutations.py +3 -2
- phoenix/server/api/mutations/dataset_label_mutations.py +109 -157
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
- phoenix/server/api/mutations/model_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
- phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
- phoenix/server/api/mutations/prompt_mutations.py +24 -117
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
- phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
- phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
- phoenix/server/api/mutations/user_mutations.py +4 -4
- phoenix/server/api/queries.py +80 -213
- phoenix/server/api/subscriptions.py +4 -4
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/Dataset.py +88 -48
- phoenix/server/api/types/DatasetExample.py +34 -30
- phoenix/server/api/types/DatasetLabel.py +47 -13
- phoenix/server/api/types/DatasetSplit.py +87 -21
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +182 -62
- phoenix/server/api/types/Experiment.py +146 -55
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
- phoenix/server/api/types/ExperimentRun.py +118 -61
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +70 -75
- phoenix/server/api/types/ProjectSession.py +69 -37
- phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +82 -44
- phoenix/server/api/types/PromptLabel.py +47 -13
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +116 -115
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +45 -44
- phoenix/server/api/types/TraceAnnotation.py +144 -48
- phoenix/server/api/types/User.py +103 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/app.py +29 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +2 -2
- phoenix/server/static/.vite/manifest.json +43 -43
- phoenix/server/static/assets/{components-BLK5vehh.js → components-v927s3NF.js} +471 -484
- phoenix/server/static/assets/{index-BP0Shd90.js → index-DrD9eSrN.js} +20 -16
- phoenix/server/static/assets/{pages-DIVgyYyy.js → pages-GVybXa_W.js} +754 -753
- phoenix/server/static/assets/{vendor-3BvTzoBp.js → vendor-D-csRHGZ.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-C6_oC0y8.js → vendor-arizeai-BJLCG_Gc.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-DPnZGAZA.js → vendor-codemirror-Cr963DyP.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-CjgSbsB0.js → vendor-recharts-DgmPLgIp.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-CJyhDG0E.js → vendor-shiki-wYOt1s7u.js} +1 -1
- phoenix/version.py +1 -1
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
from
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from math import isfinite
|
|
3
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
2
4
|
|
|
3
5
|
import strawberry
|
|
4
|
-
from strawberry import Private
|
|
5
6
|
from strawberry.relay import GlobalID, Node, NodeID
|
|
6
7
|
from strawberry.scalars import JSON
|
|
7
8
|
from strawberry.types import Info
|
|
@@ -12,56 +13,202 @@ from phoenix.server.api.context import Context
|
|
|
12
13
|
from .Annotation import Annotation
|
|
13
14
|
from .AnnotationSource import AnnotationSource
|
|
14
15
|
from .AnnotatorKind import AnnotatorKind
|
|
15
|
-
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from .Span import Span
|
|
19
|
+
from .User import User
|
|
16
20
|
|
|
17
21
|
|
|
18
22
|
@strawberry.type
|
|
19
23
|
class SpanAnnotation(Node, Annotation):
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
24
|
+
id: NodeID[int]
|
|
25
|
+
db_record: strawberry.Private[Optional[models.SpanAnnotation]] = None
|
|
26
|
+
|
|
27
|
+
def __post_init__(self) -> None:
|
|
28
|
+
if self.db_record and self.id != self.db_record.id:
|
|
29
|
+
raise ValueError("SpanAnnotation ID mismatch")
|
|
30
|
+
|
|
31
|
+
@strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
|
|
32
|
+
async def name(
|
|
33
|
+
self,
|
|
34
|
+
info: Info[Context, None],
|
|
35
|
+
) -> str:
|
|
36
|
+
if self.db_record:
|
|
37
|
+
val = self.db_record.name
|
|
38
|
+
else:
|
|
39
|
+
val = await info.context.data_loaders.span_annotation_fields.load(
|
|
40
|
+
(self.id, models.SpanAnnotation.name),
|
|
41
|
+
)
|
|
42
|
+
return val
|
|
43
|
+
|
|
44
|
+
@strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
|
|
45
|
+
async def annotator_kind(
|
|
46
|
+
self,
|
|
47
|
+
info: Info[Context, None],
|
|
48
|
+
) -> AnnotatorKind:
|
|
49
|
+
if self.db_record:
|
|
50
|
+
val = self.db_record.annotator_kind
|
|
51
|
+
else:
|
|
52
|
+
val = await info.context.data_loaders.span_annotation_fields.load(
|
|
53
|
+
(self.id, models.SpanAnnotation.annotator_kind),
|
|
54
|
+
)
|
|
55
|
+
return AnnotatorKind(val)
|
|
56
|
+
|
|
57
|
+
@strawberry.field(
|
|
58
|
+
description="Value of the annotation in the form of a string, e.g. "
|
|
59
|
+
"'helpful' or 'not helpful'. Note that the label is not necessarily binary."
|
|
60
|
+
) # type: ignore
|
|
61
|
+
async def label(
|
|
62
|
+
self,
|
|
63
|
+
info: Info[Context, None],
|
|
64
|
+
) -> Optional[str]:
|
|
65
|
+
if self.db_record:
|
|
66
|
+
val = self.db_record.label
|
|
67
|
+
else:
|
|
68
|
+
val = await info.context.data_loaders.span_annotation_fields.load(
|
|
69
|
+
(self.id, models.SpanAnnotation.label),
|
|
70
|
+
)
|
|
71
|
+
return val
|
|
72
|
+
|
|
73
|
+
@strawberry.field(
|
|
74
|
+
description="Value of the annotation in the form of a numeric score.",
|
|
75
|
+
) # type: ignore
|
|
76
|
+
async def score(
|
|
77
|
+
self,
|
|
78
|
+
info: Info[Context, None],
|
|
79
|
+
) -> Optional[float]:
|
|
80
|
+
if self.db_record:
|
|
81
|
+
val = self.db_record.score
|
|
82
|
+
else:
|
|
83
|
+
val = await info.context.data_loaders.span_annotation_fields.load(
|
|
84
|
+
(self.id, models.SpanAnnotation.score),
|
|
85
|
+
)
|
|
86
|
+
return val if val is not None and isfinite(val) else None
|
|
87
|
+
|
|
88
|
+
@strawberry.field(
|
|
89
|
+
description="The annotator's explanation for the annotation result (i.e. "
|
|
90
|
+
"score or label, or both) given to the subject."
|
|
91
|
+
) # type: ignore
|
|
92
|
+
async def explanation(
|
|
93
|
+
self,
|
|
94
|
+
info: Info[Context, None],
|
|
95
|
+
) -> Optional[str]:
|
|
96
|
+
if self.db_record:
|
|
97
|
+
val = self.db_record.explanation
|
|
98
|
+
else:
|
|
99
|
+
val = await info.context.data_loaders.span_annotation_fields.load(
|
|
100
|
+
(self.id, models.SpanAnnotation.explanation),
|
|
101
|
+
)
|
|
102
|
+
return val
|
|
103
|
+
|
|
104
|
+
@strawberry.field(description="Metadata about the annotation.") # type: ignore
|
|
105
|
+
async def metadata(
|
|
106
|
+
self,
|
|
107
|
+
info: Info[Context, None],
|
|
108
|
+
) -> JSON:
|
|
109
|
+
if self.db_record:
|
|
110
|
+
val = self.db_record.metadata_
|
|
111
|
+
else:
|
|
112
|
+
val = await info.context.data_loaders.span_annotation_fields.load(
|
|
113
|
+
(self.id, models.SpanAnnotation.metadata_),
|
|
114
|
+
)
|
|
115
|
+
return val
|
|
116
|
+
|
|
117
|
+
@strawberry.field(description="The source of the annotation.") # type: ignore
|
|
118
|
+
async def source(
|
|
119
|
+
self,
|
|
120
|
+
info: Info[Context, None],
|
|
121
|
+
) -> AnnotationSource:
|
|
122
|
+
if self.db_record:
|
|
123
|
+
val = self.db_record.source
|
|
124
|
+
else:
|
|
125
|
+
val = await info.context.data_loaders.span_annotation_fields.load(
|
|
126
|
+
(self.id, models.SpanAnnotation.source),
|
|
127
|
+
)
|
|
128
|
+
return AnnotationSource(val)
|
|
129
|
+
|
|
130
|
+
@strawberry.field(description="The identifier of the annotation.") # type: ignore
|
|
131
|
+
async def identifier(
|
|
132
|
+
self,
|
|
133
|
+
info: Info[Context, None],
|
|
134
|
+
) -> str:
|
|
135
|
+
if self.db_record:
|
|
136
|
+
val = self.db_record.identifier
|
|
137
|
+
else:
|
|
138
|
+
val = await info.context.data_loaders.span_annotation_fields.load(
|
|
139
|
+
(self.id, models.SpanAnnotation.identifier),
|
|
140
|
+
)
|
|
141
|
+
return val
|
|
142
|
+
|
|
143
|
+
@strawberry.field(description="The date and time when the annotation was created.") # type: ignore
|
|
144
|
+
async def created_at(
|
|
145
|
+
self,
|
|
146
|
+
info: Info[Context, None],
|
|
147
|
+
) -> datetime:
|
|
148
|
+
if self.db_record:
|
|
149
|
+
val = self.db_record.created_at
|
|
150
|
+
else:
|
|
151
|
+
val = await info.context.data_loaders.span_annotation_fields.load(
|
|
152
|
+
(self.id, models.SpanAnnotation.created_at),
|
|
153
|
+
)
|
|
154
|
+
return val
|
|
155
|
+
|
|
156
|
+
@strawberry.field(description="The date and time when the annotation was last updated.") # type: ignore
|
|
157
|
+
async def updated_at(
|
|
158
|
+
self,
|
|
159
|
+
info: Info[Context, None],
|
|
160
|
+
) -> datetime:
|
|
161
|
+
if self.db_record:
|
|
162
|
+
val = self.db_record.updated_at
|
|
163
|
+
else:
|
|
164
|
+
val = await info.context.data_loaders.span_annotation_fields.load(
|
|
165
|
+
(self.id, models.SpanAnnotation.updated_at),
|
|
166
|
+
)
|
|
167
|
+
return val
|
|
27
168
|
|
|
28
169
|
@strawberry.field
|
|
29
|
-
async def span_id(
|
|
170
|
+
async def span_id(
|
|
171
|
+
self,
|
|
172
|
+
info: Info[Context, None],
|
|
173
|
+
) -> GlobalID:
|
|
30
174
|
from phoenix.server.api.types.Span import Span
|
|
31
175
|
|
|
32
|
-
|
|
176
|
+
if self.db_record:
|
|
177
|
+
span_rowid = self.db_record.span_rowid
|
|
178
|
+
else:
|
|
179
|
+
span_rowid = await info.context.data_loaders.span_annotation_fields.load(
|
|
180
|
+
(self.id, models.SpanAnnotation.span_rowid),
|
|
181
|
+
)
|
|
182
|
+
return GlobalID(type_name=Span.__name__, node_id=str(span_rowid))
|
|
33
183
|
|
|
34
|
-
@strawberry.field
|
|
184
|
+
@strawberry.field(description="The span associated with the annotation.") # type: ignore
|
|
185
|
+
async def span(
|
|
186
|
+
self,
|
|
187
|
+
info: Info[Context, None],
|
|
188
|
+
) -> Annotated["Span", strawberry.lazy(".Span")]:
|
|
189
|
+
if self.db_record:
|
|
190
|
+
span_rowid = self.db_record.span_rowid
|
|
191
|
+
else:
|
|
192
|
+
span_rowid = await info.context.data_loaders.span_annotation_fields.load(
|
|
193
|
+
(self.id, models.SpanAnnotation.span_rowid),
|
|
194
|
+
)
|
|
195
|
+
from .Span import Span
|
|
196
|
+
|
|
197
|
+
return Span(id=span_rowid)
|
|
198
|
+
|
|
199
|
+
@strawberry.field(description="The user that produced the annotation.") # type: ignore
|
|
35
200
|
async def user(
|
|
36
201
|
self,
|
|
37
202
|
info: Info[Context, None],
|
|
38
|
-
) -> Optional[User]:
|
|
39
|
-
if self.
|
|
203
|
+
) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
|
|
204
|
+
if self.db_record:
|
|
205
|
+
user_id = self.db_record.user_id
|
|
206
|
+
else:
|
|
207
|
+
user_id = await info.context.data_loaders.span_annotation_fields.load(
|
|
208
|
+
(self.id, models.SpanAnnotation.user_id),
|
|
209
|
+
)
|
|
210
|
+
if user_id is None:
|
|
40
211
|
return None
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
return to_gql_user(user)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def to_gql_span_annotation(
|
|
48
|
-
annotation: models.SpanAnnotation,
|
|
49
|
-
) -> SpanAnnotation:
|
|
50
|
-
"""
|
|
51
|
-
Converts an ORM span annotation to a GraphQL SpanAnnotation.
|
|
52
|
-
"""
|
|
53
|
-
return SpanAnnotation(
|
|
54
|
-
id_attr=annotation.id,
|
|
55
|
-
user_id=annotation.user_id,
|
|
56
|
-
span_rowid=annotation.span_rowid,
|
|
57
|
-
name=annotation.name,
|
|
58
|
-
annotator_kind=AnnotatorKind(annotation.annotator_kind),
|
|
59
|
-
label=annotation.label,
|
|
60
|
-
score=annotation.score,
|
|
61
|
-
explanation=annotation.explanation,
|
|
62
|
-
metadata=annotation.metadata_,
|
|
63
|
-
source=AnnotationSource(annotation.source),
|
|
64
|
-
identifier=annotation.identifier,
|
|
65
|
-
created_at=annotation.created_at,
|
|
66
|
-
updated_at=annotation.updated_at,
|
|
67
|
-
)
|
|
212
|
+
from .User import User
|
|
213
|
+
|
|
214
|
+
return User(id=user_id)
|
|
@@ -1,9 +1,73 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
1
4
|
import strawberry
|
|
2
5
|
from strawberry.relay import Node, NodeID
|
|
6
|
+
from strawberry.types import Info
|
|
7
|
+
|
|
8
|
+
from phoenix.db.models import ApiKey as OrmApiKey
|
|
9
|
+
from phoenix.server.api.context import Context
|
|
3
10
|
|
|
4
11
|
from .ApiKey import ApiKey
|
|
5
12
|
|
|
6
13
|
|
|
7
14
|
@strawberry.type
|
|
8
15
|
class SystemApiKey(ApiKey, Node):
|
|
9
|
-
|
|
16
|
+
id: NodeID[int]
|
|
17
|
+
db_record: strawberry.Private[Optional[OrmApiKey]] = None
|
|
18
|
+
|
|
19
|
+
def __post_init__(self) -> None:
|
|
20
|
+
if self.db_record and self.id != self.db_record.id:
|
|
21
|
+
raise ValueError("SystemApiKey ID mismatch")
|
|
22
|
+
|
|
23
|
+
@strawberry.field(description="Name of the API key.") # type: ignore
|
|
24
|
+
async def name(
|
|
25
|
+
self,
|
|
26
|
+
info: Info[Context, None],
|
|
27
|
+
) -> str:
|
|
28
|
+
if self.db_record:
|
|
29
|
+
val = self.db_record.name
|
|
30
|
+
else:
|
|
31
|
+
val = await info.context.data_loaders.user_api_key_fields.load(
|
|
32
|
+
(self.id, OrmApiKey.name),
|
|
33
|
+
)
|
|
34
|
+
return val
|
|
35
|
+
|
|
36
|
+
@strawberry.field(description="Description of the API key.") # type: ignore
|
|
37
|
+
async def description(
|
|
38
|
+
self,
|
|
39
|
+
info: Info[Context, None],
|
|
40
|
+
) -> Optional[str]:
|
|
41
|
+
if self.db_record:
|
|
42
|
+
val = self.db_record.description
|
|
43
|
+
else:
|
|
44
|
+
val = await info.context.data_loaders.user_api_key_fields.load(
|
|
45
|
+
(self.id, OrmApiKey.description),
|
|
46
|
+
)
|
|
47
|
+
return val
|
|
48
|
+
|
|
49
|
+
@strawberry.field(description="The date and time the API key was created.") # type: ignore
|
|
50
|
+
async def created_at(
|
|
51
|
+
self,
|
|
52
|
+
info: Info[Context, None],
|
|
53
|
+
) -> datetime:
|
|
54
|
+
if self.db_record:
|
|
55
|
+
val = self.db_record.created_at
|
|
56
|
+
else:
|
|
57
|
+
val = await info.context.data_loaders.user_api_key_fields.load(
|
|
58
|
+
(self.id, OrmApiKey.created_at),
|
|
59
|
+
)
|
|
60
|
+
return val
|
|
61
|
+
|
|
62
|
+
@strawberry.field(description="The date and time the API key will expire.") # type: ignore
|
|
63
|
+
async def expires_at(
|
|
64
|
+
self,
|
|
65
|
+
info: Info[Context, None],
|
|
66
|
+
) -> Optional[datetime]:
|
|
67
|
+
if self.db_record:
|
|
68
|
+
val = self.db_record.expires_at
|
|
69
|
+
else:
|
|
70
|
+
val = await info.context.data_loaders.user_api_key_fields.load(
|
|
71
|
+
(self.id, OrmApiKey.expires_at),
|
|
72
|
+
)
|
|
73
|
+
return val
|
|
@@ -9,7 +9,7 @@ import pandas as pd
|
|
|
9
9
|
import strawberry
|
|
10
10
|
from openinference.semconv.trace import SpanAttributes
|
|
11
11
|
from sqlalchemy import desc, select
|
|
12
|
-
from strawberry import ID, UNSET,
|
|
12
|
+
from strawberry import ID, UNSET, lazy
|
|
13
13
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
14
14
|
from strawberry.types import Info
|
|
15
15
|
from typing_extensions import TypeAlias
|
|
@@ -29,7 +29,7 @@ from phoenix.server.api.types.SortDir import SortDir
|
|
|
29
29
|
from phoenix.server.api.types.Span import Span
|
|
30
30
|
from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
|
|
31
31
|
from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
32
|
-
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
|
|
32
|
+
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
|
|
33
33
|
|
|
34
34
|
if TYPE_CHECKING:
|
|
35
35
|
from phoenix.server.api.types.Project import Project
|
|
@@ -41,11 +41,11 @@ TraceRowId: TypeAlias = int
|
|
|
41
41
|
|
|
42
42
|
@strawberry.type
|
|
43
43
|
class Trace(Node):
|
|
44
|
-
|
|
45
|
-
|
|
44
|
+
id: NodeID[TraceRowId]
|
|
45
|
+
db_record: strawberry.Private[Optional[models.Trace]] = None
|
|
46
46
|
|
|
47
47
|
def __post_init__(self) -> None:
|
|
48
|
-
if self.
|
|
48
|
+
if self.db_record and self.id != self.db_record.id:
|
|
49
49
|
raise ValueError("Trace ID mismatch")
|
|
50
50
|
|
|
51
51
|
@strawberry.field
|
|
@@ -53,11 +53,11 @@ class Trace(Node):
|
|
|
53
53
|
self,
|
|
54
54
|
info: Info[Context, None],
|
|
55
55
|
) -> ID:
|
|
56
|
-
if self.
|
|
57
|
-
trace_id = self.
|
|
56
|
+
if self.db_record:
|
|
57
|
+
trace_id = self.db_record.trace_id
|
|
58
58
|
else:
|
|
59
59
|
trace_id = await info.context.data_loaders.trace_fields.load(
|
|
60
|
-
(self.
|
|
60
|
+
(self.id, models.Trace.trace_id),
|
|
61
61
|
)
|
|
62
62
|
return ID(trace_id)
|
|
63
63
|
|
|
@@ -66,11 +66,11 @@ class Trace(Node):
|
|
|
66
66
|
self,
|
|
67
67
|
info: Info[Context, None],
|
|
68
68
|
) -> datetime:
|
|
69
|
-
if self.
|
|
70
|
-
start_time = self.
|
|
69
|
+
if self.db_record:
|
|
70
|
+
start_time = self.db_record.start_time
|
|
71
71
|
else:
|
|
72
72
|
start_time = await info.context.data_loaders.trace_fields.load(
|
|
73
|
-
(self.
|
|
73
|
+
(self.id, models.Trace.start_time),
|
|
74
74
|
)
|
|
75
75
|
return start_time
|
|
76
76
|
|
|
@@ -79,11 +79,11 @@ class Trace(Node):
|
|
|
79
79
|
self,
|
|
80
80
|
info: Info[Context, None],
|
|
81
81
|
) -> datetime:
|
|
82
|
-
if self.
|
|
83
|
-
end_time = self.
|
|
82
|
+
if self.db_record:
|
|
83
|
+
end_time = self.db_record.end_time
|
|
84
84
|
else:
|
|
85
85
|
end_time = await info.context.data_loaders.trace_fields.load(
|
|
86
|
-
(self.
|
|
86
|
+
(self.id, models.Trace.end_time),
|
|
87
87
|
)
|
|
88
88
|
return end_time
|
|
89
89
|
|
|
@@ -92,11 +92,11 @@ class Trace(Node):
|
|
|
92
92
|
self,
|
|
93
93
|
info: Info[Context, None],
|
|
94
94
|
) -> Optional[float]:
|
|
95
|
-
if self.
|
|
96
|
-
latency_ms = self.
|
|
95
|
+
if self.db_record:
|
|
96
|
+
latency_ms = self.db_record.latency_ms
|
|
97
97
|
else:
|
|
98
98
|
latency_ms = await info.context.data_loaders.trace_fields.load(
|
|
99
|
-
(self.
|
|
99
|
+
(self.id, models.Trace.latency_ms),
|
|
100
100
|
)
|
|
101
101
|
return latency_ms
|
|
102
102
|
|
|
@@ -105,26 +105,26 @@ class Trace(Node):
|
|
|
105
105
|
self,
|
|
106
106
|
info: Info[Context, None],
|
|
107
107
|
) -> Annotated["Project", strawberry.lazy(".Project")]:
|
|
108
|
-
if self.
|
|
109
|
-
project_rowid = self.
|
|
108
|
+
if self.db_record:
|
|
109
|
+
project_rowid = self.db_record.project_rowid
|
|
110
110
|
else:
|
|
111
111
|
project_rowid = await info.context.data_loaders.trace_fields.load(
|
|
112
|
-
(self.
|
|
112
|
+
(self.id, models.Trace.project_rowid),
|
|
113
113
|
)
|
|
114
114
|
from phoenix.server.api.types.Project import Project
|
|
115
115
|
|
|
116
|
-
return Project(
|
|
116
|
+
return Project(id=project_rowid)
|
|
117
117
|
|
|
118
118
|
@strawberry.field
|
|
119
119
|
async def project_id(
|
|
120
120
|
self,
|
|
121
121
|
info: Info[Context, None],
|
|
122
122
|
) -> GlobalID:
|
|
123
|
-
if self.
|
|
124
|
-
project_rowid = self.
|
|
123
|
+
if self.db_record:
|
|
124
|
+
project_rowid = self.db_record.project_rowid
|
|
125
125
|
else:
|
|
126
126
|
project_rowid = await info.context.data_loaders.trace_fields.load(
|
|
127
|
-
(self.
|
|
127
|
+
(self.id, models.Trace.project_rowid),
|
|
128
128
|
)
|
|
129
129
|
from phoenix.server.api.types.Project import Project
|
|
130
130
|
|
|
@@ -135,11 +135,11 @@ class Trace(Node):
|
|
|
135
135
|
self,
|
|
136
136
|
info: Info[Context, None],
|
|
137
137
|
) -> Optional[GlobalID]:
|
|
138
|
-
if self.
|
|
139
|
-
project_session_rowid = self.
|
|
138
|
+
if self.db_record:
|
|
139
|
+
project_session_rowid = self.db_record.project_session_rowid
|
|
140
140
|
else:
|
|
141
141
|
project_session_rowid = await info.context.data_loaders.trace_fields.load(
|
|
142
|
-
(self.
|
|
142
|
+
(self.id, models.Trace.project_session_rowid),
|
|
143
143
|
)
|
|
144
144
|
if project_session_rowid is None:
|
|
145
145
|
return None
|
|
@@ -152,39 +152,40 @@ class Trace(Node):
|
|
|
152
152
|
self,
|
|
153
153
|
info: Info[Context, None],
|
|
154
154
|
) -> Union[Annotated["ProjectSession", lazy(".ProjectSession")], None]:
|
|
155
|
-
if self.
|
|
156
|
-
project_session_rowid = self.
|
|
155
|
+
if self.db_record:
|
|
156
|
+
project_session_rowid = self.db_record.project_session_rowid
|
|
157
157
|
else:
|
|
158
158
|
project_session_rowid = await info.context.data_loaders.trace_fields.load(
|
|
159
|
-
(self.
|
|
159
|
+
(self.id, models.Trace.project_session_rowid),
|
|
160
160
|
)
|
|
161
161
|
if project_session_rowid is None:
|
|
162
162
|
return None
|
|
163
|
-
from phoenix.server.api.types.ProjectSession import to_gql_project_session
|
|
164
163
|
|
|
165
164
|
stmt = select(models.ProjectSession).filter_by(id=project_session_rowid)
|
|
166
165
|
async with info.context.db() as session:
|
|
167
166
|
project_session = await session.scalar(stmt)
|
|
168
167
|
if project_session is None:
|
|
169
168
|
return None
|
|
170
|
-
|
|
169
|
+
from .ProjectSession import ProjectSession
|
|
170
|
+
|
|
171
|
+
return ProjectSession(id=project_session.id, db_record=project_session)
|
|
171
172
|
|
|
172
173
|
@strawberry.field
|
|
173
174
|
async def root_span(
|
|
174
175
|
self,
|
|
175
176
|
info: Info[Context, None],
|
|
176
177
|
) -> Optional[Span]:
|
|
177
|
-
span_rowid = await info.context.data_loaders.trace_root_spans.load(self.
|
|
178
|
+
span_rowid = await info.context.data_loaders.trace_root_spans.load(self.id)
|
|
178
179
|
if span_rowid is None:
|
|
179
180
|
return None
|
|
180
|
-
return Span(
|
|
181
|
+
return Span(id=span_rowid)
|
|
181
182
|
|
|
182
183
|
@strawberry.field
|
|
183
184
|
async def num_spans(
|
|
184
185
|
self,
|
|
185
186
|
info: Info[Context, None],
|
|
186
187
|
) -> int:
|
|
187
|
-
return await info.context.data_loaders.num_spans_per_trace.load(self.
|
|
188
|
+
return await info.context.data_loaders.num_spans_per_trace.load(self.id)
|
|
188
189
|
|
|
189
190
|
@strawberry.field
|
|
190
191
|
async def spans(
|
|
@@ -204,7 +205,7 @@ class Trace(Node):
|
|
|
204
205
|
stmt = (
|
|
205
206
|
select(models.Span.id)
|
|
206
207
|
.join(models.Trace)
|
|
207
|
-
.where(models.Trace.id == self.
|
|
208
|
+
.where(models.Trace.id == self.id)
|
|
208
209
|
# Sort descending because the root span tends to show up later
|
|
209
210
|
# in the ingestion process.
|
|
210
211
|
.order_by(desc(models.Span.id))
|
|
@@ -212,7 +213,7 @@ class Trace(Node):
|
|
|
212
213
|
)
|
|
213
214
|
async with info.context.db() as session:
|
|
214
215
|
span_rowids = await session.stream_scalars(stmt)
|
|
215
|
-
data = [Span(
|
|
216
|
+
data = [Span(id=span_rowid) async for span_rowid in span_rowids]
|
|
216
217
|
return connection_from_list(data=data, args=args)
|
|
217
218
|
|
|
218
219
|
@strawberry.field(description="Annotations associated with the trace.") # type: ignore
|
|
@@ -222,7 +223,7 @@ class Trace(Node):
|
|
|
222
223
|
sort: Optional[TraceAnnotationSort] = None,
|
|
223
224
|
) -> list[TraceAnnotation]:
|
|
224
225
|
async with info.context.db() as session:
|
|
225
|
-
stmt = select(models.TraceAnnotation).filter_by(trace_rowid=self.
|
|
226
|
+
stmt = select(models.TraceAnnotation).filter_by(trace_rowid=self.id)
|
|
226
227
|
if sort:
|
|
227
228
|
sort_col = getattr(models.TraceAnnotation, sort.col.value)
|
|
228
229
|
if sort.dir is SortDir.desc:
|
|
@@ -232,7 +233,9 @@ class Trace(Node):
|
|
|
232
233
|
else:
|
|
233
234
|
stmt = stmt.order_by(models.TraceAnnotation.created_at.desc())
|
|
234
235
|
annotations = await session.scalars(stmt)
|
|
235
|
-
return [
|
|
236
|
+
return [
|
|
237
|
+
TraceAnnotation(id=annotation.id, db_record=annotation) for annotation in annotations
|
|
238
|
+
]
|
|
236
239
|
|
|
237
240
|
@strawberry.field(description="Summarizes each annotation (by name) associated with the trace") # type: ignore
|
|
238
241
|
async def trace_annotation_summaries(
|
|
@@ -257,9 +260,7 @@ class Trace(Node):
|
|
|
257
260
|
- data: A list of dictionaries with label statistics
|
|
258
261
|
"""
|
|
259
262
|
# Load all annotations for this span from the data loader
|
|
260
|
-
annotations = await info.context.data_loaders.trace_annotations_by_trace.load(
|
|
261
|
-
self.trace_rowid
|
|
262
|
-
)
|
|
263
|
+
annotations = await info.context.data_loaders.trace_annotations_by_trace.load(self.id)
|
|
263
264
|
|
|
264
265
|
# Apply filter if provided to narrow down the annotations
|
|
265
266
|
if filter:
|
|
@@ -296,7 +297,7 @@ class Trace(Node):
|
|
|
296
297
|
info: Info[Context, None],
|
|
297
298
|
) -> SpanCostSummary:
|
|
298
299
|
loader = info.context.data_loaders.span_cost_summary_by_trace
|
|
299
|
-
summary = await loader.load(self.
|
|
300
|
+
summary = await loader.load(self.id)
|
|
300
301
|
return SpanCostSummary(
|
|
301
302
|
prompt=CostBreakdown(
|
|
302
303
|
tokens=summary.prompt.tokens,
|
|
@@ -318,7 +319,7 @@ class Trace(Node):
|
|
|
318
319
|
info: Info[Context, None],
|
|
319
320
|
) -> list[SpanCostDetailSummaryEntry]:
|
|
320
321
|
loader = info.context.data_loaders.span_cost_detail_summary_entries_by_trace
|
|
321
|
-
entries = await loader.load(self.
|
|
322
|
+
entries = await loader.load(self.id)
|
|
322
323
|
return [
|
|
323
324
|
SpanCostDetailSummaryEntry(
|
|
324
325
|
token_type=entry.token_type,
|