arize-phoenix 4.35.2__py3-none-any.whl → 5.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/METADATA +10 -12
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/RECORD +92 -79
- phoenix/__init__.py +86 -0
- phoenix/auth.py +275 -14
- phoenix/config.py +369 -27
- phoenix/db/alembic.ini +0 -34
- phoenix/db/engines.py +27 -10
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +112 -0
- phoenix/db/insertion/dataset.py +0 -1
- phoenix/db/insertion/types.py +1 -1
- phoenix/db/migrate.py +3 -3
- phoenix/db/migrations/env.py +0 -7
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/models.py +145 -60
- phoenix/experiments/evaluators/code_evaluators.py +9 -3
- phoenix/experiments/functions.py +1 -4
- phoenix/inferences/fixtures.py +0 -1
- phoenix/inferences/inferences.py +0 -1
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +0 -1
- phoenix/otel/settings.py +4 -4
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +32 -0
- phoenix/server/api/context.py +50 -2
- phoenix/server/api/dataloaders/__init__.py +4 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +7 -0
- phoenix/server/api/mutations/__init__.py +0 -2
- phoenix/server/api/mutations/api_key_mutations.py +104 -86
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/project_mutations.py +3 -3
- phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/user_mutations.py +282 -42
- phoenix/server/api/openapi/schema.py +2 -2
- phoenix/server/api/queries.py +48 -39
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +456 -0
- phoenix/server/api/routers/v1/__init__.py +38 -16
- phoenix/server/api/routers/v1/datasets.py +0 -1
- phoenix/server/api/types/ApiKey.py +11 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/User.py +48 -4
- phoenix/server/api/types/UserApiKey.py +35 -1
- phoenix/server/api/types/UserRole.py +7 -0
- phoenix/server/app.py +105 -34
- phoenix/server/bearer_auth.py +161 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +26 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +6 -0
- phoenix/server/jwt_store.py +504 -0
- phoenix/server/main.py +61 -30
- phoenix/server/oauth2.py +51 -0
- phoenix/server/prometheus.py +20 -0
- phoenix/server/rate_limiters.py +191 -0
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-Dte7_KRd.js → components-REunxTt6.js} +348 -286
- phoenix/server/static/assets/index-DAPJxlCw.js +101 -0
- phoenix/server/static/assets/{pages-CnTvEGEN.js → pages-1VrMk2pW.js} +559 -291
- phoenix/server/static/assets/{vendor-BC3OPQuM.js → vendor-B5IC0ivG.js} +5 -5
- phoenix/server/static/assets/{vendor-arizeai-NjB3cZzD.js → vendor-arizeai-aFbT4kl1.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-gE_JCOgX.js → vendor-codemirror-BEGorXSV.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-BXLYwcXF.js → vendor-recharts-6nUU7gU_.js} +1 -1
- phoenix/server/telemetry.py +2 -2
- phoenix/server/templates/index.html +1 -0
- phoenix/server/types.py +157 -1
- phoenix/services.py +0 -1
- phoenix/session/client.py +7 -3
- phoenix/session/evaluation.py +0 -1
- phoenix/session/session.py +0 -1
- phoenix/settings.py +9 -0
- phoenix/trace/exporter.py +0 -1
- phoenix/trace/fixtures.py +0 -2
- phoenix/utilities/client.py +16 -0
- phoenix/utilities/logging.py +9 -1
- phoenix/utilities/re.py +3 -3
- phoenix/version.py +1 -1
- phoenix/db/migrations/future_versions/README.md +0 -4
- phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +0 -293
- phoenix/db/migrations/versions/.gitignore +0 -1
- phoenix/server/api/mutations/auth.py +0 -18
- phoenix/server/api/mutations/auth_mutations.py +0 -65
- phoenix/server/static/assets/index-fq1-hCK4.js +0 -100
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -103
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -31
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,27 +1,79 @@
|
|
|
1
|
-
import
|
|
2
|
-
from
|
|
1
|
+
import secrets
|
|
2
|
+
from contextlib import AsyncExitStack
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from typing import List, Literal, Optional, Tuple
|
|
3
5
|
|
|
4
6
|
import strawberry
|
|
5
|
-
from sqlalchemy import
|
|
7
|
+
from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select
|
|
8
|
+
from sqlalchemy.orm import joinedload
|
|
6
9
|
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
|
|
10
|
+
from strawberry import UNSET
|
|
11
|
+
from strawberry.relay import GlobalID
|
|
7
12
|
from strawberry.types import Info
|
|
8
13
|
|
|
9
|
-
from phoenix.auth import
|
|
10
|
-
|
|
14
|
+
from phoenix.auth import (
|
|
15
|
+
DEFAULT_ADMIN_EMAIL,
|
|
16
|
+
DEFAULT_ADMIN_USERNAME,
|
|
17
|
+
DEFAULT_SECRET_LENGTH,
|
|
18
|
+
PASSWORD_REQUIREMENTS,
|
|
19
|
+
PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
|
|
20
|
+
PHOENIX_REFRESH_TOKEN_COOKIE_NAME,
|
|
21
|
+
validate_email_format,
|
|
22
|
+
validate_password_format,
|
|
23
|
+
)
|
|
24
|
+
from phoenix.db import enums, models
|
|
25
|
+
from phoenix.server.api.auth import IsAdmin, IsNotReadOnly
|
|
11
26
|
from phoenix.server.api.context import Context
|
|
27
|
+
from phoenix.server.api.exceptions import Conflict, NotFound, Unauthorized
|
|
12
28
|
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
|
|
13
|
-
from phoenix.server.api.types.
|
|
14
|
-
from phoenix.server.api.types.
|
|
29
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
30
|
+
from phoenix.server.api.types.User import User, to_gql_user
|
|
31
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
32
|
+
from phoenix.server.types import AccessTokenId, ApiKeyId, PasswordResetTokenId, RefreshTokenId
|
|
15
33
|
|
|
16
34
|
|
|
17
35
|
@strawberry.input
|
|
18
36
|
class CreateUserInput:
|
|
19
37
|
email: str
|
|
20
|
-
username:
|
|
38
|
+
username: str
|
|
21
39
|
password: str
|
|
22
40
|
role: UserRoleInput
|
|
23
41
|
|
|
24
42
|
|
|
43
|
+
@strawberry.input
|
|
44
|
+
class PatchViewerInput:
|
|
45
|
+
new_username: Optional[str] = UNSET
|
|
46
|
+
new_password: Optional[str] = UNSET
|
|
47
|
+
current_password: Optional[str] = UNSET
|
|
48
|
+
|
|
49
|
+
def __post_init__(self) -> None:
|
|
50
|
+
if not self.new_username and not self.new_password:
|
|
51
|
+
raise ValueError("At least one field must be set")
|
|
52
|
+
if self.new_password and not self.current_password:
|
|
53
|
+
raise ValueError("current_password is required when modifying password")
|
|
54
|
+
if self.new_password:
|
|
55
|
+
PASSWORD_REQUIREMENTS.validate(self.new_password)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@strawberry.input
|
|
59
|
+
class PatchUserInput:
|
|
60
|
+
user_id: GlobalID
|
|
61
|
+
new_role: Optional[UserRoleInput] = UNSET
|
|
62
|
+
new_username: Optional[str] = UNSET
|
|
63
|
+
new_password: Optional[str] = UNSET
|
|
64
|
+
|
|
65
|
+
def __post_init__(self) -> None:
|
|
66
|
+
if not self.new_role and not self.new_username and not self.new_password:
|
|
67
|
+
raise ValueError("At least one field must be set")
|
|
68
|
+
if self.new_password:
|
|
69
|
+
PASSWORD_REQUIREMENTS.validate(self.new_password)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@strawberry.input
|
|
73
|
+
class DeleteUsersInput:
|
|
74
|
+
user_ids: List[GlobalID]
|
|
75
|
+
|
|
76
|
+
|
|
25
77
|
@strawberry.type
|
|
26
78
|
class UserMutationPayload:
|
|
27
79
|
user: User
|
|
@@ -29,7 +81,7 @@ class UserMutationPayload:
|
|
|
29
81
|
|
|
30
82
|
@strawberry.type
|
|
31
83
|
class UserMutationMixin:
|
|
32
|
-
@strawberry.mutation
|
|
84
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
|
|
33
85
|
async def create_user(
|
|
34
86
|
self,
|
|
35
87
|
info: Info[Context, None],
|
|
@@ -37,47 +89,235 @@ class UserMutationMixin:
|
|
|
37
89
|
) -> UserMutationPayload:
|
|
38
90
|
validate_email_format(email := input.email)
|
|
39
91
|
validate_password_format(password := input.password)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
92
|
+
salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
|
|
93
|
+
password_hash = await info.context.hash_password(password, salt)
|
|
94
|
+
user = models.User(
|
|
95
|
+
reset_password=True,
|
|
96
|
+
username=input.username,
|
|
97
|
+
email=email,
|
|
98
|
+
password_hash=password_hash,
|
|
99
|
+
password_salt=salt,
|
|
43
100
|
)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
101
|
+
async with AsyncExitStack() as stack:
|
|
102
|
+
session = await stack.enter_async_context(info.context.db())
|
|
103
|
+
user_role_id = await session.scalar(_select_role_id_by_name(input.role.value))
|
|
104
|
+
if user_role_id is None:
|
|
105
|
+
raise NotFound(f"Role {input.role.value} not found")
|
|
106
|
+
stack.enter_context(session.no_autoflush)
|
|
107
|
+
user.user_role_id = user_role_id
|
|
108
|
+
session.add(user)
|
|
109
|
+
try:
|
|
110
|
+
await session.flush()
|
|
111
|
+
except IntegrityError as error:
|
|
112
|
+
raise Conflict(_user_operation_error_message(error))
|
|
113
|
+
return UserMutationPayload(user=to_gql_user(user))
|
|
114
|
+
|
|
115
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
|
|
116
|
+
async def patch_user(
|
|
117
|
+
self,
|
|
118
|
+
info: Info[Context, None],
|
|
119
|
+
input: PatchUserInput,
|
|
120
|
+
) -> UserMutationPayload:
|
|
121
|
+
assert (request := info.context.request)
|
|
122
|
+
assert isinstance(request.user, PhoenixUser)
|
|
123
|
+
assert (requester_id := int(request.user.identity))
|
|
124
|
+
user_id = from_global_id_with_expected_type(input.user_id, expected_type_name=User.__name__)
|
|
125
|
+
async with AsyncExitStack() as stack:
|
|
126
|
+
session = await stack.enter_async_context(info.context.db())
|
|
127
|
+
requester = await session.scalar(_select_user_by_id(requester_id))
|
|
128
|
+
assert requester
|
|
129
|
+
if not (user := await session.scalar(_select_user_by_id(user_id))):
|
|
130
|
+
raise NotFound("User not found")
|
|
131
|
+
stack.enter_context(session.no_autoflush)
|
|
132
|
+
if input.new_role:
|
|
133
|
+
if user.email == DEFAULT_ADMIN_EMAIL:
|
|
134
|
+
raise Unauthorized("Cannot modify role for the default admin user")
|
|
135
|
+
user_role_id = await session.scalar(_select_role_id_by_name(input.new_role.value))
|
|
136
|
+
if user_role_id is None:
|
|
137
|
+
raise NotFound(f"Role {input.new_role.value} not found")
|
|
138
|
+
user.user_role_id = user_role_id
|
|
139
|
+
if password := input.new_password:
|
|
140
|
+
if user.auth_method != enums.AuthMethod.LOCAL.value:
|
|
141
|
+
raise Conflict("Cannot modify password for non-local user")
|
|
142
|
+
validate_password_format(password)
|
|
143
|
+
user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
|
|
144
|
+
user.password_hash = await info.context.hash_password(password, user.password_salt)
|
|
145
|
+
user.reset_password = True
|
|
146
|
+
if username := input.new_username:
|
|
147
|
+
user.username = username
|
|
148
|
+
assert user in session.dirty
|
|
149
|
+
try:
|
|
150
|
+
await session.flush()
|
|
151
|
+
except IntegrityError as error:
|
|
152
|
+
raise Conflict(_user_operation_error_message(error, "modify"))
|
|
153
|
+
assert user
|
|
154
|
+
if input.new_password:
|
|
155
|
+
await info.context.log_out(user.id)
|
|
156
|
+
return UserMutationPayload(user=to_gql_user(user))
|
|
157
|
+
|
|
158
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
159
|
+
async def patch_viewer(
|
|
160
|
+
self,
|
|
161
|
+
info: Info[Context, None],
|
|
162
|
+
input: PatchViewerInput,
|
|
163
|
+
) -> UserMutationPayload:
|
|
164
|
+
assert (request := info.context.request)
|
|
165
|
+
assert isinstance(user := request.user, PhoenixUser)
|
|
166
|
+
user_id = int(user.identity)
|
|
167
|
+
async with AsyncExitStack() as stack:
|
|
168
|
+
session = await stack.enter_async_context(info.context.db())
|
|
169
|
+
if not (user := await session.scalar(_select_user_by_id(user_id))):
|
|
170
|
+
raise NotFound("User not found")
|
|
171
|
+
stack.enter_context(session.no_autoflush)
|
|
172
|
+
if password := input.new_password:
|
|
173
|
+
if user.auth_method != enums.AuthMethod.LOCAL.value:
|
|
174
|
+
raise Conflict("Cannot modify password for non-local user")
|
|
175
|
+
if not (
|
|
176
|
+
current_password := input.current_password
|
|
177
|
+
) or not await info.context.is_valid_password(current_password, user):
|
|
178
|
+
raise Conflict("Valid current password is required to modify password")
|
|
179
|
+
validate_password_format(password)
|
|
180
|
+
user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
|
|
181
|
+
user.password_hash = await info.context.hash_password(password, user.password_salt)
|
|
182
|
+
user.reset_password = False
|
|
183
|
+
if username := input.new_username:
|
|
184
|
+
user.username = username
|
|
185
|
+
assert user in session.dirty
|
|
186
|
+
user.updated_at = datetime.now(timezone.utc)
|
|
187
|
+
try:
|
|
188
|
+
await session.flush()
|
|
189
|
+
except IntegrityError as error:
|
|
190
|
+
raise Conflict(_user_operation_error_message(error, "modify"))
|
|
191
|
+
assert user
|
|
192
|
+
if input.new_password:
|
|
193
|
+
await info.context.log_out(user.id)
|
|
194
|
+
response = info.context.get_response()
|
|
195
|
+
response.delete_cookie(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)
|
|
196
|
+
response.delete_cookie(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)
|
|
197
|
+
return UserMutationPayload(user=to_gql_user(user))
|
|
198
|
+
|
|
199
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
|
|
200
|
+
async def delete_users(
|
|
201
|
+
self,
|
|
202
|
+
info: Info[Context, None],
|
|
203
|
+
input: DeleteUsersInput,
|
|
204
|
+
) -> None:
|
|
205
|
+
assert (token_store := info.context.token_store) is not None
|
|
206
|
+
if not input.user_ids:
|
|
207
|
+
return
|
|
208
|
+
user_ids = tuple(
|
|
209
|
+
map(
|
|
210
|
+
lambda gid: from_global_id_with_expected_type(gid, User.__name__),
|
|
211
|
+
set(input.user_ids),
|
|
212
|
+
)
|
|
213
|
+
)
|
|
214
|
+
system_user_role_id = (
|
|
215
|
+
select(models.UserRole.id)
|
|
216
|
+
.where(models.UserRole.name == enums.UserRole.SYSTEM.value)
|
|
217
|
+
.scalar_subquery()
|
|
218
|
+
)
|
|
219
|
+
admin_user_role_id = (
|
|
220
|
+
select(models.UserRole.id)
|
|
221
|
+
.where(models.UserRole.name == enums.UserRole.ADMIN.value)
|
|
222
|
+
.scalar_subquery()
|
|
49
223
|
)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
user_role_id
|
|
56
|
-
|
|
57
|
-
email
|
|
58
|
-
auth_method="LOCAL",
|
|
59
|
-
password_hash=password_hash,
|
|
60
|
-
reset_password=True,
|
|
224
|
+
default_admin_user_id = (
|
|
225
|
+
select(models.User.id)
|
|
226
|
+
.where(
|
|
227
|
+
(
|
|
228
|
+
and_(
|
|
229
|
+
models.User.user_role_id == admin_user_role_id,
|
|
230
|
+
models.User.username == DEFAULT_ADMIN_USERNAME,
|
|
231
|
+
models.User.email == DEFAULT_ADMIN_EMAIL,
|
|
61
232
|
)
|
|
62
|
-
.returning(models.User)
|
|
63
233
|
)
|
|
64
|
-
assert user is not None
|
|
65
|
-
except IntegrityError as error:
|
|
66
|
-
raise ValueError(_get_user_create_error_message(error))
|
|
67
|
-
return UserMutationPayload(
|
|
68
|
-
user=User(
|
|
69
|
-
id_attr=user.id,
|
|
70
|
-
email=user.email,
|
|
71
|
-
username=user.username,
|
|
72
|
-
created_at=user.created_at,
|
|
73
|
-
role=UserRole(id_attr=user.user_role_id, name=role_name),
|
|
74
234
|
)
|
|
235
|
+
.scalar_subquery()
|
|
75
236
|
)
|
|
237
|
+
async with info.context.db() as session:
|
|
238
|
+
[
|
|
239
|
+
(
|
|
240
|
+
deletes_default_admin,
|
|
241
|
+
num_resolved_user_ids,
|
|
242
|
+
)
|
|
243
|
+
] = (
|
|
244
|
+
await session.execute(
|
|
245
|
+
select(
|
|
246
|
+
cast(
|
|
247
|
+
func.coalesce(
|
|
248
|
+
func.max(
|
|
249
|
+
case((models.User.id == default_admin_user_id, 1), else_=0)
|
|
250
|
+
),
|
|
251
|
+
0,
|
|
252
|
+
),
|
|
253
|
+
Boolean,
|
|
254
|
+
).label("deletes_default_admin"),
|
|
255
|
+
func.count(distinct(models.User.id)).label("num_resolved_user_ids"),
|
|
256
|
+
)
|
|
257
|
+
.select_from(models.User)
|
|
258
|
+
.where(
|
|
259
|
+
and_(
|
|
260
|
+
models.User.id.in_(user_ids),
|
|
261
|
+
models.User.user_role_id != system_user_role_id,
|
|
262
|
+
)
|
|
263
|
+
)
|
|
264
|
+
)
|
|
265
|
+
).all()
|
|
266
|
+
if deletes_default_admin:
|
|
267
|
+
raise Conflict("Cannot delete the default admin user")
|
|
268
|
+
if num_resolved_user_ids < len(user_ids):
|
|
269
|
+
raise NotFound("Some user IDs could not be found")
|
|
270
|
+
password_reset_token_ids = [
|
|
271
|
+
PasswordResetTokenId(id_)
|
|
272
|
+
async for id_ in await session.stream_scalars(
|
|
273
|
+
select(models.PasswordResetToken.id).where(
|
|
274
|
+
models.PasswordResetToken.user_id.in_(user_ids)
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
]
|
|
278
|
+
access_token_ids = [
|
|
279
|
+
AccessTokenId(id_)
|
|
280
|
+
async for id_ in await session.stream_scalars(
|
|
281
|
+
select(models.AccessToken.id).where(models.AccessToken.user_id.in_(user_ids))
|
|
282
|
+
)
|
|
283
|
+
]
|
|
284
|
+
refresh_token_ids = [
|
|
285
|
+
RefreshTokenId(id_)
|
|
286
|
+
async for id_ in await session.stream_scalars(
|
|
287
|
+
select(models.RefreshToken.id).where(models.RefreshToken.user_id.in_(user_ids))
|
|
288
|
+
)
|
|
289
|
+
]
|
|
290
|
+
api_key_ids = [
|
|
291
|
+
ApiKeyId(id_)
|
|
292
|
+
async for id_ in await session.stream_scalars(
|
|
293
|
+
select(models.ApiKey.id).where(models.ApiKey.user_id.in_(user_ids))
|
|
294
|
+
)
|
|
295
|
+
]
|
|
296
|
+
await session.execute(delete(models.User).where(models.User.id.in_(user_ids)))
|
|
297
|
+
await token_store.revoke(
|
|
298
|
+
*password_reset_token_ids,
|
|
299
|
+
*access_token_ids,
|
|
300
|
+
*refresh_token_ids,
|
|
301
|
+
*api_key_ids,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]:
|
|
306
|
+
return select(models.UserRole.id).where(models.UserRole.name == role_name)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]:
|
|
310
|
+
return (
|
|
311
|
+
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
|
|
312
|
+
)
|
|
76
313
|
|
|
77
314
|
|
|
78
|
-
def
|
|
315
|
+
def _user_operation_error_message(
|
|
316
|
+
error: IntegrityError,
|
|
317
|
+
operation: Literal["create", "modify"] = "create",
|
|
318
|
+
) -> str:
|
|
79
319
|
"""
|
|
80
|
-
|
|
320
|
+
User-facing error message to explain why user creation/modification failed.
|
|
81
321
|
"""
|
|
82
322
|
original_error_message = str(error)
|
|
83
323
|
username_already_exists = "users.username" in original_error_message
|
|
@@ -86,4 +326,4 @@ def _get_user_create_error_message(error: IntegrityError) -> str:
|
|
|
86
326
|
return "Username already exists"
|
|
87
327
|
elif email_already_exists:
|
|
88
328
|
return "Email already exists"
|
|
89
|
-
return "Failed to
|
|
329
|
+
return f"Failed to {operation} user"
|
|
@@ -2,11 +2,11 @@ from typing import Any, Dict
|
|
|
2
2
|
|
|
3
3
|
from fastapi.openapi.utils import get_openapi
|
|
4
4
|
|
|
5
|
-
from phoenix.server.api.routers.v1 import REST_API_VERSION
|
|
6
|
-
from phoenix.server.api.routers.v1 import router as v1_router
|
|
5
|
+
from phoenix.server.api.routers.v1 import REST_API_VERSION, create_v1_router
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
def get_openapi_schema() -> Dict[str, Any]:
|
|
9
|
+
v1_router = create_v1_router(authentication_enabled=False)
|
|
10
10
|
return get_openapi(
|
|
11
11
|
title="Arize-Phoenix REST API",
|
|
12
12
|
version=REST_API_VERSION,
|
phoenix/server/api/queries.py
CHANGED
|
@@ -7,12 +7,13 @@ import numpy.typing as npt
|
|
|
7
7
|
import strawberry
|
|
8
8
|
from sqlalchemy import and_, distinct, func, select
|
|
9
9
|
from sqlalchemy.orm import joinedload
|
|
10
|
+
from starlette.authentication import UnauthenticatedUser
|
|
10
11
|
from strawberry import ID, UNSET
|
|
11
12
|
from strawberry.relay import Connection, GlobalID, Node
|
|
12
13
|
from strawberry.types import Info
|
|
13
14
|
from typing_extensions import Annotated, TypeAlias
|
|
14
15
|
|
|
15
|
-
from phoenix.db import models
|
|
16
|
+
from phoenix.db import enums, models
|
|
16
17
|
from phoenix.db.models import (
|
|
17
18
|
DatasetExample as OrmExample,
|
|
18
19
|
)
|
|
@@ -32,8 +33,9 @@ from phoenix.db.models import (
|
|
|
32
33
|
Trace as OrmTrace,
|
|
33
34
|
)
|
|
34
35
|
from phoenix.pointcloud.clustering import Hdbscan
|
|
36
|
+
from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
|
|
35
37
|
from phoenix.server.api.context import Context
|
|
36
|
-
from phoenix.server.api.exceptions import NotFound
|
|
38
|
+
from phoenix.server.api.exceptions import NotFound, Unauthorized
|
|
37
39
|
from phoenix.server.api.helpers import ensure_list
|
|
38
40
|
from phoenix.server.api.input_types.ClusterInput import ClusterInput
|
|
39
41
|
from phoenix.server.api.input_types.Coordinates import (
|
|
@@ -69,14 +71,14 @@ from phoenix.server.api.types.SortDir import SortDir
|
|
|
69
71
|
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
70
72
|
from phoenix.server.api.types.SystemApiKey import SystemApiKey
|
|
71
73
|
from phoenix.server.api.types.Trace import Trace
|
|
72
|
-
from phoenix.server.api.types.User import User
|
|
73
|
-
from phoenix.server.api.types.UserApiKey import UserApiKey
|
|
74
|
+
from phoenix.server.api.types.User import User, to_gql_user
|
|
75
|
+
from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
|
|
74
76
|
from phoenix.server.api.types.UserRole import UserRole
|
|
75
77
|
|
|
76
78
|
|
|
77
79
|
@strawberry.type
|
|
78
80
|
class Query:
|
|
79
|
-
@strawberry.field
|
|
81
|
+
@strawberry.field(permission_classes=[IsAdmin]) # type: ignore
|
|
80
82
|
async def users(
|
|
81
83
|
self,
|
|
82
84
|
info: Info[Context, None],
|
|
@@ -94,25 +96,13 @@ class Query:
|
|
|
94
96
|
stmt = (
|
|
95
97
|
select(models.User)
|
|
96
98
|
.join(models.UserRole)
|
|
97
|
-
.where(models.UserRole.name !=
|
|
99
|
+
.where(models.UserRole.name != enums.UserRole.SYSTEM.value)
|
|
98
100
|
.order_by(models.User.email)
|
|
99
101
|
.options(joinedload(models.User.role))
|
|
100
102
|
)
|
|
101
103
|
async with info.context.db() as session:
|
|
102
104
|
users = await session.stream_scalars(stmt)
|
|
103
|
-
data = [
|
|
104
|
-
User(
|
|
105
|
-
id_attr=user.id,
|
|
106
|
-
email=user.email,
|
|
107
|
-
username=user.username,
|
|
108
|
-
created_at=user.created_at,
|
|
109
|
-
role=UserRole(
|
|
110
|
-
id_attr=user.role.id,
|
|
111
|
-
name=user.role.name,
|
|
112
|
-
),
|
|
113
|
-
)
|
|
114
|
-
async for user in users
|
|
115
|
-
]
|
|
105
|
+
data = [to_gql_user(user) async for user in users]
|
|
116
106
|
return connection_from_list(data=data, args=args)
|
|
117
107
|
|
|
118
108
|
@strawberry.field
|
|
@@ -122,7 +112,7 @@ class Query:
|
|
|
122
112
|
) -> List[UserRole]:
|
|
123
113
|
async with info.context.db() as session:
|
|
124
114
|
roles = await session.scalars(
|
|
125
|
-
select(models.UserRole).where(models.UserRole.name !=
|
|
115
|
+
select(models.UserRole).where(models.UserRole.name != enums.UserRole.SYSTEM.value)
|
|
126
116
|
)
|
|
127
117
|
return [
|
|
128
118
|
UserRole(
|
|
@@ -132,37 +122,25 @@ class Query:
|
|
|
132
122
|
for role in roles
|
|
133
123
|
]
|
|
134
124
|
|
|
135
|
-
@strawberry.field
|
|
125
|
+
@strawberry.field(permission_classes=[IsAdmin]) # type: ignore
|
|
136
126
|
async def user_api_keys(self, info: Info[Context, None]) -> List[UserApiKey]:
|
|
137
|
-
# TODO(auth): add access control
|
|
138
127
|
stmt = (
|
|
139
|
-
select(models.
|
|
128
|
+
select(models.ApiKey)
|
|
140
129
|
.join(models.User)
|
|
141
130
|
.join(models.UserRole)
|
|
142
|
-
.where(models.UserRole.name !=
|
|
131
|
+
.where(models.UserRole.name != enums.UserRole.SYSTEM.value)
|
|
143
132
|
)
|
|
144
133
|
async with info.context.db() as session:
|
|
145
134
|
api_keys = await session.scalars(stmt)
|
|
146
|
-
return [
|
|
147
|
-
UserApiKey(
|
|
148
|
-
id_attr=api_key.id,
|
|
149
|
-
user_id=api_key.user_id,
|
|
150
|
-
name=api_key.name,
|
|
151
|
-
description=api_key.description,
|
|
152
|
-
created_at=api_key.created_at,
|
|
153
|
-
expires_at=api_key.expires_at,
|
|
154
|
-
)
|
|
155
|
-
for api_key in api_keys
|
|
156
|
-
]
|
|
135
|
+
return [to_gql_api_key(api_key) for api_key in api_keys]
|
|
157
136
|
|
|
158
|
-
@strawberry.field
|
|
137
|
+
@strawberry.field(permission_classes=[IsAdmin]) # type: ignore
|
|
159
138
|
async def system_api_keys(self, info: Info[Context, None]) -> List[SystemApiKey]:
|
|
160
|
-
# TODO(auth): add access control
|
|
161
139
|
stmt = (
|
|
162
|
-
select(models.
|
|
140
|
+
select(models.ApiKey)
|
|
163
141
|
.join(models.User)
|
|
164
142
|
.join(models.UserRole)
|
|
165
|
-
.where(models.UserRole.name ==
|
|
143
|
+
.where(models.UserRole.name == enums.UserRole.SYSTEM.value)
|
|
166
144
|
)
|
|
167
145
|
async with info.context.db() as session:
|
|
168
146
|
api_keys = await session.scalars(stmt)
|
|
@@ -488,8 +466,39 @@ class Query:
|
|
|
488
466
|
):
|
|
489
467
|
raise NotFound(f"Unknown experiment run: {id}")
|
|
490
468
|
return to_gql_experiment_run(run)
|
|
469
|
+
elif type_name == User.__name__:
|
|
470
|
+
if int((user := info.context.user).identity) != node_id and not user.is_admin:
|
|
471
|
+
raise Unauthorized(MSG_ADMIN_ONLY)
|
|
472
|
+
async with info.context.db() as session:
|
|
473
|
+
if not (
|
|
474
|
+
user := await session.scalar(
|
|
475
|
+
select(models.User).where(models.User.id == node_id)
|
|
476
|
+
)
|
|
477
|
+
):
|
|
478
|
+
raise NotFound(f"Unknown user: {id}")
|
|
479
|
+
return to_gql_user(user)
|
|
491
480
|
raise NotFound(f"Unknown node type: {type_name}")
|
|
492
481
|
|
|
482
|
+
@strawberry.field
|
|
483
|
+
async def viewer(self, info: Info[Context, None]) -> Optional[User]:
|
|
484
|
+
request = info.context.get_request()
|
|
485
|
+
try:
|
|
486
|
+
user = request.user
|
|
487
|
+
except AssertionError:
|
|
488
|
+
return None
|
|
489
|
+
if isinstance(user, UnauthenticatedUser):
|
|
490
|
+
return None
|
|
491
|
+
async with info.context.db() as session:
|
|
492
|
+
if (
|
|
493
|
+
user := await session.scalar(
|
|
494
|
+
select(models.User)
|
|
495
|
+
.where(models.User.id == int(user.identity))
|
|
496
|
+
.options(joinedload(models.User.role))
|
|
497
|
+
)
|
|
498
|
+
) is None:
|
|
499
|
+
return None
|
|
500
|
+
return to_gql_user(user)
|
|
501
|
+
|
|
493
502
|
@strawberry.field
|
|
494
503
|
def clusters(
|
|
495
504
|
self,
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .auth import router as auth_router
|
|
2
|
+
from .embeddings import create_embeddings_router
|
|
3
|
+
from .oauth2 import router as oauth2_router
|
|
4
|
+
from .v1 import create_v1_router
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"auth_router",
|
|
8
|
+
"create_embeddings_router",
|
|
9
|
+
"create_v1_router",
|
|
10
|
+
"oauth2_router",
|
|
11
|
+
]
|