arize-phoenix 12.7.1__py3-none-any.whl → 12.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +76 -73
- phoenix/config.py +131 -9
- phoenix/db/engines.py +127 -14
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/pg_config.py +10 -0
- phoenix/server/api/context.py +23 -0
- phoenix/server/api/dataloaders/__init__.py +6 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/helpers/playground_clients.py +3 -3
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
- phoenix/server/api/mutations/api_key_mutations.py +2 -15
- phoenix/server/api/mutations/chat_mutations.py +3 -2
- phoenix/server/api/mutations/dataset_label_mutations.py +109 -157
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
- phoenix/server/api/mutations/model_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
- phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
- phoenix/server/api/mutations/prompt_mutations.py +24 -117
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
- phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
- phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
- phoenix/server/api/mutations/user_mutations.py +4 -4
- phoenix/server/api/queries.py +80 -213
- phoenix/server/api/subscriptions.py +4 -4
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/Dataset.py +88 -48
- phoenix/server/api/types/DatasetExample.py +34 -30
- phoenix/server/api/types/DatasetLabel.py +47 -13
- phoenix/server/api/types/DatasetSplit.py +87 -21
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +182 -62
- phoenix/server/api/types/Experiment.py +146 -55
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
- phoenix/server/api/types/ExperimentRun.py +118 -61
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +70 -75
- phoenix/server/api/types/ProjectSession.py +69 -37
- phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +82 -44
- phoenix/server/api/types/PromptLabel.py +47 -13
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +116 -115
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +45 -44
- phoenix/server/api/types/TraceAnnotation.py +144 -48
- phoenix/server/api/types/User.py +103 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/app.py +29 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +2 -2
- phoenix/server/static/.vite/manifest.json +43 -43
- phoenix/server/static/assets/{components-BLK5vehh.js → components-v927s3NF.js} +471 -484
- phoenix/server/static/assets/{index-BP0Shd90.js → index-DrD9eSrN.js} +20 -16
- phoenix/server/static/assets/{pages-DIVgyYyy.js → pages-GVybXa_W.js} +754 -753
- phoenix/server/static/assets/{vendor-3BvTzoBp.js → vendor-D-csRHGZ.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-C6_oC0y8.js → vendor-arizeai-BJLCG_Gc.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-DPnZGAZA.js → vendor-codemirror-Cr963DyP.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-CjgSbsB0.js → vendor-recharts-DgmPLgIp.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-CJyhDG0E.js → vendor-shiki-wYOt1s7u.js} +1 -1
- phoenix/version.py +1 -1
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
|
+
from math import isfinite
|
|
1
2
|
from typing import TYPE_CHECKING, Annotated, Optional
|
|
2
3
|
|
|
3
4
|
import strawberry
|
|
4
|
-
from strawberry import Private
|
|
5
5
|
from strawberry.relay import Node, NodeID
|
|
6
6
|
from strawberry.scalars import JSON
|
|
7
7
|
from strawberry.types import Info
|
|
@@ -11,61 +11,157 @@ from phoenix.server.api.context import Context
|
|
|
11
11
|
from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
|
|
12
12
|
|
|
13
13
|
from .AnnotationSource import AnnotationSource
|
|
14
|
-
from .User import User, to_gql_user
|
|
15
14
|
|
|
16
15
|
if TYPE_CHECKING:
|
|
17
16
|
from .Trace import Trace
|
|
17
|
+
from .User import User
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@strawberry.type
|
|
21
21
|
class TraceAnnotation(Node):
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
22
|
+
id: NodeID[int]
|
|
23
|
+
db_record: strawberry.Private[Optional[models.TraceAnnotation]] = None
|
|
24
|
+
|
|
25
|
+
def __post_init__(self) -> None:
|
|
26
|
+
if self.db_record and self.id != self.db_record.id:
|
|
27
|
+
raise ValueError("TraceAnnotation ID mismatch")
|
|
28
|
+
|
|
29
|
+
@strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
|
|
30
|
+
async def name(
|
|
31
|
+
self,
|
|
32
|
+
info: Info[Context, None],
|
|
33
|
+
) -> str:
|
|
34
|
+
if self.db_record:
|
|
35
|
+
val = self.db_record.name
|
|
36
|
+
else:
|
|
37
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
38
|
+
(self.id, models.TraceAnnotation.name),
|
|
39
|
+
)
|
|
40
|
+
return val
|
|
41
|
+
|
|
42
|
+
@strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
|
|
43
|
+
async def annotator_kind(
|
|
44
|
+
self,
|
|
45
|
+
info: Info[Context, None],
|
|
46
|
+
) -> AnnotatorKind:
|
|
47
|
+
if self.db_record:
|
|
48
|
+
val = self.db_record.annotator_kind
|
|
49
|
+
else:
|
|
50
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
51
|
+
(self.id, models.TraceAnnotation.annotator_kind),
|
|
52
|
+
)
|
|
53
|
+
return AnnotatorKind(val)
|
|
54
|
+
|
|
55
|
+
@strawberry.field(
|
|
56
|
+
description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
|
|
57
|
+
) # type: ignore
|
|
58
|
+
async def label(
|
|
59
|
+
self,
|
|
60
|
+
info: Info[Context, None],
|
|
61
|
+
) -> Optional[str]:
|
|
62
|
+
if self.db_record:
|
|
63
|
+
val = self.db_record.label
|
|
64
|
+
else:
|
|
65
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
66
|
+
(self.id, models.TraceAnnotation.label),
|
|
67
|
+
)
|
|
68
|
+
return val
|
|
69
|
+
|
|
70
|
+
@strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
|
|
71
|
+
async def score(
|
|
72
|
+
self,
|
|
73
|
+
info: Info[Context, None],
|
|
74
|
+
) -> Optional[float]:
|
|
75
|
+
if self.db_record:
|
|
76
|
+
val = self.db_record.score
|
|
77
|
+
else:
|
|
78
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
79
|
+
(self.id, models.TraceAnnotation.score),
|
|
80
|
+
)
|
|
81
|
+
return val if val is not None and isfinite(val) else None
|
|
82
|
+
|
|
83
|
+
@strawberry.field(
|
|
84
|
+
description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
|
|
85
|
+
) # type: ignore
|
|
86
|
+
async def explanation(
|
|
87
|
+
self,
|
|
88
|
+
info: Info[Context, None],
|
|
89
|
+
) -> Optional[str]:
|
|
90
|
+
if self.db_record:
|
|
91
|
+
val = self.db_record.explanation
|
|
92
|
+
else:
|
|
93
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
94
|
+
(self.id, models.TraceAnnotation.explanation),
|
|
95
|
+
)
|
|
96
|
+
return val
|
|
97
|
+
|
|
98
|
+
@strawberry.field(description="Metadata about the annotation.") # type: ignore
|
|
99
|
+
async def metadata(
|
|
100
|
+
self,
|
|
101
|
+
info: Info[Context, None],
|
|
102
|
+
) -> JSON:
|
|
103
|
+
if self.db_record:
|
|
104
|
+
val = self.db_record.metadata_
|
|
105
|
+
else:
|
|
106
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
107
|
+
(self.id, models.TraceAnnotation.metadata_),
|
|
108
|
+
)
|
|
109
|
+
return val
|
|
110
|
+
|
|
111
|
+
@strawberry.field(description="The identifier of the annotation.") # type: ignore
|
|
112
|
+
async def identifier(
|
|
113
|
+
self,
|
|
114
|
+
info: Info[Context, None],
|
|
115
|
+
) -> str:
|
|
116
|
+
if self.db_record:
|
|
117
|
+
val = self.db_record.identifier
|
|
118
|
+
else:
|
|
119
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
120
|
+
(self.id, models.TraceAnnotation.identifier),
|
|
121
|
+
)
|
|
122
|
+
return val
|
|
123
|
+
|
|
124
|
+
@strawberry.field(description="The source of the annotation.") # type: ignore
|
|
125
|
+
async def source(
|
|
126
|
+
self,
|
|
127
|
+
info: Info[Context, None],
|
|
128
|
+
) -> AnnotationSource:
|
|
129
|
+
if self.db_record:
|
|
130
|
+
val = self.db_record.source
|
|
131
|
+
else:
|
|
132
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
133
|
+
(self.id, models.TraceAnnotation.source),
|
|
134
|
+
)
|
|
135
|
+
return AnnotationSource(val)
|
|
136
|
+
|
|
137
|
+
@strawberry.field(description="The trace associated with the annotation.") # type: ignore
|
|
138
|
+
async def trace(
|
|
139
|
+
self,
|
|
140
|
+
info: Info[Context, None],
|
|
141
|
+
) -> Annotated["Trace", strawberry.lazy(".Trace")]:
|
|
142
|
+
if self.db_record:
|
|
143
|
+
trace_rowid = self.db_record.trace_rowid
|
|
144
|
+
else:
|
|
145
|
+
trace_rowid = await info.context.data_loaders.trace_annotation_fields.load(
|
|
146
|
+
(self.id, models.TraceAnnotation.trace_rowid),
|
|
147
|
+
)
|
|
148
|
+
from .Trace import Trace
|
|
149
|
+
|
|
150
|
+
return Trace(id=trace_rowid)
|
|
151
|
+
|
|
152
|
+
@strawberry.field(description="The user that produced the annotation.") # type: ignore
|
|
41
153
|
async def user(
|
|
42
154
|
self,
|
|
43
155
|
info: Info[Context, None],
|
|
44
|
-
) -> Optional[User]:
|
|
45
|
-
if self.
|
|
156
|
+
) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
|
|
157
|
+
if self.db_record:
|
|
158
|
+
user_id = self.db_record.user_id
|
|
159
|
+
else:
|
|
160
|
+
user_id = await info.context.data_loaders.trace_annotation_fields.load(
|
|
161
|
+
(self.id, models.TraceAnnotation.user_id),
|
|
162
|
+
)
|
|
163
|
+
if user_id is None:
|
|
46
164
|
return None
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
return to_gql_user(user)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def to_gql_trace_annotation(
|
|
54
|
-
annotation: models.TraceAnnotation,
|
|
55
|
-
) -> TraceAnnotation:
|
|
56
|
-
"""
|
|
57
|
-
Converts an ORM trace annotation to a GraphQL TraceAnnotation.
|
|
58
|
-
"""
|
|
59
|
-
return TraceAnnotation(
|
|
60
|
-
id_attr=annotation.id,
|
|
61
|
-
user_id=annotation.user_id,
|
|
62
|
-
trace_rowid=annotation.trace_rowid,
|
|
63
|
-
name=annotation.name,
|
|
64
|
-
annotator_kind=AnnotatorKind(annotation.annotator_kind),
|
|
65
|
-
label=annotation.label,
|
|
66
|
-
score=annotation.score,
|
|
67
|
-
explanation=annotation.explanation,
|
|
68
|
-
metadata=annotation.metadata_,
|
|
69
|
-
identifier=annotation.identifier,
|
|
70
|
-
source=AnnotationSource(annotation.source),
|
|
71
|
-
)
|
|
165
|
+
from .User import User
|
|
166
|
+
|
|
167
|
+
return User(id=user_id)
|
phoenix/server/api/types/User.py
CHANGED
|
@@ -3,7 +3,6 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from sqlalchemy import select
|
|
6
|
-
from strawberry import Private
|
|
7
6
|
from strawberry.relay import Node, NodeID
|
|
8
7
|
from strawberry.types import Info
|
|
9
8
|
|
|
@@ -12,59 +11,130 @@ from phoenix.db import models
|
|
|
12
11
|
from phoenix.server.api.context import Context
|
|
13
12
|
from phoenix.server.api.exceptions import NotFound
|
|
14
13
|
from phoenix.server.api.types.AuthMethod import AuthMethod
|
|
15
|
-
from phoenix.server.api.types.UserApiKey import UserApiKey
|
|
14
|
+
from phoenix.server.api.types.UserApiKey import UserApiKey
|
|
16
15
|
|
|
17
16
|
from .UserRole import UserRole, to_gql_user_role
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
@strawberry.type
|
|
21
20
|
class User(Node):
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
21
|
+
id: NodeID[int]
|
|
22
|
+
db_record: strawberry.Private[Optional[models.User]] = None
|
|
23
|
+
|
|
24
|
+
def __post_init__(self) -> None:
|
|
25
|
+
if self.db_record and self.id != self.db_record.id:
|
|
26
|
+
raise ValueError("User ID mismatch")
|
|
27
|
+
|
|
28
|
+
@strawberry.field
|
|
29
|
+
async def password_needs_reset(
|
|
30
|
+
self,
|
|
31
|
+
info: Info[Context, None],
|
|
32
|
+
) -> bool:
|
|
33
|
+
if self.db_record:
|
|
34
|
+
val = self.db_record.reset_password
|
|
35
|
+
else:
|
|
36
|
+
val = await info.context.data_loaders.user_fields.load(
|
|
37
|
+
(self.id, models.User.reset_password),
|
|
38
|
+
)
|
|
39
|
+
return val
|
|
40
|
+
|
|
41
|
+
@strawberry.field
|
|
42
|
+
async def email(
|
|
43
|
+
self,
|
|
44
|
+
info: Info[Context, None],
|
|
45
|
+
) -> str:
|
|
46
|
+
if self.db_record:
|
|
47
|
+
val = self.db_record.email
|
|
48
|
+
else:
|
|
49
|
+
val = await info.context.data_loaders.user_fields.load(
|
|
50
|
+
(self.id, models.User.email),
|
|
51
|
+
)
|
|
52
|
+
return val
|
|
53
|
+
|
|
54
|
+
@strawberry.field
|
|
55
|
+
async def username(
|
|
56
|
+
self,
|
|
57
|
+
info: Info[Context, None],
|
|
58
|
+
) -> str:
|
|
59
|
+
if self.db_record:
|
|
60
|
+
val = self.db_record.username
|
|
61
|
+
else:
|
|
62
|
+
val = await info.context.data_loaders.user_fields.load(
|
|
63
|
+
(self.id, models.User.username),
|
|
64
|
+
)
|
|
65
|
+
return val
|
|
66
|
+
|
|
67
|
+
@strawberry.field
|
|
68
|
+
async def profile_picture_url(
|
|
69
|
+
self,
|
|
70
|
+
info: Info[Context, None],
|
|
71
|
+
) -> Optional[str]:
|
|
72
|
+
if self.db_record:
|
|
73
|
+
val = self.db_record.profile_picture_url
|
|
74
|
+
else:
|
|
75
|
+
val = await info.context.data_loaders.user_fields.load(
|
|
76
|
+
(self.id, models.User.profile_picture_url),
|
|
77
|
+
)
|
|
78
|
+
return val
|
|
79
|
+
|
|
80
|
+
@strawberry.field
|
|
81
|
+
async def created_at(
|
|
82
|
+
self,
|
|
83
|
+
info: Info[Context, None],
|
|
84
|
+
) -> datetime:
|
|
85
|
+
if self.db_record:
|
|
86
|
+
val = self.db_record.created_at
|
|
87
|
+
else:
|
|
88
|
+
val = await info.context.data_loaders.user_fields.load(
|
|
89
|
+
(self.id, models.User.created_at),
|
|
90
|
+
)
|
|
91
|
+
return val
|
|
92
|
+
|
|
93
|
+
@strawberry.field
|
|
94
|
+
async def auth_method(
|
|
95
|
+
self,
|
|
96
|
+
info: Info[Context, None],
|
|
97
|
+
) -> AuthMethod:
|
|
98
|
+
if self.db_record:
|
|
99
|
+
val = self.db_record.auth_method
|
|
100
|
+
else:
|
|
101
|
+
val = await info.context.data_loaders.user_fields.load(
|
|
102
|
+
(self.id, models.User.auth_method),
|
|
103
|
+
)
|
|
104
|
+
return AuthMethod(val)
|
|
30
105
|
|
|
31
106
|
@strawberry.field
|
|
32
107
|
async def role(self, info: Info[Context, None]) -> UserRole:
|
|
33
|
-
|
|
108
|
+
if self.db_record:
|
|
109
|
+
user_role_id = self.db_record.user_role_id
|
|
110
|
+
else:
|
|
111
|
+
user_role_id = await info.context.data_loaders.user_fields.load(
|
|
112
|
+
(self.id, models.User.user_role_id),
|
|
113
|
+
)
|
|
114
|
+
role = await info.context.data_loaders.user_roles.load(user_role_id)
|
|
34
115
|
if role is None:
|
|
35
|
-
raise NotFound(f"User role with id {
|
|
116
|
+
raise NotFound(f"User role with id {user_role_id} not found")
|
|
36
117
|
return to_gql_user_role(role)
|
|
37
118
|
|
|
38
119
|
@strawberry.field
|
|
39
120
|
async def api_keys(self, info: Info[Context, None]) -> list[UserApiKey]:
|
|
40
121
|
async with info.context.db() as session:
|
|
41
122
|
api_keys = await session.scalars(
|
|
42
|
-
select(models.ApiKey).where(models.ApiKey.user_id == self.
|
|
123
|
+
select(models.ApiKey).where(models.ApiKey.user_id == self.id)
|
|
43
124
|
)
|
|
44
|
-
return [
|
|
125
|
+
return [UserApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
|
|
45
126
|
|
|
46
127
|
@strawberry.field
|
|
47
|
-
async def is_management_user(self) -> bool:
|
|
128
|
+
async def is_management_user(self, info: Info[Context, None]) -> bool:
|
|
48
129
|
initial_admins = get_env_admins()
|
|
49
130
|
# this field is only visible to initial admins as they are the ones likely to have access to
|
|
50
131
|
# a management interface / the phoenix environment.
|
|
51
|
-
if self.
|
|
132
|
+
if self.db_record:
|
|
133
|
+
email = self.db_record.email
|
|
134
|
+
else:
|
|
135
|
+
email = await info.context.data_loaders.user_fields.load(
|
|
136
|
+
(self.id, models.User.email),
|
|
137
|
+
)
|
|
138
|
+
if email in initial_admins or email == "admin@localhost":
|
|
52
139
|
return True
|
|
53
140
|
return False
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def to_gql_user(user: models.User, api_keys: Optional[list[models.ApiKey]] = None) -> User:
|
|
57
|
-
"""
|
|
58
|
-
Converts an ORM user to a GraphQL user.
|
|
59
|
-
"""
|
|
60
|
-
assert user.auth_method is not None
|
|
61
|
-
return User(
|
|
62
|
-
id_attr=user.id,
|
|
63
|
-
password_needs_reset=user.reset_password,
|
|
64
|
-
username=user.username,
|
|
65
|
-
email=user.email,
|
|
66
|
-
profile_picture_url=user.profile_picture_url,
|
|
67
|
-
created_at=user.created_at,
|
|
68
|
-
user_role_id=user.user_role_id,
|
|
69
|
-
auth_method=AuthMethod(user.auth_method),
|
|
70
|
-
)
|
|
@@ -1,14 +1,13 @@
|
|
|
1
|
-
from
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
2
3
|
|
|
3
4
|
import strawberry
|
|
4
|
-
from strawberry import Private
|
|
5
5
|
from strawberry.relay import Node, NodeID
|
|
6
6
|
from strawberry.types import Info
|
|
7
7
|
from typing_extensions import Annotated
|
|
8
8
|
|
|
9
9
|
from phoenix.db.models import ApiKey as OrmApiKey
|
|
10
10
|
from phoenix.server.api.context import Context
|
|
11
|
-
from phoenix.server.api.exceptions import NotFound
|
|
12
11
|
|
|
13
12
|
from .ApiKey import ApiKey
|
|
14
13
|
|
|
@@ -18,28 +17,76 @@ if TYPE_CHECKING:
|
|
|
18
17
|
|
|
19
18
|
@strawberry.type
|
|
20
19
|
class UserApiKey(ApiKey, Node):
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
id: NodeID[int]
|
|
21
|
+
db_record: strawberry.Private[Optional[OrmApiKey]] = None
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
if self.db_record and self.id != self.db_record.id:
|
|
25
|
+
raise ValueError("UserApiKey ID mismatch")
|
|
26
|
+
|
|
27
|
+
@strawberry.field(description="Name of the API key.") # type: ignore
|
|
28
|
+
async def name(
|
|
29
|
+
self,
|
|
30
|
+
info: Info[Context, None],
|
|
31
|
+
) -> str:
|
|
32
|
+
if self.db_record:
|
|
33
|
+
val = self.db_record.name
|
|
34
|
+
else:
|
|
35
|
+
val = await info.context.data_loaders.user_api_key_fields.load(
|
|
36
|
+
(self.id, OrmApiKey.name),
|
|
37
|
+
)
|
|
38
|
+
return val
|
|
39
|
+
|
|
40
|
+
@strawberry.field(description="Description of the API key.") # type: ignore
|
|
41
|
+
async def description(
|
|
42
|
+
self,
|
|
43
|
+
info: Info[Context, None],
|
|
44
|
+
) -> Optional[str]:
|
|
45
|
+
if self.db_record:
|
|
46
|
+
val = self.db_record.description
|
|
47
|
+
else:
|
|
48
|
+
val = await info.context.data_loaders.user_api_key_fields.load(
|
|
49
|
+
(self.id, OrmApiKey.description),
|
|
50
|
+
)
|
|
51
|
+
return val
|
|
52
|
+
|
|
53
|
+
@strawberry.field(description="The date and time the API key was created.") # type: ignore
|
|
54
|
+
async def created_at(
|
|
55
|
+
self,
|
|
56
|
+
info: Info[Context, None],
|
|
57
|
+
) -> datetime:
|
|
58
|
+
if self.db_record:
|
|
59
|
+
val = self.db_record.created_at
|
|
60
|
+
else:
|
|
61
|
+
val = await info.context.data_loaders.user_api_key_fields.load(
|
|
62
|
+
(self.id, OrmApiKey.created_at),
|
|
63
|
+
)
|
|
64
|
+
return val
|
|
65
|
+
|
|
66
|
+
@strawberry.field(description="The date and time the API key will expire.") # type: ignore
|
|
67
|
+
async def expires_at(
|
|
68
|
+
self,
|
|
69
|
+
info: Info[Context, None],
|
|
70
|
+
) -> Optional[datetime]:
|
|
71
|
+
if self.db_record:
|
|
72
|
+
val = self.db_record.expires_at
|
|
73
|
+
else:
|
|
74
|
+
val = await info.context.data_loaders.user_api_key_fields.load(
|
|
75
|
+
(self.id, OrmApiKey.expires_at),
|
|
76
|
+
)
|
|
77
|
+
return val
|
|
23
78
|
|
|
24
79
|
@strawberry.field
|
|
25
|
-
async def user(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
return UserApiKey(
|
|
39
|
-
id_attr=api_key.id,
|
|
40
|
-
user_id=api_key.user_id,
|
|
41
|
-
name=api_key.name,
|
|
42
|
-
description=api_key.description,
|
|
43
|
-
created_at=api_key.created_at,
|
|
44
|
-
expires_at=api_key.expires_at,
|
|
45
|
-
)
|
|
80
|
+
async def user(
|
|
81
|
+
self,
|
|
82
|
+
info: Info[Context, None],
|
|
83
|
+
) -> Annotated["User", strawberry.lazy(".User")]:
|
|
84
|
+
if self.db_record:
|
|
85
|
+
user_id = self.db_record.user_id
|
|
86
|
+
else:
|
|
87
|
+
user_id = await info.context.data_loaders.user_api_key_fields.load(
|
|
88
|
+
(self.id, OrmApiKey.user_id),
|
|
89
|
+
)
|
|
90
|
+
from .User import User
|
|
91
|
+
|
|
92
|
+
return User(id=user_id)
|
phoenix/server/app.py
CHANGED
|
@@ -103,6 +103,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
103
103
|
ExperimentRepeatedRunGroupsDataLoader,
|
|
104
104
|
ExperimentRunAnnotations,
|
|
105
105
|
ExperimentRunCountsDataLoader,
|
|
106
|
+
ExperimentRunsByExperimentAndExampleDataLoader,
|
|
106
107
|
ExperimentSequenceNumberDataLoader,
|
|
107
108
|
LastUsedTimesByGenerativeModelIdDataLoader,
|
|
108
109
|
LatencyMsQuantileDataLoader,
|
|
@@ -139,6 +140,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
139
140
|
SpanProjectsDataLoader,
|
|
140
141
|
TableFieldsDataLoader,
|
|
141
142
|
TokenCountDataLoader,
|
|
143
|
+
TokenPricesByModelDataLoader,
|
|
142
144
|
TraceAnnotationsByTraceDataLoader,
|
|
143
145
|
TraceByTraceIdsDataLoader,
|
|
144
146
|
TraceRetentionPolicyIdByProjectIdDataLoader,
|
|
@@ -712,13 +714,18 @@ def create_graphql_router(
|
|
|
712
714
|
),
|
|
713
715
|
average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(db),
|
|
714
716
|
dataset_dataset_splits=DatasetDatasetSplitsDataLoader(db),
|
|
717
|
+
dataset_example_fields=TableFieldsDataLoader(db, models.DatasetExample),
|
|
715
718
|
dataset_example_revisions=DatasetExampleRevisionsDataLoader(db),
|
|
716
719
|
dataset_example_spans=DatasetExampleSpansDataLoader(db),
|
|
717
720
|
dataset_examples_and_versions_by_experiment_run=DatasetExamplesAndVersionsByExperimentRunDataLoader(
|
|
718
721
|
db
|
|
719
722
|
),
|
|
720
723
|
dataset_example_splits=DatasetExampleSplitsDataLoader(db),
|
|
724
|
+
dataset_fields=TableFieldsDataLoader(db, models.Dataset),
|
|
725
|
+
dataset_split_fields=TableFieldsDataLoader(db, models.DatasetSplit),
|
|
726
|
+
dataset_version_fields=TableFieldsDataLoader(db, models.DatasetVersion),
|
|
721
727
|
dataset_labels=DatasetLabelsDataLoader(db),
|
|
728
|
+
dataset_label_fields=TableFieldsDataLoader(db, models.DatasetLabel),
|
|
722
729
|
document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(
|
|
723
730
|
db,
|
|
724
731
|
cache_map=(
|
|
@@ -727,6 +734,7 @@ def create_graphql_router(
|
|
|
727
734
|
else None
|
|
728
735
|
),
|
|
729
736
|
),
|
|
737
|
+
document_annotation_fields=TableFieldsDataLoader(db, models.DocumentAnnotation),
|
|
730
738
|
document_evaluations=DocumentEvaluationsDataLoader(db),
|
|
731
739
|
document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(db),
|
|
732
740
|
annotation_summaries=AnnotationSummaryDataLoader(
|
|
@@ -738,13 +746,22 @@ def create_graphql_router(
|
|
|
738
746
|
experiment_annotation_summaries=ExperimentAnnotationSummaryDataLoader(db),
|
|
739
747
|
experiment_dataset_splits=ExperimentDatasetSplitsDataLoader(db),
|
|
740
748
|
experiment_error_rates=ExperimentErrorRatesDataLoader(db),
|
|
749
|
+
experiment_fields=TableFieldsDataLoader(db, models.Experiment),
|
|
741
750
|
experiment_repeated_run_group_annotation_summaries=ExperimentRepeatedRunGroupAnnotationSummariesDataLoader(
|
|
742
751
|
db
|
|
743
752
|
),
|
|
744
753
|
experiment_repeated_run_groups=ExperimentRepeatedRunGroupsDataLoader(db),
|
|
754
|
+
experiment_run_annotation_fields=TableFieldsDataLoader(
|
|
755
|
+
db, models.ExperimentRunAnnotation
|
|
756
|
+
),
|
|
745
757
|
experiment_run_annotations=ExperimentRunAnnotations(db),
|
|
746
758
|
experiment_run_counts=ExperimentRunCountsDataLoader(db),
|
|
759
|
+
experiment_run_fields=TableFieldsDataLoader(db, models.ExperimentRun),
|
|
760
|
+
experiment_runs_by_experiment_and_example=ExperimentRunsByExperimentAndExampleDataLoader(
|
|
761
|
+
db
|
|
762
|
+
),
|
|
747
763
|
experiment_sequence_number=ExperimentSequenceNumberDataLoader(db),
|
|
764
|
+
generative_model_fields=TableFieldsDataLoader(db, models.GenerativeModel),
|
|
748
765
|
last_used_times_by_generative_model_id=LastUsedTimesByGenerativeModelIdDataLoader(
|
|
749
766
|
db
|
|
750
767
|
),
|
|
@@ -768,7 +785,14 @@ def create_graphql_router(
|
|
|
768
785
|
projects_by_trace_retention_policy_id=ProjectIdsByTraceRetentionPolicyIdDataLoader(
|
|
769
786
|
db
|
|
770
787
|
),
|
|
788
|
+
prompt_fields=TableFieldsDataLoader(db, models.Prompt),
|
|
789
|
+
prompt_label_fields=TableFieldsDataLoader(db, models.PromptLabel),
|
|
771
790
|
prompt_version_sequence_number=PromptVersionSequenceNumberDataLoader(db),
|
|
791
|
+
prompt_version_tag_fields=TableFieldsDataLoader(db, models.PromptVersionTag),
|
|
792
|
+
project_session_annotation_fields=TableFieldsDataLoader(
|
|
793
|
+
db, models.ProjectSessionAnnotation
|
|
794
|
+
),
|
|
795
|
+
project_session_fields=TableFieldsDataLoader(db, models.ProjectSession),
|
|
772
796
|
record_counts=RecordCountDataLoader(
|
|
773
797
|
db,
|
|
774
798
|
cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
|
|
@@ -780,6 +804,7 @@ def create_graphql_router(
|
|
|
780
804
|
session_num_traces_with_error=SessionNumTracesWithErrorDataLoader(db),
|
|
781
805
|
session_token_usages=SessionTokenUsagesDataLoader(db),
|
|
782
806
|
session_trace_latency_ms_quantile=SessionTraceLatencyMsQuantileDataLoader(db),
|
|
807
|
+
span_annotation_fields=TableFieldsDataLoader(db, models.SpanAnnotation),
|
|
783
808
|
span_annotations=SpanAnnotationsDataLoader(db),
|
|
784
809
|
span_fields=TableFieldsDataLoader(db, models.Span),
|
|
785
810
|
span_by_id=SpanByIdDataLoader(db),
|
|
@@ -820,6 +845,8 @@ def create_graphql_router(
|
|
|
820
845
|
db,
|
|
821
846
|
cache_map=cache_for_dataloaders.token_count if cache_for_dataloaders else None,
|
|
822
847
|
),
|
|
848
|
+
token_prices_by_model=TokenPricesByModelDataLoader(db),
|
|
849
|
+
trace_annotation_fields=TableFieldsDataLoader(db, models.TraceAnnotation),
|
|
823
850
|
trace_annotations_by_trace=TraceAnnotationsByTraceDataLoader(db),
|
|
824
851
|
trace_by_trace_ids=TraceByTraceIdsDataLoader(db),
|
|
825
852
|
trace_fields=TableFieldsDataLoader(db, models.Trace),
|
|
@@ -832,6 +859,8 @@ def create_graphql_router(
|
|
|
832
859
|
trace_root_spans=TraceRootSpansDataLoader(db),
|
|
833
860
|
project_by_name=ProjectByNameDataLoader(db),
|
|
834
861
|
users=UsersDataLoader(db),
|
|
862
|
+
user_api_key_fields=TableFieldsDataLoader(db, models.ApiKey),
|
|
863
|
+
user_fields=TableFieldsDataLoader(db, models.User),
|
|
835
864
|
user_roles=UserRolesDataLoader(db),
|
|
836
865
|
),
|
|
837
866
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
@@ -784,7 +784,7 @@
|
|
|
784
784
|
"token_type": "output"
|
|
785
785
|
},
|
|
786
786
|
{
|
|
787
|
-
"base_rate":
|
|
787
|
+
"base_rate": 3e-8,
|
|
788
788
|
"is_prompt": true,
|
|
789
789
|
"token_type": "cache_read"
|
|
790
790
|
},
|
|
@@ -1027,7 +1027,7 @@
|
|
|
1027
1027
|
"token_type": "output"
|
|
1028
1028
|
},
|
|
1029
1029
|
{
|
|
1030
|
-
"base_rate":
|
|
1030
|
+
"base_rate": 1.25e-7,
|
|
1031
1031
|
"is_prompt": true,
|
|
1032
1032
|
"token_type": "cache_read"
|
|
1033
1033
|
}
|