arize-phoenix 4.36.0__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (81) hide show
  1. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/METADATA +10 -12
  2. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/RECORD +69 -60
  3. phoenix/__init__.py +86 -0
  4. phoenix/auth.py +275 -14
  5. phoenix/config.py +277 -25
  6. phoenix/db/enums.py +20 -0
  7. phoenix/db/facilitator.py +112 -0
  8. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  9. phoenix/db/models.py +145 -60
  10. phoenix/experiments/evaluators/code_evaluators.py +9 -3
  11. phoenix/experiments/functions.py +1 -4
  12. phoenix/server/api/README.md +28 -0
  13. phoenix/server/api/auth.py +32 -0
  14. phoenix/server/api/context.py +50 -2
  15. phoenix/server/api/dataloaders/__init__.py +4 -0
  16. phoenix/server/api/dataloaders/user_roles.py +30 -0
  17. phoenix/server/api/dataloaders/users.py +33 -0
  18. phoenix/server/api/exceptions.py +7 -0
  19. phoenix/server/api/mutations/__init__.py +0 -2
  20. phoenix/server/api/mutations/api_key_mutations.py +104 -86
  21. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  22. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  23. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  24. phoenix/server/api/mutations/project_mutations.py +3 -3
  25. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  26. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  27. phoenix/server/api/mutations/user_mutations.py +282 -42
  28. phoenix/server/api/openapi/schema.py +2 -2
  29. phoenix/server/api/queries.py +48 -39
  30. phoenix/server/api/routers/__init__.py +11 -0
  31. phoenix/server/api/routers/auth.py +284 -0
  32. phoenix/server/api/routers/embeddings.py +26 -0
  33. phoenix/server/api/routers/oauth2.py +456 -0
  34. phoenix/server/api/routers/v1/__init__.py +38 -16
  35. phoenix/server/api/types/ApiKey.py +11 -0
  36. phoenix/server/api/types/AuthMethod.py +9 -0
  37. phoenix/server/api/types/User.py +48 -4
  38. phoenix/server/api/types/UserApiKey.py +35 -1
  39. phoenix/server/api/types/UserRole.py +7 -0
  40. phoenix/server/app.py +103 -31
  41. phoenix/server/bearer_auth.py +161 -0
  42. phoenix/server/email/__init__.py +0 -0
  43. phoenix/server/email/sender.py +26 -0
  44. phoenix/server/email/templates/__init__.py +0 -0
  45. phoenix/server/email/templates/password_reset.html +19 -0
  46. phoenix/server/email/types.py +11 -0
  47. phoenix/server/grpc_server.py +6 -0
  48. phoenix/server/jwt_store.py +504 -0
  49. phoenix/server/main.py +40 -9
  50. phoenix/server/oauth2.py +51 -0
  51. phoenix/server/prometheus.py +20 -0
  52. phoenix/server/rate_limiters.py +191 -0
  53. phoenix/server/static/.vite/manifest.json +31 -31
  54. phoenix/server/static/assets/{components-Dte7_KRd.js → components-REunxTt6.js} +348 -286
  55. phoenix/server/static/assets/index-DAPJxlCw.js +101 -0
  56. phoenix/server/static/assets/{pages-CnTvEGEN.js → pages-1VrMk2pW.js} +559 -291
  57. phoenix/server/static/assets/{vendor-BC3OPQuM.js → vendor-B5IC0ivG.js} +5 -5
  58. phoenix/server/static/assets/{vendor-arizeai-NjB3cZzD.js → vendor-arizeai-aFbT4kl1.js} +2 -2
  59. phoenix/server/static/assets/{vendor-codemirror-gE_JCOgX.js → vendor-codemirror-BEGorXSV.js} +1 -1
  60. phoenix/server/static/assets/{vendor-recharts-BXLYwcXF.js → vendor-recharts-6nUU7gU_.js} +1 -1
  61. phoenix/server/templates/index.html +1 -0
  62. phoenix/server/types.py +157 -1
  63. phoenix/session/client.py +7 -2
  64. phoenix/trace/fixtures.py +24 -0
  65. phoenix/utilities/client.py +16 -0
  66. phoenix/version.py +1 -1
  67. phoenix/db/migrations/future_versions/README.md +0 -4
  68. phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +0 -293
  69. phoenix/db/migrations/versions/.gitignore +0 -1
  70. phoenix/server/api/mutations/auth.py +0 -18
  71. phoenix/server/api/mutations/auth_mutations.py +0 -65
  72. phoenix/server/static/assets/index-fq1-hCK4.js +0 -100
  73. phoenix/trace/langchain/__init__.py +0 -3
  74. phoenix/trace/langchain/instrumentor.py +0 -34
  75. phoenix/trace/llama_index/__init__.py +0 -3
  76. phoenix/trace/llama_index/callback.py +0 -102
  77. phoenix/trace/openai/__init__.py +0 -3
  78. phoenix/trace/openai/instrumentor.py +0 -30
  79. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/WHEEL +0 -0
  80. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/licenses/IP_NOTICE +0 -0
  81. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -6,11 +6,11 @@ from strawberry import UNSET
6
6
  from strawberry.types import Info
7
7
 
8
8
  from phoenix.db import models
9
+ from phoenix.server.api.auth import IsNotReadOnly
9
10
  from phoenix.server.api.context import Context
10
11
  from phoenix.server.api.input_types.CreateTraceAnnotationInput import CreateTraceAnnotationInput
11
12
  from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
12
13
  from phoenix.server.api.input_types.PatchAnnotationInput import PatchAnnotationInput
13
- from phoenix.server.api.mutations.auth import IsAuthenticated
14
14
  from phoenix.server.api.queries import Query
15
15
  from phoenix.server.api.types.node import from_global_id_with_expected_type
16
16
  from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
@@ -25,7 +25,7 @@ class TraceAnnotationMutationPayload:
25
25
 
26
26
  @strawberry.type
27
27
  class TraceAnnotationMutationMixin:
28
- @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
28
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
29
29
  async def create_trace_annotations(
30
30
  self, info: Info[Context, None], input: List[CreateTraceAnnotationInput]
31
31
  ) -> TraceAnnotationMutationPayload:
@@ -59,7 +59,7 @@ class TraceAnnotationMutationMixin:
59
59
  query=Query(),
60
60
  )
61
61
 
62
- @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
62
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
63
63
  async def patch_trace_annotations(
64
64
  self, info: Info[Context, None], input: List[PatchAnnotationInput]
65
65
  ) -> TraceAnnotationMutationPayload:
@@ -98,7 +98,7 @@ class TraceAnnotationMutationMixin:
98
98
  info.context.event_queue.put(TraceAnnotationInsertEvent((trace_annotation.id,)))
99
99
  return TraceAnnotationMutationPayload(trace_annotations=patched_annotations, query=Query())
100
100
 
101
- @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
101
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
102
102
  async def delete_trace_annotations(
103
103
  self, info: Info[Context, None], input: DeleteAnnotationsInput
104
104
  ) -> TraceAnnotationMutationPayload:
@@ -1,27 +1,79 @@
1
- import asyncio
2
- from typing import Optional
1
+ import secrets
2
+ from contextlib import AsyncExitStack
3
+ from datetime import datetime, timezone
4
+ from typing import List, Literal, Optional, Tuple
3
5
 
4
6
  import strawberry
5
- from sqlalchemy import insert, select
7
+ from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select
8
+ from sqlalchemy.orm import joinedload
6
9
  from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
10
+ from strawberry import UNSET
11
+ from strawberry.relay import GlobalID
7
12
  from strawberry.types import Info
8
13
 
9
- from phoenix.auth import compute_password_hash, validate_email_format, validate_password_format
10
- from phoenix.db import models
14
+ from phoenix.auth import (
15
+ DEFAULT_ADMIN_EMAIL,
16
+ DEFAULT_ADMIN_USERNAME,
17
+ DEFAULT_SECRET_LENGTH,
18
+ PASSWORD_REQUIREMENTS,
19
+ PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
20
+ PHOENIX_REFRESH_TOKEN_COOKIE_NAME,
21
+ validate_email_format,
22
+ validate_password_format,
23
+ )
24
+ from phoenix.db import enums, models
25
+ from phoenix.server.api.auth import IsAdmin, IsNotReadOnly
11
26
  from phoenix.server.api.context import Context
27
+ from phoenix.server.api.exceptions import Conflict, NotFound, Unauthorized
12
28
  from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
13
- from phoenix.server.api.types.User import User
14
- from phoenix.server.api.types.UserRole import UserRole
29
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
30
+ from phoenix.server.api.types.User import User, to_gql_user
31
+ from phoenix.server.bearer_auth import PhoenixUser
32
+ from phoenix.server.types import AccessTokenId, ApiKeyId, PasswordResetTokenId, RefreshTokenId
15
33
 
16
34
 
17
35
  @strawberry.input
18
36
  class CreateUserInput:
19
37
  email: str
20
- username: Optional[str]
38
+ username: str
21
39
  password: str
22
40
  role: UserRoleInput
23
41
 
24
42
 
43
+ @strawberry.input
44
+ class PatchViewerInput:
45
+ new_username: Optional[str] = UNSET
46
+ new_password: Optional[str] = UNSET
47
+ current_password: Optional[str] = UNSET
48
+
49
+ def __post_init__(self) -> None:
50
+ if not self.new_username and not self.new_password:
51
+ raise ValueError("At least one field must be set")
52
+ if self.new_password and not self.current_password:
53
+ raise ValueError("current_password is required when modifying password")
54
+ if self.new_password:
55
+ PASSWORD_REQUIREMENTS.validate(self.new_password)
56
+
57
+
58
+ @strawberry.input
59
+ class PatchUserInput:
60
+ user_id: GlobalID
61
+ new_role: Optional[UserRoleInput] = UNSET
62
+ new_username: Optional[str] = UNSET
63
+ new_password: Optional[str] = UNSET
64
+
65
+ def __post_init__(self) -> None:
66
+ if not self.new_role and not self.new_username and not self.new_password:
67
+ raise ValueError("At least one field must be set")
68
+ if self.new_password:
69
+ PASSWORD_REQUIREMENTS.validate(self.new_password)
70
+
71
+
72
+ @strawberry.input
73
+ class DeleteUsersInput:
74
+ user_ids: List[GlobalID]
75
+
76
+
25
77
  @strawberry.type
26
78
  class UserMutationPayload:
27
79
  user: User
@@ -29,7 +81,7 @@ class UserMutationPayload:
29
81
 
30
82
  @strawberry.type
31
83
  class UserMutationMixin:
32
- @strawberry.mutation
84
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
33
85
  async def create_user(
34
86
  self,
35
87
  info: Info[Context, None],
@@ -37,47 +89,235 @@ class UserMutationMixin:
37
89
  ) -> UserMutationPayload:
38
90
  validate_email_format(email := input.email)
39
91
  validate_password_format(password := input.password)
40
- role_name = input.role.value
41
- user_role_id = (
42
- select(models.UserRole.id).where(models.UserRole.name == role_name).scalar_subquery()
92
+ salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
93
+ password_hash = await info.context.hash_password(password, salt)
94
+ user = models.User(
95
+ reset_password=True,
96
+ username=input.username,
97
+ email=email,
98
+ password_hash=password_hash,
99
+ password_salt=salt,
43
100
  )
44
- secret = info.context.get_secret()
45
- loop = asyncio.get_running_loop()
46
- password_hash = await loop.run_in_executor(
47
- executor=None,
48
- func=lambda: compute_password_hash(password=password, salt=secret),
101
+ async with AsyncExitStack() as stack:
102
+ session = await stack.enter_async_context(info.context.db())
103
+ user_role_id = await session.scalar(_select_role_id_by_name(input.role.value))
104
+ if user_role_id is None:
105
+ raise NotFound(f"Role {input.role.value} not found")
106
+ stack.enter_context(session.no_autoflush)
107
+ user.user_role_id = user_role_id
108
+ session.add(user)
109
+ try:
110
+ await session.flush()
111
+ except IntegrityError as error:
112
+ raise Conflict(_user_operation_error_message(error))
113
+ return UserMutationPayload(user=to_gql_user(user))
114
+
115
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
116
+ async def patch_user(
117
+ self,
118
+ info: Info[Context, None],
119
+ input: PatchUserInput,
120
+ ) -> UserMutationPayload:
121
+ assert (request := info.context.request)
122
+ assert isinstance(request.user, PhoenixUser)
123
+ assert (requester_id := int(request.user.identity))
124
+ user_id = from_global_id_with_expected_type(input.user_id, expected_type_name=User.__name__)
125
+ async with AsyncExitStack() as stack:
126
+ session = await stack.enter_async_context(info.context.db())
127
+ requester = await session.scalar(_select_user_by_id(requester_id))
128
+ assert requester
129
+ if not (user := await session.scalar(_select_user_by_id(user_id))):
130
+ raise NotFound("User not found")
131
+ stack.enter_context(session.no_autoflush)
132
+ if input.new_role:
133
+ if user.email == DEFAULT_ADMIN_EMAIL:
134
+ raise Unauthorized("Cannot modify role for the default admin user")
135
+ user_role_id = await session.scalar(_select_role_id_by_name(input.new_role.value))
136
+ if user_role_id is None:
137
+ raise NotFound(f"Role {input.new_role.value} not found")
138
+ user.user_role_id = user_role_id
139
+ if password := input.new_password:
140
+ if user.auth_method != enums.AuthMethod.LOCAL.value:
141
+ raise Conflict("Cannot modify password for non-local user")
142
+ validate_password_format(password)
143
+ user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
144
+ user.password_hash = await info.context.hash_password(password, user.password_salt)
145
+ user.reset_password = True
146
+ if username := input.new_username:
147
+ user.username = username
148
+ assert user in session.dirty
149
+ try:
150
+ await session.flush()
151
+ except IntegrityError as error:
152
+ raise Conflict(_user_operation_error_message(error, "modify"))
153
+ assert user
154
+ if input.new_password:
155
+ await info.context.log_out(user.id)
156
+ return UserMutationPayload(user=to_gql_user(user))
157
+
158
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
159
+ async def patch_viewer(
160
+ self,
161
+ info: Info[Context, None],
162
+ input: PatchViewerInput,
163
+ ) -> UserMutationPayload:
164
+ assert (request := info.context.request)
165
+ assert isinstance(user := request.user, PhoenixUser)
166
+ user_id = int(user.identity)
167
+ async with AsyncExitStack() as stack:
168
+ session = await stack.enter_async_context(info.context.db())
169
+ if not (user := await session.scalar(_select_user_by_id(user_id))):
170
+ raise NotFound("User not found")
171
+ stack.enter_context(session.no_autoflush)
172
+ if password := input.new_password:
173
+ if user.auth_method != enums.AuthMethod.LOCAL.value:
174
+ raise Conflict("Cannot modify password for non-local user")
175
+ if not (
176
+ current_password := input.current_password
177
+ ) or not await info.context.is_valid_password(current_password, user):
178
+ raise Conflict("Valid current password is required to modify password")
179
+ validate_password_format(password)
180
+ user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
181
+ user.password_hash = await info.context.hash_password(password, user.password_salt)
182
+ user.reset_password = False
183
+ if username := input.new_username:
184
+ user.username = username
185
+ assert user in session.dirty
186
+ user.updated_at = datetime.now(timezone.utc)
187
+ try:
188
+ await session.flush()
189
+ except IntegrityError as error:
190
+ raise Conflict(_user_operation_error_message(error, "modify"))
191
+ assert user
192
+ if input.new_password:
193
+ await info.context.log_out(user.id)
194
+ response = info.context.get_response()
195
+ response.delete_cookie(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)
196
+ response.delete_cookie(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)
197
+ return UserMutationPayload(user=to_gql_user(user))
198
+
199
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
200
+ async def delete_users(
201
+ self,
202
+ info: Info[Context, None],
203
+ input: DeleteUsersInput,
204
+ ) -> None:
205
+ assert (token_store := info.context.token_store) is not None
206
+ if not input.user_ids:
207
+ return
208
+ user_ids = tuple(
209
+ map(
210
+ lambda gid: from_global_id_with_expected_type(gid, User.__name__),
211
+ set(input.user_ids),
212
+ )
213
+ )
214
+ system_user_role_id = (
215
+ select(models.UserRole.id)
216
+ .where(models.UserRole.name == enums.UserRole.SYSTEM.value)
217
+ .scalar_subquery()
218
+ )
219
+ admin_user_role_id = (
220
+ select(models.UserRole.id)
221
+ .where(models.UserRole.name == enums.UserRole.ADMIN.value)
222
+ .scalar_subquery()
49
223
  )
50
- try:
51
- async with info.context.db() as session:
52
- user = await session.scalar(
53
- insert(models.User)
54
- .values(
55
- user_role_id=user_role_id,
56
- username=input.username,
57
- email=email,
58
- auth_method="LOCAL",
59
- password_hash=password_hash,
60
- reset_password=True,
224
+ default_admin_user_id = (
225
+ select(models.User.id)
226
+ .where(
227
+ (
228
+ and_(
229
+ models.User.user_role_id == admin_user_role_id,
230
+ models.User.username == DEFAULT_ADMIN_USERNAME,
231
+ models.User.email == DEFAULT_ADMIN_EMAIL,
61
232
  )
62
- .returning(models.User)
63
233
  )
64
- assert user is not None
65
- except IntegrityError as error:
66
- raise ValueError(_get_user_create_error_message(error))
67
- return UserMutationPayload(
68
- user=User(
69
- id_attr=user.id,
70
- email=user.email,
71
- username=user.username,
72
- created_at=user.created_at,
73
- role=UserRole(id_attr=user.user_role_id, name=role_name),
74
234
  )
235
+ .scalar_subquery()
75
236
  )
237
+ async with info.context.db() as session:
238
+ [
239
+ (
240
+ deletes_default_admin,
241
+ num_resolved_user_ids,
242
+ )
243
+ ] = (
244
+ await session.execute(
245
+ select(
246
+ cast(
247
+ func.coalesce(
248
+ func.max(
249
+ case((models.User.id == default_admin_user_id, 1), else_=0)
250
+ ),
251
+ 0,
252
+ ),
253
+ Boolean,
254
+ ).label("deletes_default_admin"),
255
+ func.count(distinct(models.User.id)).label("num_resolved_user_ids"),
256
+ )
257
+ .select_from(models.User)
258
+ .where(
259
+ and_(
260
+ models.User.id.in_(user_ids),
261
+ models.User.user_role_id != system_user_role_id,
262
+ )
263
+ )
264
+ )
265
+ ).all()
266
+ if deletes_default_admin:
267
+ raise Conflict("Cannot delete the default admin user")
268
+ if num_resolved_user_ids < len(user_ids):
269
+ raise NotFound("Some user IDs could not be found")
270
+ password_reset_token_ids = [
271
+ PasswordResetTokenId(id_)
272
+ async for id_ in await session.stream_scalars(
273
+ select(models.PasswordResetToken.id).where(
274
+ models.PasswordResetToken.user_id.in_(user_ids)
275
+ )
276
+ )
277
+ ]
278
+ access_token_ids = [
279
+ AccessTokenId(id_)
280
+ async for id_ in await session.stream_scalars(
281
+ select(models.AccessToken.id).where(models.AccessToken.user_id.in_(user_ids))
282
+ )
283
+ ]
284
+ refresh_token_ids = [
285
+ RefreshTokenId(id_)
286
+ async for id_ in await session.stream_scalars(
287
+ select(models.RefreshToken.id).where(models.RefreshToken.user_id.in_(user_ids))
288
+ )
289
+ ]
290
+ api_key_ids = [
291
+ ApiKeyId(id_)
292
+ async for id_ in await session.stream_scalars(
293
+ select(models.ApiKey.id).where(models.ApiKey.user_id.in_(user_ids))
294
+ )
295
+ ]
296
+ await session.execute(delete(models.User).where(models.User.id.in_(user_ids)))
297
+ await token_store.revoke(
298
+ *password_reset_token_ids,
299
+ *access_token_ids,
300
+ *refresh_token_ids,
301
+ *api_key_ids,
302
+ )
303
+
304
+
305
+ def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]:
306
+ return select(models.UserRole.id).where(models.UserRole.name == role_name)
307
+
308
+
309
+ def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]:
310
+ return (
311
+ select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
312
+ )
76
313
 
77
314
 
78
- def _get_user_create_error_message(error: IntegrityError) -> str:
315
+ def _user_operation_error_message(
316
+ error: IntegrityError,
317
+ operation: Literal["create", "modify"] = "create",
318
+ ) -> str:
79
319
  """
80
- Gets a user-facing error message to explain why user creation failed.
320
+ User-facing error message to explain why user creation/modification failed.
81
321
  """
82
322
  original_error_message = str(error)
83
323
  username_already_exists = "users.username" in original_error_message
@@ -86,4 +326,4 @@ def _get_user_create_error_message(error: IntegrityError) -> str:
86
326
  return "Username already exists"
87
327
  elif email_already_exists:
88
328
  return "Email already exists"
89
- return "Failed to create user"
329
+ return f"Failed to {operation} user"
@@ -2,11 +2,11 @@ from typing import Any, Dict
2
2
 
3
3
  from fastapi.openapi.utils import get_openapi
4
4
 
5
- from phoenix.server.api.routers.v1 import REST_API_VERSION
6
- from phoenix.server.api.routers.v1 import router as v1_router
5
+ from phoenix.server.api.routers.v1 import REST_API_VERSION, create_v1_router
7
6
 
8
7
 
9
8
  def get_openapi_schema() -> Dict[str, Any]:
9
+ v1_router = create_v1_router(authentication_enabled=False)
10
10
  return get_openapi(
11
11
  title="Arize-Phoenix REST API",
12
12
  version=REST_API_VERSION,
@@ -7,12 +7,13 @@ import numpy.typing as npt
7
7
  import strawberry
8
8
  from sqlalchemy import and_, distinct, func, select
9
9
  from sqlalchemy.orm import joinedload
10
+ from starlette.authentication import UnauthenticatedUser
10
11
  from strawberry import ID, UNSET
11
12
  from strawberry.relay import Connection, GlobalID, Node
12
13
  from strawberry.types import Info
13
14
  from typing_extensions import Annotated, TypeAlias
14
15
 
15
- from phoenix.db import models
16
+ from phoenix.db import enums, models
16
17
  from phoenix.db.models import (
17
18
  DatasetExample as OrmExample,
18
19
  )
@@ -32,8 +33,9 @@ from phoenix.db.models import (
32
33
  Trace as OrmTrace,
33
34
  )
34
35
  from phoenix.pointcloud.clustering import Hdbscan
36
+ from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
35
37
  from phoenix.server.api.context import Context
36
- from phoenix.server.api.exceptions import NotFound
38
+ from phoenix.server.api.exceptions import NotFound, Unauthorized
37
39
  from phoenix.server.api.helpers import ensure_list
38
40
  from phoenix.server.api.input_types.ClusterInput import ClusterInput
39
41
  from phoenix.server.api.input_types.Coordinates import (
@@ -69,14 +71,14 @@ from phoenix.server.api.types.SortDir import SortDir
69
71
  from phoenix.server.api.types.Span import Span, to_gql_span
70
72
  from phoenix.server.api.types.SystemApiKey import SystemApiKey
71
73
  from phoenix.server.api.types.Trace import Trace
72
- from phoenix.server.api.types.User import User
73
- from phoenix.server.api.types.UserApiKey import UserApiKey
74
+ from phoenix.server.api.types.User import User, to_gql_user
75
+ from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
74
76
  from phoenix.server.api.types.UserRole import UserRole
75
77
 
76
78
 
77
79
  @strawberry.type
78
80
  class Query:
79
- @strawberry.field
81
+ @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
80
82
  async def users(
81
83
  self,
82
84
  info: Info[Context, None],
@@ -94,25 +96,13 @@ class Query:
94
96
  stmt = (
95
97
  select(models.User)
96
98
  .join(models.UserRole)
97
- .where(models.UserRole.name != "SYSTEM")
99
+ .where(models.UserRole.name != enums.UserRole.SYSTEM.value)
98
100
  .order_by(models.User.email)
99
101
  .options(joinedload(models.User.role))
100
102
  )
101
103
  async with info.context.db() as session:
102
104
  users = await session.stream_scalars(stmt)
103
- data = [
104
- User(
105
- id_attr=user.id,
106
- email=user.email,
107
- username=user.username,
108
- created_at=user.created_at,
109
- role=UserRole(
110
- id_attr=user.role.id,
111
- name=user.role.name,
112
- ),
113
- )
114
- async for user in users
115
- ]
105
+ data = [to_gql_user(user) async for user in users]
116
106
  return connection_from_list(data=data, args=args)
117
107
 
118
108
  @strawberry.field
@@ -122,7 +112,7 @@ class Query:
122
112
  ) -> List[UserRole]:
123
113
  async with info.context.db() as session:
124
114
  roles = await session.scalars(
125
- select(models.UserRole).where(models.UserRole.name != "SYSTEM")
115
+ select(models.UserRole).where(models.UserRole.name != enums.UserRole.SYSTEM.value)
126
116
  )
127
117
  return [
128
118
  UserRole(
@@ -132,37 +122,25 @@ class Query:
132
122
  for role in roles
133
123
  ]
134
124
 
135
- @strawberry.field
125
+ @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
136
126
  async def user_api_keys(self, info: Info[Context, None]) -> List[UserApiKey]:
137
- # TODO(auth): add access control
138
127
  stmt = (
139
- select(models.APIKey)
128
+ select(models.ApiKey)
140
129
  .join(models.User)
141
130
  .join(models.UserRole)
142
- .where(models.UserRole.name != "SYSTEM")
131
+ .where(models.UserRole.name != enums.UserRole.SYSTEM.value)
143
132
  )
144
133
  async with info.context.db() as session:
145
134
  api_keys = await session.scalars(stmt)
146
- return [
147
- UserApiKey(
148
- id_attr=api_key.id,
149
- user_id=api_key.user_id,
150
- name=api_key.name,
151
- description=api_key.description,
152
- created_at=api_key.created_at,
153
- expires_at=api_key.expires_at,
154
- )
155
- for api_key in api_keys
156
- ]
135
+ return [to_gql_api_key(api_key) for api_key in api_keys]
157
136
 
158
- @strawberry.field
137
+ @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
159
138
  async def system_api_keys(self, info: Info[Context, None]) -> List[SystemApiKey]:
160
- # TODO(auth): add access control
161
139
  stmt = (
162
- select(models.APIKey)
140
+ select(models.ApiKey)
163
141
  .join(models.User)
164
142
  .join(models.UserRole)
165
- .where(models.UserRole.name == "SYSTEM")
143
+ .where(models.UserRole.name == enums.UserRole.SYSTEM.value)
166
144
  )
167
145
  async with info.context.db() as session:
168
146
  api_keys = await session.scalars(stmt)
@@ -488,8 +466,39 @@ class Query:
488
466
  ):
489
467
  raise NotFound(f"Unknown experiment run: {id}")
490
468
  return to_gql_experiment_run(run)
469
+ elif type_name == User.__name__:
470
+ if int((user := info.context.user).identity) != node_id and not user.is_admin:
471
+ raise Unauthorized(MSG_ADMIN_ONLY)
472
+ async with info.context.db() as session:
473
+ if not (
474
+ user := await session.scalar(
475
+ select(models.User).where(models.User.id == node_id)
476
+ )
477
+ ):
478
+ raise NotFound(f"Unknown user: {id}")
479
+ return to_gql_user(user)
491
480
  raise NotFound(f"Unknown node type: {type_name}")
492
481
 
482
+ @strawberry.field
483
+ async def viewer(self, info: Info[Context, None]) -> Optional[User]:
484
+ request = info.context.get_request()
485
+ try:
486
+ user = request.user
487
+ except AssertionError:
488
+ return None
489
+ if isinstance(user, UnauthenticatedUser):
490
+ return None
491
+ async with info.context.db() as session:
492
+ if (
493
+ user := await session.scalar(
494
+ select(models.User)
495
+ .where(models.User.id == int(user.identity))
496
+ .options(joinedload(models.User.role))
497
+ )
498
+ ) is None:
499
+ return None
500
+ return to_gql_user(user)
501
+
493
502
  @strawberry.field
494
503
  def clusters(
495
504
  self,
@@ -0,0 +1,11 @@
1
+ from .auth import router as auth_router
2
+ from .embeddings import create_embeddings_router
3
+ from .oauth2 import router as oauth2_router
4
+ from .v1 import create_v1_router
5
+
6
+ __all__ = [
7
+ "auth_router",
8
+ "create_embeddings_router",
9
+ "create_v1_router",
10
+ "oauth2_router",
11
+ ]