arize-phoenix 4.36.0__py3-none-any.whl → 5.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (80) hide show
  1. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/METADATA +10 -12
  2. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/RECORD +68 -59
  3. phoenix/__init__.py +86 -0
  4. phoenix/auth.py +275 -14
  5. phoenix/config.py +277 -25
  6. phoenix/db/enums.py +20 -0
  7. phoenix/db/facilitator.py +112 -0
  8. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  9. phoenix/db/models.py +145 -60
  10. phoenix/experiments/evaluators/code_evaluators.py +9 -3
  11. phoenix/experiments/functions.py +1 -4
  12. phoenix/server/api/README.md +28 -0
  13. phoenix/server/api/auth.py +32 -0
  14. phoenix/server/api/context.py +50 -2
  15. phoenix/server/api/dataloaders/__init__.py +4 -0
  16. phoenix/server/api/dataloaders/user_roles.py +30 -0
  17. phoenix/server/api/dataloaders/users.py +33 -0
  18. phoenix/server/api/exceptions.py +7 -0
  19. phoenix/server/api/mutations/__init__.py +0 -2
  20. phoenix/server/api/mutations/api_key_mutations.py +104 -86
  21. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  22. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  23. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  24. phoenix/server/api/mutations/project_mutations.py +3 -3
  25. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  26. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  27. phoenix/server/api/mutations/user_mutations.py +282 -42
  28. phoenix/server/api/openapi/schema.py +2 -2
  29. phoenix/server/api/queries.py +48 -39
  30. phoenix/server/api/routers/__init__.py +11 -0
  31. phoenix/server/api/routers/auth.py +284 -0
  32. phoenix/server/api/routers/embeddings.py +26 -0
  33. phoenix/server/api/routers/oauth2.py +456 -0
  34. phoenix/server/api/routers/v1/__init__.py +38 -16
  35. phoenix/server/api/types/ApiKey.py +11 -0
  36. phoenix/server/api/types/AuthMethod.py +9 -0
  37. phoenix/server/api/types/User.py +48 -4
  38. phoenix/server/api/types/UserApiKey.py +35 -1
  39. phoenix/server/api/types/UserRole.py +7 -0
  40. phoenix/server/app.py +103 -31
  41. phoenix/server/bearer_auth.py +161 -0
  42. phoenix/server/email/__init__.py +0 -0
  43. phoenix/server/email/sender.py +26 -0
  44. phoenix/server/email/templates/__init__.py +0 -0
  45. phoenix/server/email/templates/password_reset.html +19 -0
  46. phoenix/server/email/types.py +11 -0
  47. phoenix/server/grpc_server.py +6 -0
  48. phoenix/server/jwt_store.py +504 -0
  49. phoenix/server/main.py +40 -9
  50. phoenix/server/oauth2.py +51 -0
  51. phoenix/server/prometheus.py +20 -0
  52. phoenix/server/rate_limiters.py +191 -0
  53. phoenix/server/static/.vite/manifest.json +31 -31
  54. phoenix/server/static/assets/{components-Dte7_KRd.js → components-REunxTt6.js} +348 -286
  55. phoenix/server/static/assets/index-DAPJxlCw.js +101 -0
  56. phoenix/server/static/assets/{pages-CnTvEGEN.js → pages-1VrMk2pW.js} +559 -291
  57. phoenix/server/static/assets/{vendor-BC3OPQuM.js → vendor-B5IC0ivG.js} +5 -5
  58. phoenix/server/static/assets/{vendor-arizeai-NjB3cZzD.js → vendor-arizeai-aFbT4kl1.js} +2 -2
  59. phoenix/server/static/assets/{vendor-codemirror-gE_JCOgX.js → vendor-codemirror-BEGorXSV.js} +1 -1
  60. phoenix/server/static/assets/{vendor-recharts-BXLYwcXF.js → vendor-recharts-6nUU7gU_.js} +1 -1
  61. phoenix/server/templates/index.html +1 -0
  62. phoenix/server/types.py +157 -1
  63. phoenix/session/client.py +7 -2
  64. phoenix/utilities/client.py +16 -0
  65. phoenix/version.py +1 -1
  66. phoenix/db/migrations/future_versions/README.md +0 -4
  67. phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +0 -293
  68. phoenix/db/migrations/versions/.gitignore +0 -1
  69. phoenix/server/api/mutations/auth.py +0 -18
  70. phoenix/server/api/mutations/auth_mutations.py +0 -65
  71. phoenix/server/static/assets/index-fq1-hCK4.js +0 -100
  72. phoenix/trace/langchain/__init__.py +0 -3
  73. phoenix/trace/langchain/instrumentor.py +0 -34
  74. phoenix/trace/llama_index/__init__.py +0 -3
  75. phoenix/trace/llama_index/callback.py +0 -102
  76. phoenix/trace/openai/__init__.py +0 -3
  77. phoenix/trace/openai/instrumentor.py +0 -30
  78. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/WHEEL +0 -0
  79. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  80. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,504 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from asyncio import create_task, gather, sleep
4
+ from copy import deepcopy
5
+ from dataclasses import replace
6
+ from datetime import datetime, timezone
7
+ from functools import cached_property, singledispatchmethod
8
+ from typing import Any, Callable, Coroutine, Dict, Generic, List, Optional, Tuple, Type, TypeVar
9
+
10
+ from authlib.jose import jwt
11
+ from authlib.jose.errors import JoseError
12
+ from sqlalchemy import Select, delete, select
13
+
14
+ from phoenix.auth import (
15
+ JWT_ALGORITHM,
16
+ ClaimSet,
17
+ Token,
18
+ )
19
+ from phoenix.config import get_env_enable_prometheus
20
+ from phoenix.db import models
21
+ from phoenix.db.enums import UserRole
22
+ from phoenix.server.types import (
23
+ AccessToken,
24
+ AccessTokenAttributes,
25
+ AccessTokenClaims,
26
+ AccessTokenId,
27
+ ApiKey,
28
+ ApiKeyAttributes,
29
+ ApiKeyClaims,
30
+ ApiKeyId,
31
+ DaemonTask,
32
+ DbSessionFactory,
33
+ PasswordResetToken,
34
+ PasswordResetTokenAttributes,
35
+ PasswordResetTokenClaims,
36
+ PasswordResetTokenId,
37
+ RefreshToken,
38
+ RefreshTokenAttributes,
39
+ RefreshTokenClaims,
40
+ RefreshTokenId,
41
+ TokenId,
42
+ UserId,
43
+ )
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ class JwtStore:
49
+ def __init__(
50
+ self,
51
+ db: DbSessionFactory,
52
+ secret: str,
53
+ algorithm: str = JWT_ALGORITHM,
54
+ sleep_seconds: int = 10,
55
+ **kwargs: Any,
56
+ ) -> None:
57
+ assert secret
58
+ super().__init__(**kwargs)
59
+ self._db = db
60
+ self._secret = secret
61
+ args = (db, secret, algorithm, sleep_seconds)
62
+ self._password_reset_token_store = _PasswordResetTokenStore(*args, **kwargs)
63
+ self._access_token_store = _AccessTokenStore(*args, **kwargs)
64
+ self._refresh_token_store = _RefreshTokenStore(*args, **kwargs)
65
+ self._api_key_store = _ApiKeyStore(*args, **kwargs)
66
+
67
+ @cached_property
68
+ def _stores(self) -> Tuple[DaemonTask, ...]:
69
+ return tuple(dt for dt in self.__dict__.values() if isinstance(dt, _Store))
70
+
71
+ async def __aenter__(self) -> None:
72
+ await gather(*(s.__aenter__() for s in self._stores))
73
+
74
+ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
75
+ await gather(*(s.__aexit__(*args, **kwargs) for s in self._stores))
76
+
77
+ async def read(self, token: Token) -> Optional[ClaimSet]:
78
+ try:
79
+ payload = jwt.decode(
80
+ s=token,
81
+ key=self._secret,
82
+ )
83
+ except JoseError:
84
+ return None
85
+ if (jti := payload.get("jti")) is None:
86
+ return None
87
+ if (token_id := TokenId.parse(jti)) is None:
88
+ return None
89
+ return await self._get(token_id)
90
+
91
+ @singledispatchmethod
92
+ async def _get(self, _: TokenId) -> Optional[ClaimSet]:
93
+ return None
94
+
95
+ @_get.register
96
+ async def _(self, token_id: PasswordResetTokenId) -> Optional[ClaimSet]:
97
+ return await self._password_reset_token_store.get(token_id)
98
+
99
+ @_get.register
100
+ async def _(self, token_id: AccessTokenId) -> Optional[ClaimSet]:
101
+ return await self._access_token_store.get(token_id)
102
+
103
+ @_get.register
104
+ async def _(self, token_id: RefreshTokenId) -> Optional[ClaimSet]:
105
+ return await self._refresh_token_store.get(token_id)
106
+
107
+ @_get.register
108
+ async def _(self, token_id: ApiKeyId) -> Optional[ClaimSet]:
109
+ return await self._api_key_store.get(token_id)
110
+
111
+ @singledispatchmethod
112
+ async def _evict(self, _: TokenId) -> Optional[ClaimSet]:
113
+ return None
114
+
115
+ @_evict.register
116
+ async def _(self, token_id: PasswordResetTokenId) -> Optional[ClaimSet]:
117
+ return await self._password_reset_token_store.evict(token_id)
118
+
119
+ @_evict.register
120
+ async def _(self, token_id: AccessTokenId) -> Optional[ClaimSet]:
121
+ return await self._access_token_store.evict(token_id)
122
+
123
+ @_evict.register
124
+ async def _(self, token_id: RefreshTokenId) -> Optional[ClaimSet]:
125
+ return await self._refresh_token_store.evict(token_id)
126
+
127
+ @_evict.register
128
+ async def _(self, token_id: ApiKeyId) -> Optional[ClaimSet]:
129
+ return await self._api_key_store.evict(token_id)
130
+
131
+ async def create_password_reset_token(
132
+ self,
133
+ claim: PasswordResetTokenClaims,
134
+ ) -> Tuple[PasswordResetToken, PasswordResetTokenId]:
135
+ return await self._password_reset_token_store.create(claim)
136
+
137
+ async def create_access_token(
138
+ self,
139
+ claim: AccessTokenClaims,
140
+ ) -> Tuple[AccessToken, AccessTokenId]:
141
+ return await self._access_token_store.create(claim)
142
+
143
+ async def create_refresh_token(
144
+ self,
145
+ claim: RefreshTokenClaims,
146
+ ) -> Tuple[RefreshToken, RefreshTokenId]:
147
+ return await self._refresh_token_store.create(claim)
148
+
149
+ async def create_api_key(
150
+ self,
151
+ claim: ApiKeyClaims,
152
+ ) -> Tuple[ApiKey, ApiKeyId]:
153
+ return await self._api_key_store.create(claim)
154
+
155
+ async def revoke(self, *token_ids: TokenId) -> None:
156
+ if not token_ids:
157
+ return
158
+ password_reset_token_ids: List[PasswordResetTokenId] = []
159
+ access_token_ids: List[AccessTokenId] = []
160
+ refresh_token_ids: List[RefreshTokenId] = []
161
+ api_key_ids: List[ApiKeyId] = []
162
+ for token_id in token_ids:
163
+ if isinstance(token_id, PasswordResetTokenId):
164
+ password_reset_token_ids.append(token_id)
165
+ if isinstance(token_id, AccessTokenId):
166
+ access_token_ids.append(token_id)
167
+ elif isinstance(token_id, RefreshTokenId):
168
+ refresh_token_ids.append(token_id)
169
+ elif isinstance(token_id, ApiKeyId):
170
+ api_key_ids.append(token_id)
171
+ coroutines: List[Coroutine[None, None, None]] = []
172
+ if password_reset_token_ids:
173
+ coroutines.append(self._password_reset_token_store.revoke(*password_reset_token_ids))
174
+ if access_token_ids:
175
+ coroutines.append(self._access_token_store.revoke(*access_token_ids))
176
+ if refresh_token_ids:
177
+ coroutines.append(self._refresh_token_store.revoke(*refresh_token_ids))
178
+ if api_key_ids:
179
+ coroutines.append(self._api_key_store.revoke(*api_key_ids))
180
+ await gather(*coroutines)
181
+
182
+ async def log_out(self, user_id: UserId) -> None:
183
+ for cls in (AccessTokenId, RefreshTokenId):
184
+ table = cls.table
185
+ stmt = delete(table).where(table.user_id == int(user_id)).returning(table.id)
186
+ async with self._db() as session:
187
+ async for id_ in await session.stream_scalars(stmt):
188
+ await self._evict(cls(id_))
189
+
190
+
191
+ _TokenT = TypeVar("_TokenT", bound=Token)
192
+ _TokenIdT = TypeVar("_TokenIdT", bound=TokenId)
193
+ _ClaimSetT = TypeVar("_ClaimSetT", bound=ClaimSet)
194
+ _RecordT = TypeVar(
195
+ "_RecordT",
196
+ models.PasswordResetToken,
197
+ models.AccessToken,
198
+ models.RefreshToken,
199
+ models.ApiKey,
200
+ )
201
+
202
+
203
+ class _Claims(Generic[_TokenIdT, _ClaimSetT]):
204
+ def __init__(self) -> None:
205
+ self._cache: Dict[_TokenIdT, _ClaimSetT] = {}
206
+
207
+ def __getitem__(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
208
+ claim = self._cache.get(token_id)
209
+ return deepcopy(claim) if claim else None
210
+
211
+ def __setitem__(self, token_id: _TokenIdT, claim: _ClaimSetT) -> None:
212
+ self._cache[token_id] = deepcopy(claim)
213
+
214
+ def get(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
215
+ claim = self._cache.get(token_id)
216
+ return deepcopy(claim) if claim else None
217
+
218
+ def pop(
219
+ self, token_id: _TokenIdT, default: Optional[_ClaimSetT] = None
220
+ ) -> Optional[_ClaimSetT]:
221
+ claim = self._cache.pop(token_id, default)
222
+ return deepcopy(claim) if claim else None
223
+
224
+
225
+ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC):
226
+ _table: Type[_RecordT]
227
+ _token_id: Callable[[int], _TokenIdT]
228
+ _token: Callable[[str], _TokenT]
229
+
230
+ def __init__(
231
+ self,
232
+ db: DbSessionFactory,
233
+ secret: str,
234
+ algorithm: str = JWT_ALGORITHM,
235
+ sleep_seconds: int = 10,
236
+ **kwargs: Any,
237
+ ) -> None:
238
+ assert secret
239
+ super().__init__(**kwargs)
240
+ self._db = db
241
+ self._seconds = sleep_seconds
242
+ self._claims: _Claims[_TokenIdT, _ClaimSetT] = _Claims()
243
+ self._secret = secret
244
+ self._algorithm = algorithm
245
+
246
+ def _encode(self, claim: ClaimSet) -> str:
247
+ payload: Dict[str, Any] = dict(jti=claim.token_id)
248
+ header = {"alg": self._algorithm}
249
+ jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=self._secret)
250
+ return jwt_bytes.decode()
251
+
252
+ async def get(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
253
+ if claims := self._claims.get(token_id):
254
+ return claims
255
+ stmt = self._update_stmt.where(self._table.id == int(token_id))
256
+ async with self._db() as session:
257
+ record = (await session.execute(stmt)).first()
258
+ if not record:
259
+ return None
260
+ token, role = record
261
+ _, claims = self._from_db(token, UserRole(role))
262
+ self._claims[token_id] = claims
263
+ return claims
264
+
265
+ async def evict(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
266
+ return self._claims.pop(token_id, None)
267
+
268
+ async def revoke(self, *token_ids: _TokenIdT) -> None:
269
+ if not token_ids:
270
+ return
271
+ for token_id in token_ids:
272
+ await self.evict(token_id)
273
+ stmt = delete(self._table).where(self._table.id.in_(map(int, token_ids)))
274
+ async with self._db() as session:
275
+ await session.execute(stmt)
276
+
277
+ @abstractmethod
278
+ def _from_db(self, record: _RecordT, role: UserRole) -> Tuple[_TokenIdT, _ClaimSetT]: ...
279
+
280
+ @abstractmethod
281
+ def _to_db(self, claims: _ClaimSetT) -> _RecordT: ...
282
+
283
+ async def create(self, claim: _ClaimSetT) -> Tuple[_TokenT, _TokenIdT]:
284
+ record = self._to_db(claim)
285
+ async with self._db() as session:
286
+ session.add(record)
287
+ await session.flush()
288
+ token_id = self._token_id(record.id)
289
+ claim = replace(claim, token_id=token_id)
290
+ self._claims[token_id] = claim
291
+ token = self._token(self._encode(claim))
292
+ return token, token_id
293
+
294
+ async def _update(self) -> None:
295
+ claims: _Claims[_TokenIdT, _ClaimSetT] = _Claims()
296
+ async with self._db() as session:
297
+ async with session.begin_nested():
298
+ await self._delete_expired_tokens(session)
299
+ async with session.begin_nested():
300
+ async for record, role in await session.stream(self._update_stmt):
301
+ token_id, claim_set = self._from_db(record, UserRole(role))
302
+ claims[token_id] = claim_set
303
+ self._claims = claims
304
+
305
+ @cached_property
306
+ def _update_stmt(self) -> Select[Tuple[_RecordT, str]]:
307
+ return (
308
+ select(self._table, models.UserRole.name)
309
+ .join_from(self._table, models.User)
310
+ .join_from(models.User, models.UserRole)
311
+ )
312
+
313
+ async def _delete_expired_tokens(self, session: Any) -> None:
314
+ now = datetime.now(timezone.utc)
315
+ await session.execute(delete(self._table).where(self._table.expires_at < now))
316
+
317
+ async def _run(self) -> None:
318
+ while self._running:
319
+ self._tasks.append(create_task(self._update()))
320
+ await self._tasks[-1]
321
+ self._tasks.pop()
322
+ self._tasks.append(create_task(sleep(self._seconds)))
323
+ await self._tasks[-1]
324
+ self._tasks.pop()
325
+
326
+
327
+ class _PasswordResetTokenStore(
328
+ _Store[
329
+ PasswordResetTokenClaims,
330
+ PasswordResetToken,
331
+ PasswordResetTokenId,
332
+ models.PasswordResetToken,
333
+ ]
334
+ ):
335
+ _table = models.PasswordResetToken
336
+ _token_id = PasswordResetTokenId
337
+ _token = PasswordResetToken
338
+
339
+ def _from_db(
340
+ self,
341
+ record: models.PasswordResetToken,
342
+ user_role: UserRole,
343
+ ) -> Tuple[PasswordResetTokenId, PasswordResetTokenClaims]:
344
+ token_id = PasswordResetTokenId(record.id)
345
+ return token_id, PasswordResetTokenClaims(
346
+ token_id=token_id,
347
+ subject=UserId(record.user_id),
348
+ issued_at=record.created_at,
349
+ expiration_time=record.expires_at,
350
+ attributes=PasswordResetTokenAttributes(
351
+ user_role=user_role,
352
+ ),
353
+ )
354
+
355
+ def _to_db(self, claim: PasswordResetTokenClaims) -> models.PasswordResetToken:
356
+ assert claim.expiration_time
357
+ assert claim.subject
358
+ user_id = int(claim.subject)
359
+ return models.PasswordResetToken(
360
+ user_id=user_id,
361
+ created_at=claim.issued_at,
362
+ expires_at=claim.expiration_time,
363
+ )
364
+
365
+
366
+ class _AccessTokenStore(
367
+ _Store[
368
+ AccessTokenClaims,
369
+ AccessToken,
370
+ AccessTokenId,
371
+ models.AccessToken,
372
+ ]
373
+ ):
374
+ _table = models.AccessToken
375
+ _token_id = AccessTokenId
376
+ _token = AccessToken
377
+
378
+ def _from_db(
379
+ self,
380
+ record: models.AccessToken,
381
+ user_role: UserRole,
382
+ ) -> Tuple[AccessTokenId, AccessTokenClaims]:
383
+ token_id = AccessTokenId(record.id)
384
+ refresh_token_id = RefreshTokenId(record.refresh_token_id)
385
+ return token_id, AccessTokenClaims(
386
+ token_id=token_id,
387
+ subject=UserId(record.user_id),
388
+ issued_at=record.created_at,
389
+ expiration_time=record.expires_at,
390
+ attributes=AccessTokenAttributes(
391
+ user_role=user_role,
392
+ refresh_token_id=refresh_token_id,
393
+ ),
394
+ )
395
+
396
+ def _to_db(self, claim: AccessTokenClaims) -> models.AccessToken:
397
+ assert claim.expiration_time
398
+ assert claim.subject
399
+ user_id = int(claim.subject)
400
+ assert claim.attributes
401
+ refresh_token_id = int(claim.attributes.refresh_token_id)
402
+ return models.AccessToken(
403
+ user_id=user_id,
404
+ created_at=claim.issued_at,
405
+ expires_at=claim.expiration_time,
406
+ refresh_token_id=refresh_token_id,
407
+ )
408
+
409
+
410
+ class _RefreshTokenStore(
411
+ _Store[
412
+ RefreshTokenClaims,
413
+ RefreshToken,
414
+ RefreshTokenId,
415
+ models.RefreshToken,
416
+ ]
417
+ ):
418
+ _table = models.RefreshToken
419
+ _token_id = RefreshTokenId
420
+ _token = RefreshToken
421
+
422
+ def _from_db(
423
+ self,
424
+ record: models.RefreshToken,
425
+ user_role: UserRole,
426
+ ) -> Tuple[RefreshTokenId, RefreshTokenClaims]:
427
+ token_id = RefreshTokenId(record.id)
428
+ return token_id, RefreshTokenClaims(
429
+ token_id=token_id,
430
+ subject=UserId(record.user_id),
431
+ issued_at=record.created_at,
432
+ expiration_time=record.expires_at,
433
+ attributes=RefreshTokenAttributes(
434
+ user_role=user_role,
435
+ ),
436
+ )
437
+
438
+ def _to_db(self, claims: RefreshTokenClaims) -> models.RefreshToken:
439
+ assert claims.expiration_time
440
+ assert claims.subject
441
+ user_id = int(claims.subject)
442
+ return models.RefreshToken(
443
+ user_id=user_id,
444
+ created_at=claims.issued_at,
445
+ expires_at=claims.expiration_time,
446
+ )
447
+
448
+ async def _update(self) -> None:
449
+ await super()._update()
450
+ if get_env_enable_prometheus():
451
+ from phoenix.server.prometheus import JWT_STORE_TOKENS_ACTIVE
452
+
453
+ JWT_STORE_TOKENS_ACTIVE.set(len(self._claims._cache))
454
+
455
+
456
+ class _ApiKeyStore(
457
+ _Store[
458
+ ApiKeyClaims,
459
+ ApiKey,
460
+ ApiKeyId,
461
+ models.ApiKey,
462
+ ]
463
+ ):
464
+ _table = models.ApiKey
465
+ _token_id = ApiKeyId
466
+ _token = ApiKey
467
+
468
+ def _from_db(
469
+ self,
470
+ record: models.ApiKey,
471
+ user_role: UserRole,
472
+ ) -> Tuple[ApiKeyId, ApiKeyClaims]:
473
+ token_id = ApiKeyId(record.id)
474
+ return token_id, ApiKeyClaims(
475
+ token_id=token_id,
476
+ subject=UserId(record.user_id),
477
+ issued_at=record.created_at,
478
+ expiration_time=record.expires_at,
479
+ attributes=ApiKeyAttributes(
480
+ user_role=user_role,
481
+ name=record.name,
482
+ description=record.description,
483
+ ),
484
+ )
485
+
486
+ def _to_db(self, claims: ApiKeyClaims) -> models.ApiKey:
487
+ assert claims.attributes
488
+ assert claims.attributes.name
489
+ assert claims.subject
490
+ user_id = int(claims.subject)
491
+ return models.ApiKey(
492
+ user_id=user_id,
493
+ name=claims.attributes.name,
494
+ description=claims.attributes.description or None,
495
+ created_at=claims.issued_at,
496
+ expires_at=claims.expiration_time or None,
497
+ )
498
+
499
+ async def _update(self) -> None:
500
+ await super()._update()
501
+ if get_env_enable_prometheus():
502
+ from phoenix.server.prometheus import JWT_STORE_API_KEYS_ACTIVE
503
+
504
+ JWT_STORE_API_KEYS_ACTIVE.set(len(self._claims._cache))
phoenix/server/main.py CHANGED
@@ -10,13 +10,15 @@ from time import sleep, time
10
10
  from typing import List, Optional
11
11
  from urllib.parse import urljoin
12
12
 
13
+ from fastapi_mail import ConnectionConfig
13
14
  from jinja2 import BaseLoader, Environment
14
15
  from uvicorn import Config, Server
15
16
 
16
17
  import phoenix.trace.v1 as pb
17
18
  from phoenix.config import (
18
19
  EXPORT_DIR,
19
- get_auth_settings,
20
+ get_env_access_token_expiry,
21
+ get_env_auth_settings,
20
22
  get_env_database_connection_str,
21
23
  get_env_database_schema,
22
24
  get_env_db_logging_level,
@@ -27,7 +29,16 @@ from phoenix.config import (
27
29
  get_env_log_migrations,
28
30
  get_env_logging_level,
29
31
  get_env_logging_mode,
32
+ get_env_oauth2_settings,
33
+ get_env_password_reset_token_expiry,
30
34
  get_env_port,
35
+ get_env_refresh_token_expiry,
36
+ get_env_smtp_hostname,
37
+ get_env_smtp_mail_from,
38
+ get_env_smtp_password,
39
+ get_env_smtp_port,
40
+ get_env_smtp_username,
41
+ get_env_smtp_validate_certs,
31
42
  get_pids_path,
32
43
  )
33
44
  from phoenix.core.model_schema_adapter import create_model_from_inferences
@@ -48,6 +59,7 @@ from phoenix.server.app import (
48
59
  create_engine_and_run_migrations,
49
60
  instrument_engine_if_enabled,
50
61
  )
62
+ from phoenix.server.email.sender import EMAIL_TEMPLATE_FOLDER, FastMailSender
51
63
  from phoenix.server.types import DbSessionFactory
52
64
  from phoenix.settings import Settings
53
65
  from phoenix.trace.fixtures import (
@@ -83,6 +95,7 @@ _WELCOME_MESSAGE = Environment(loader=BaseLoader()).from_string("""
83
95
  |
84
96
  | 🚀 Phoenix Server 🚀
85
97
  | Phoenix UI: {{ ui_path }}
98
+ | Authentication: {{ auth_enabled }}
86
99
  | Log traces:
87
100
  | - gRPC: {{ grpc_path }}
88
101
  | - HTTP: {{ http_path }}
@@ -92,11 +105,6 @@ _WELCOME_MESSAGE = Environment(loader=BaseLoader()).from_string("""
92
105
  {% endif -%}
93
106
  """)
94
107
 
95
- _EXPERIMENTAL_WARNING = """
96
- 🚨 WARNING: Phoenix is running in experimental mode. 🚨
97
- | Authentication enabled: {auth_enabled}
98
- """
99
-
100
108
 
101
109
  def _write_pid_file_when_ready(
102
110
  server: Server,
@@ -298,7 +306,7 @@ def main() -> None:
298
306
  reference_inferences,
299
307
  )
300
308
 
301
- authentication_enabled, secret = get_auth_settings()
309
+ authentication_enabled, secret = get_env_auth_settings()
302
310
 
303
311
  fixture_spans: List[Span] = []
304
312
  fixture_evals: List[pb.Evaluation] = []
@@ -345,9 +353,8 @@ def main() -> None:
345
353
  http_path=urljoin(root_path, "v1/traces"),
346
354
  storage=get_printable_db_url(db_connection_str),
347
355
  schema=get_env_database_schema(),
356
+ auth_enabled=authentication_enabled,
348
357
  )
349
- if authentication_enabled:
350
- msg += _EXPERIMENTAL_WARNING.format(auth_enabled=True)
351
358
  if sys.platform.startswith("win"):
352
359
  msg = codecs.encode(msg, "ascii", errors="ignore").decode("ascii").strip()
353
360
  scaffolder_config = ScaffolderConfig(
@@ -357,6 +364,25 @@ def main() -> None:
357
364
  scaffold_datasets=scaffold_datasets,
358
365
  phoenix_url=root_path,
359
366
  )
367
+ email_sender = None
368
+ if mail_sever := get_env_smtp_hostname():
369
+ assert (mail_username := get_env_smtp_username()), "SMTP username is required"
370
+ assert (mail_password := get_env_smtp_password()), "SMTP password is required"
371
+ assert (mail_from := get_env_smtp_mail_from()), "SMTP mail_from is required"
372
+ email_sender = FastMailSender(
373
+ ConnectionConfig(
374
+ MAIL_USERNAME=mail_username,
375
+ MAIL_PASSWORD=mail_password,
376
+ MAIL_FROM=mail_from,
377
+ MAIL_SERVER=mail_sever,
378
+ MAIL_PORT=get_env_smtp_port(),
379
+ VALIDATE_CERTS=get_env_smtp_validate_certs(),
380
+ USE_CREDENTIALS=True,
381
+ MAIL_STARTTLS=True,
382
+ MAIL_SSL_TLS=False,
383
+ TEMPLATE_FOLDER=EMAIL_TEMPLATE_FOLDER,
384
+ )
385
+ )
360
386
  app = create_app(
361
387
  db=factory,
362
388
  export_path=export_path,
@@ -374,7 +400,12 @@ def main() -> None:
374
400
  startup_callbacks=[lambda: print(msg)],
375
401
  shutdown_callbacks=instrumentation_cleanups,
376
402
  secret=secret,
403
+ password_reset_token_expiry=get_env_password_reset_token_expiry(),
404
+ access_token_expiry=get_env_access_token_expiry(),
405
+ refresh_token_expiry=get_env_refresh_token_expiry(),
377
406
  scaffolder_config=scaffolder_config,
407
+ email_sender=email_sender,
408
+ oauth2_client_configs=get_env_oauth2_settings(),
378
409
  )
379
410
  server = Server(config=Config(app, host=host, port=port, root_path=host_root_path)) # type: ignore
380
411
  Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start()
@@ -0,0 +1,51 @@
1
+ from typing import Any, Dict, Iterable
2
+
3
+ from authlib.integrations.base_client import BaseApp
4
+ from authlib.integrations.base_client.async_app import AsyncOAuth2Mixin
5
+ from authlib.integrations.base_client.async_openid import AsyncOpenIDMixin
6
+ from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncHttpxOAuth2Client
7
+
8
+ from phoenix.config import OAuth2ClientConfig
9
+
10
+
11
+ class OAuth2Client(AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp): # type:ignore[misc]
12
+ """
13
+ An OAuth2 client class that supports OpenID Connect. Adapted from authlib's
14
+ `StarletteOAuth2App` to be useable without integration with Starlette.
15
+
16
+ https://github.com/lepture/authlib/blob/904d66bebd79bf39fb8814353a22bab7d3e092c4/authlib/integrations/starlette_client/apps.py#L58
17
+ """
18
+
19
+ client_cls = AsyncHttpxOAuth2Client
20
+
21
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
22
+ super().__init__(framework=None, *args, **kwargs)
23
+
24
+
25
+ class OAuth2Clients:
26
+ def __init__(self) -> None:
27
+ self._clients: Dict[str, OAuth2Client] = {}
28
+
29
+ def add_client(self, config: OAuth2ClientConfig) -> None:
30
+ if (idp_name := config.idp_name) in self._clients:
31
+ raise ValueError(f"oauth client already registered: {idp_name}")
32
+ client = OAuth2Client(
33
+ client_id=config.client_id,
34
+ client_secret=config.client_secret,
35
+ server_metadata_url=config.oidc_config_url,
36
+ client_kwargs={"scope": "openid email profile"},
37
+ )
38
+ assert isinstance(client, OAuth2Client)
39
+ self._clients[config.idp_name] = client
40
+
41
+ def get_client(self, idp_name: str) -> OAuth2Client:
42
+ if (client := self._clients.get(idp_name)) is None:
43
+ raise ValueError(f"unknown or unregistered OAuth2 client: {idp_name}")
44
+ return client
45
+
46
+ @classmethod
47
+ def from_configs(cls, configs: Iterable[OAuth2ClientConfig]) -> "OAuth2Clients":
48
+ oauth2_clients = cls()
49
+ for config in configs:
50
+ oauth2_clients.add_client(config)
51
+ return oauth2_clients
@@ -50,6 +50,26 @@ BULK_LOADER_EXCEPTIONS = Counter(
50
50
  documentation="Total count of bulk loader exceptions",
51
51
  )
52
52
 
53
+ RATE_LIMITER_CACHE_SIZE = Gauge(
54
+ name="rate_limiter_cache_size",
55
+ documentation="Current size of the rate limiter cache",
56
+ )
57
+
58
+ RATE_LIMITER_THROTTLES = Counter(
59
+ name="rate_limiter_throttles_total",
60
+ documentation="Total count of rate limiter throttles",
61
+ )
62
+
63
+ JWT_STORE_TOKENS_ACTIVE = Gauge(
64
+ name="jwt_store_tokens_active",
65
+ documentation="Current number of refresh tokens in the JWT store",
66
+ )
67
+
68
+ JWT_STORE_API_KEYS_ACTIVE = Gauge(
69
+ name="jwt_store_api_keys_active",
70
+ documentation="Current number of API keys in the JWT store",
71
+ )
72
+
53
73
 
54
74
  class PrometheusMiddleware(BaseHTTPMiddleware):
55
75
  async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: