arize-phoenix 4.36.0__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/METADATA +10 -12
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/RECORD +69 -60
- phoenix/__init__.py +86 -0
- phoenix/auth.py +275 -14
- phoenix/config.py +277 -25
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +112 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/models.py +145 -60
- phoenix/experiments/evaluators/code_evaluators.py +9 -3
- phoenix/experiments/functions.py +1 -4
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +32 -0
- phoenix/server/api/context.py +50 -2
- phoenix/server/api/dataloaders/__init__.py +4 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +7 -0
- phoenix/server/api/mutations/__init__.py +0 -2
- phoenix/server/api/mutations/api_key_mutations.py +104 -86
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/project_mutations.py +3 -3
- phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/user_mutations.py +282 -42
- phoenix/server/api/openapi/schema.py +2 -2
- phoenix/server/api/queries.py +48 -39
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +456 -0
- phoenix/server/api/routers/v1/__init__.py +38 -16
- phoenix/server/api/types/ApiKey.py +11 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/User.py +48 -4
- phoenix/server/api/types/UserApiKey.py +35 -1
- phoenix/server/api/types/UserRole.py +7 -0
- phoenix/server/app.py +103 -31
- phoenix/server/bearer_auth.py +161 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +26 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +6 -0
- phoenix/server/jwt_store.py +504 -0
- phoenix/server/main.py +40 -9
- phoenix/server/oauth2.py +51 -0
- phoenix/server/prometheus.py +20 -0
- phoenix/server/rate_limiters.py +191 -0
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-Dte7_KRd.js → components-REunxTt6.js} +348 -286
- phoenix/server/static/assets/index-DAPJxlCw.js +101 -0
- phoenix/server/static/assets/{pages-CnTvEGEN.js → pages-1VrMk2pW.js} +559 -291
- phoenix/server/static/assets/{vendor-BC3OPQuM.js → vendor-B5IC0ivG.js} +5 -5
- phoenix/server/static/assets/{vendor-arizeai-NjB3cZzD.js → vendor-arizeai-aFbT4kl1.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-gE_JCOgX.js → vendor-codemirror-BEGorXSV.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-BXLYwcXF.js → vendor-recharts-6nUU7gU_.js} +1 -1
- phoenix/server/templates/index.html +1 -0
- phoenix/server/types.py +157 -1
- phoenix/session/client.py +7 -2
- phoenix/trace/fixtures.py +24 -0
- phoenix/utilities/client.py +16 -0
- phoenix/version.py +1 -1
- phoenix/db/migrations/future_versions/README.md +0 -4
- phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +0 -293
- phoenix/db/migrations/versions/.gitignore +0 -1
- phoenix/server/api/mutations/auth.py +0 -18
- phoenix/server/api/mutations/auth_mutations.py +0 -65
- phoenix/server/static/assets/index-fq1-hCK4.js +0 -100
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -34
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from asyncio import create_task, gather, sleep
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from dataclasses import replace
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from functools import cached_property, singledispatchmethod
|
|
8
|
+
from typing import Any, Callable, Coroutine, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
|
9
|
+
|
|
10
|
+
from authlib.jose import jwt
|
|
11
|
+
from authlib.jose.errors import JoseError
|
|
12
|
+
from sqlalchemy import Select, delete, select
|
|
13
|
+
|
|
14
|
+
from phoenix.auth import (
|
|
15
|
+
JWT_ALGORITHM,
|
|
16
|
+
ClaimSet,
|
|
17
|
+
Token,
|
|
18
|
+
)
|
|
19
|
+
from phoenix.config import get_env_enable_prometheus
|
|
20
|
+
from phoenix.db import models
|
|
21
|
+
from phoenix.db.enums import UserRole
|
|
22
|
+
from phoenix.server.types import (
|
|
23
|
+
AccessToken,
|
|
24
|
+
AccessTokenAttributes,
|
|
25
|
+
AccessTokenClaims,
|
|
26
|
+
AccessTokenId,
|
|
27
|
+
ApiKey,
|
|
28
|
+
ApiKeyAttributes,
|
|
29
|
+
ApiKeyClaims,
|
|
30
|
+
ApiKeyId,
|
|
31
|
+
DaemonTask,
|
|
32
|
+
DbSessionFactory,
|
|
33
|
+
PasswordResetToken,
|
|
34
|
+
PasswordResetTokenAttributes,
|
|
35
|
+
PasswordResetTokenClaims,
|
|
36
|
+
PasswordResetTokenId,
|
|
37
|
+
RefreshToken,
|
|
38
|
+
RefreshTokenAttributes,
|
|
39
|
+
RefreshTokenClaims,
|
|
40
|
+
RefreshTokenId,
|
|
41
|
+
TokenId,
|
|
42
|
+
UserId,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class JwtStore:
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
db: DbSessionFactory,
|
|
52
|
+
secret: str,
|
|
53
|
+
algorithm: str = JWT_ALGORITHM,
|
|
54
|
+
sleep_seconds: int = 10,
|
|
55
|
+
**kwargs: Any,
|
|
56
|
+
) -> None:
|
|
57
|
+
assert secret
|
|
58
|
+
super().__init__(**kwargs)
|
|
59
|
+
self._db = db
|
|
60
|
+
self._secret = secret
|
|
61
|
+
args = (db, secret, algorithm, sleep_seconds)
|
|
62
|
+
self._password_reset_token_store = _PasswordResetTokenStore(*args, **kwargs)
|
|
63
|
+
self._access_token_store = _AccessTokenStore(*args, **kwargs)
|
|
64
|
+
self._refresh_token_store = _RefreshTokenStore(*args, **kwargs)
|
|
65
|
+
self._api_key_store = _ApiKeyStore(*args, **kwargs)
|
|
66
|
+
|
|
67
|
+
@cached_property
|
|
68
|
+
def _stores(self) -> Tuple[DaemonTask, ...]:
|
|
69
|
+
return tuple(dt for dt in self.__dict__.values() if isinstance(dt, _Store))
|
|
70
|
+
|
|
71
|
+
async def __aenter__(self) -> None:
|
|
72
|
+
await gather(*(s.__aenter__() for s in self._stores))
|
|
73
|
+
|
|
74
|
+
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
|
|
75
|
+
await gather(*(s.__aexit__(*args, **kwargs) for s in self._stores))
|
|
76
|
+
|
|
77
|
+
async def read(self, token: Token) -> Optional[ClaimSet]:
|
|
78
|
+
try:
|
|
79
|
+
payload = jwt.decode(
|
|
80
|
+
s=token,
|
|
81
|
+
key=self._secret,
|
|
82
|
+
)
|
|
83
|
+
except JoseError:
|
|
84
|
+
return None
|
|
85
|
+
if (jti := payload.get("jti")) is None:
|
|
86
|
+
return None
|
|
87
|
+
if (token_id := TokenId.parse(jti)) is None:
|
|
88
|
+
return None
|
|
89
|
+
return await self._get(token_id)
|
|
90
|
+
|
|
91
|
+
@singledispatchmethod
|
|
92
|
+
async def _get(self, _: TokenId) -> Optional[ClaimSet]:
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
@_get.register
|
|
96
|
+
async def _(self, token_id: PasswordResetTokenId) -> Optional[ClaimSet]:
|
|
97
|
+
return await self._password_reset_token_store.get(token_id)
|
|
98
|
+
|
|
99
|
+
@_get.register
|
|
100
|
+
async def _(self, token_id: AccessTokenId) -> Optional[ClaimSet]:
|
|
101
|
+
return await self._access_token_store.get(token_id)
|
|
102
|
+
|
|
103
|
+
@_get.register
|
|
104
|
+
async def _(self, token_id: RefreshTokenId) -> Optional[ClaimSet]:
|
|
105
|
+
return await self._refresh_token_store.get(token_id)
|
|
106
|
+
|
|
107
|
+
@_get.register
|
|
108
|
+
async def _(self, token_id: ApiKeyId) -> Optional[ClaimSet]:
|
|
109
|
+
return await self._api_key_store.get(token_id)
|
|
110
|
+
|
|
111
|
+
@singledispatchmethod
|
|
112
|
+
async def _evict(self, _: TokenId) -> Optional[ClaimSet]:
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
@_evict.register
|
|
116
|
+
async def _(self, token_id: PasswordResetTokenId) -> Optional[ClaimSet]:
|
|
117
|
+
return await self._password_reset_token_store.evict(token_id)
|
|
118
|
+
|
|
119
|
+
@_evict.register
|
|
120
|
+
async def _(self, token_id: AccessTokenId) -> Optional[ClaimSet]:
|
|
121
|
+
return await self._access_token_store.evict(token_id)
|
|
122
|
+
|
|
123
|
+
@_evict.register
|
|
124
|
+
async def _(self, token_id: RefreshTokenId) -> Optional[ClaimSet]:
|
|
125
|
+
return await self._refresh_token_store.evict(token_id)
|
|
126
|
+
|
|
127
|
+
@_evict.register
|
|
128
|
+
async def _(self, token_id: ApiKeyId) -> Optional[ClaimSet]:
|
|
129
|
+
return await self._api_key_store.evict(token_id)
|
|
130
|
+
|
|
131
|
+
async def create_password_reset_token(
|
|
132
|
+
self,
|
|
133
|
+
claim: PasswordResetTokenClaims,
|
|
134
|
+
) -> Tuple[PasswordResetToken, PasswordResetTokenId]:
|
|
135
|
+
return await self._password_reset_token_store.create(claim)
|
|
136
|
+
|
|
137
|
+
async def create_access_token(
|
|
138
|
+
self,
|
|
139
|
+
claim: AccessTokenClaims,
|
|
140
|
+
) -> Tuple[AccessToken, AccessTokenId]:
|
|
141
|
+
return await self._access_token_store.create(claim)
|
|
142
|
+
|
|
143
|
+
async def create_refresh_token(
|
|
144
|
+
self,
|
|
145
|
+
claim: RefreshTokenClaims,
|
|
146
|
+
) -> Tuple[RefreshToken, RefreshTokenId]:
|
|
147
|
+
return await self._refresh_token_store.create(claim)
|
|
148
|
+
|
|
149
|
+
async def create_api_key(
|
|
150
|
+
self,
|
|
151
|
+
claim: ApiKeyClaims,
|
|
152
|
+
) -> Tuple[ApiKey, ApiKeyId]:
|
|
153
|
+
return await self._api_key_store.create(claim)
|
|
154
|
+
|
|
155
|
+
async def revoke(self, *token_ids: TokenId) -> None:
|
|
156
|
+
if not token_ids:
|
|
157
|
+
return
|
|
158
|
+
password_reset_token_ids: List[PasswordResetTokenId] = []
|
|
159
|
+
access_token_ids: List[AccessTokenId] = []
|
|
160
|
+
refresh_token_ids: List[RefreshTokenId] = []
|
|
161
|
+
api_key_ids: List[ApiKeyId] = []
|
|
162
|
+
for token_id in token_ids:
|
|
163
|
+
if isinstance(token_id, PasswordResetTokenId):
|
|
164
|
+
password_reset_token_ids.append(token_id)
|
|
165
|
+
if isinstance(token_id, AccessTokenId):
|
|
166
|
+
access_token_ids.append(token_id)
|
|
167
|
+
elif isinstance(token_id, RefreshTokenId):
|
|
168
|
+
refresh_token_ids.append(token_id)
|
|
169
|
+
elif isinstance(token_id, ApiKeyId):
|
|
170
|
+
api_key_ids.append(token_id)
|
|
171
|
+
coroutines: List[Coroutine[None, None, None]] = []
|
|
172
|
+
if password_reset_token_ids:
|
|
173
|
+
coroutines.append(self._password_reset_token_store.revoke(*password_reset_token_ids))
|
|
174
|
+
if access_token_ids:
|
|
175
|
+
coroutines.append(self._access_token_store.revoke(*access_token_ids))
|
|
176
|
+
if refresh_token_ids:
|
|
177
|
+
coroutines.append(self._refresh_token_store.revoke(*refresh_token_ids))
|
|
178
|
+
if api_key_ids:
|
|
179
|
+
coroutines.append(self._api_key_store.revoke(*api_key_ids))
|
|
180
|
+
await gather(*coroutines)
|
|
181
|
+
|
|
182
|
+
async def log_out(self, user_id: UserId) -> None:
|
|
183
|
+
for cls in (AccessTokenId, RefreshTokenId):
|
|
184
|
+
table = cls.table
|
|
185
|
+
stmt = delete(table).where(table.user_id == int(user_id)).returning(table.id)
|
|
186
|
+
async with self._db() as session:
|
|
187
|
+
async for id_ in await session.stream_scalars(stmt):
|
|
188
|
+
await self._evict(cls(id_))
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
_TokenT = TypeVar("_TokenT", bound=Token)
|
|
192
|
+
_TokenIdT = TypeVar("_TokenIdT", bound=TokenId)
|
|
193
|
+
_ClaimSetT = TypeVar("_ClaimSetT", bound=ClaimSet)
|
|
194
|
+
_RecordT = TypeVar(
|
|
195
|
+
"_RecordT",
|
|
196
|
+
models.PasswordResetToken,
|
|
197
|
+
models.AccessToken,
|
|
198
|
+
models.RefreshToken,
|
|
199
|
+
models.ApiKey,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class _Claims(Generic[_TokenIdT, _ClaimSetT]):
|
|
204
|
+
def __init__(self) -> None:
|
|
205
|
+
self._cache: Dict[_TokenIdT, _ClaimSetT] = {}
|
|
206
|
+
|
|
207
|
+
def __getitem__(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
|
|
208
|
+
claim = self._cache.get(token_id)
|
|
209
|
+
return deepcopy(claim) if claim else None
|
|
210
|
+
|
|
211
|
+
def __setitem__(self, token_id: _TokenIdT, claim: _ClaimSetT) -> None:
|
|
212
|
+
self._cache[token_id] = deepcopy(claim)
|
|
213
|
+
|
|
214
|
+
def get(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
|
|
215
|
+
claim = self._cache.get(token_id)
|
|
216
|
+
return deepcopy(claim) if claim else None
|
|
217
|
+
|
|
218
|
+
def pop(
|
|
219
|
+
self, token_id: _TokenIdT, default: Optional[_ClaimSetT] = None
|
|
220
|
+
) -> Optional[_ClaimSetT]:
|
|
221
|
+
claim = self._cache.pop(token_id, default)
|
|
222
|
+
return deepcopy(claim) if claim else None
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC):
|
|
226
|
+
_table: Type[_RecordT]
|
|
227
|
+
_token_id: Callable[[int], _TokenIdT]
|
|
228
|
+
_token: Callable[[str], _TokenT]
|
|
229
|
+
|
|
230
|
+
def __init__(
|
|
231
|
+
self,
|
|
232
|
+
db: DbSessionFactory,
|
|
233
|
+
secret: str,
|
|
234
|
+
algorithm: str = JWT_ALGORITHM,
|
|
235
|
+
sleep_seconds: int = 10,
|
|
236
|
+
**kwargs: Any,
|
|
237
|
+
) -> None:
|
|
238
|
+
assert secret
|
|
239
|
+
super().__init__(**kwargs)
|
|
240
|
+
self._db = db
|
|
241
|
+
self._seconds = sleep_seconds
|
|
242
|
+
self._claims: _Claims[_TokenIdT, _ClaimSetT] = _Claims()
|
|
243
|
+
self._secret = secret
|
|
244
|
+
self._algorithm = algorithm
|
|
245
|
+
|
|
246
|
+
def _encode(self, claim: ClaimSet) -> str:
|
|
247
|
+
payload: Dict[str, Any] = dict(jti=claim.token_id)
|
|
248
|
+
header = {"alg": self._algorithm}
|
|
249
|
+
jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=self._secret)
|
|
250
|
+
return jwt_bytes.decode()
|
|
251
|
+
|
|
252
|
+
async def get(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
|
|
253
|
+
if claims := self._claims.get(token_id):
|
|
254
|
+
return claims
|
|
255
|
+
stmt = self._update_stmt.where(self._table.id == int(token_id))
|
|
256
|
+
async with self._db() as session:
|
|
257
|
+
record = (await session.execute(stmt)).first()
|
|
258
|
+
if not record:
|
|
259
|
+
return None
|
|
260
|
+
token, role = record
|
|
261
|
+
_, claims = self._from_db(token, UserRole(role))
|
|
262
|
+
self._claims[token_id] = claims
|
|
263
|
+
return claims
|
|
264
|
+
|
|
265
|
+
async def evict(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
|
|
266
|
+
return self._claims.pop(token_id, None)
|
|
267
|
+
|
|
268
|
+
async def revoke(self, *token_ids: _TokenIdT) -> None:
|
|
269
|
+
if not token_ids:
|
|
270
|
+
return
|
|
271
|
+
for token_id in token_ids:
|
|
272
|
+
await self.evict(token_id)
|
|
273
|
+
stmt = delete(self._table).where(self._table.id.in_(map(int, token_ids)))
|
|
274
|
+
async with self._db() as session:
|
|
275
|
+
await session.execute(stmt)
|
|
276
|
+
|
|
277
|
+
@abstractmethod
|
|
278
|
+
def _from_db(self, record: _RecordT, role: UserRole) -> Tuple[_TokenIdT, _ClaimSetT]: ...
|
|
279
|
+
|
|
280
|
+
@abstractmethod
|
|
281
|
+
def _to_db(self, claims: _ClaimSetT) -> _RecordT: ...
|
|
282
|
+
|
|
283
|
+
async def create(self, claim: _ClaimSetT) -> Tuple[_TokenT, _TokenIdT]:
|
|
284
|
+
record = self._to_db(claim)
|
|
285
|
+
async with self._db() as session:
|
|
286
|
+
session.add(record)
|
|
287
|
+
await session.flush()
|
|
288
|
+
token_id = self._token_id(record.id)
|
|
289
|
+
claim = replace(claim, token_id=token_id)
|
|
290
|
+
self._claims[token_id] = claim
|
|
291
|
+
token = self._token(self._encode(claim))
|
|
292
|
+
return token, token_id
|
|
293
|
+
|
|
294
|
+
async def _update(self) -> None:
|
|
295
|
+
claims: _Claims[_TokenIdT, _ClaimSetT] = _Claims()
|
|
296
|
+
async with self._db() as session:
|
|
297
|
+
async with session.begin_nested():
|
|
298
|
+
await self._delete_expired_tokens(session)
|
|
299
|
+
async with session.begin_nested():
|
|
300
|
+
async for record, role in await session.stream(self._update_stmt):
|
|
301
|
+
token_id, claim_set = self._from_db(record, UserRole(role))
|
|
302
|
+
claims[token_id] = claim_set
|
|
303
|
+
self._claims = claims
|
|
304
|
+
|
|
305
|
+
@cached_property
|
|
306
|
+
def _update_stmt(self) -> Select[Tuple[_RecordT, str]]:
|
|
307
|
+
return (
|
|
308
|
+
select(self._table, models.UserRole.name)
|
|
309
|
+
.join_from(self._table, models.User)
|
|
310
|
+
.join_from(models.User, models.UserRole)
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
async def _delete_expired_tokens(self, session: Any) -> None:
|
|
314
|
+
now = datetime.now(timezone.utc)
|
|
315
|
+
await session.execute(delete(self._table).where(self._table.expires_at < now))
|
|
316
|
+
|
|
317
|
+
async def _run(self) -> None:
|
|
318
|
+
while self._running:
|
|
319
|
+
self._tasks.append(create_task(self._update()))
|
|
320
|
+
await self._tasks[-1]
|
|
321
|
+
self._tasks.pop()
|
|
322
|
+
self._tasks.append(create_task(sleep(self._seconds)))
|
|
323
|
+
await self._tasks[-1]
|
|
324
|
+
self._tasks.pop()
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class _PasswordResetTokenStore(
|
|
328
|
+
_Store[
|
|
329
|
+
PasswordResetTokenClaims,
|
|
330
|
+
PasswordResetToken,
|
|
331
|
+
PasswordResetTokenId,
|
|
332
|
+
models.PasswordResetToken,
|
|
333
|
+
]
|
|
334
|
+
):
|
|
335
|
+
_table = models.PasswordResetToken
|
|
336
|
+
_token_id = PasswordResetTokenId
|
|
337
|
+
_token = PasswordResetToken
|
|
338
|
+
|
|
339
|
+
def _from_db(
|
|
340
|
+
self,
|
|
341
|
+
record: models.PasswordResetToken,
|
|
342
|
+
user_role: UserRole,
|
|
343
|
+
) -> Tuple[PasswordResetTokenId, PasswordResetTokenClaims]:
|
|
344
|
+
token_id = PasswordResetTokenId(record.id)
|
|
345
|
+
return token_id, PasswordResetTokenClaims(
|
|
346
|
+
token_id=token_id,
|
|
347
|
+
subject=UserId(record.user_id),
|
|
348
|
+
issued_at=record.created_at,
|
|
349
|
+
expiration_time=record.expires_at,
|
|
350
|
+
attributes=PasswordResetTokenAttributes(
|
|
351
|
+
user_role=user_role,
|
|
352
|
+
),
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
def _to_db(self, claim: PasswordResetTokenClaims) -> models.PasswordResetToken:
|
|
356
|
+
assert claim.expiration_time
|
|
357
|
+
assert claim.subject
|
|
358
|
+
user_id = int(claim.subject)
|
|
359
|
+
return models.PasswordResetToken(
|
|
360
|
+
user_id=user_id,
|
|
361
|
+
created_at=claim.issued_at,
|
|
362
|
+
expires_at=claim.expiration_time,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class _AccessTokenStore(
|
|
367
|
+
_Store[
|
|
368
|
+
AccessTokenClaims,
|
|
369
|
+
AccessToken,
|
|
370
|
+
AccessTokenId,
|
|
371
|
+
models.AccessToken,
|
|
372
|
+
]
|
|
373
|
+
):
|
|
374
|
+
_table = models.AccessToken
|
|
375
|
+
_token_id = AccessTokenId
|
|
376
|
+
_token = AccessToken
|
|
377
|
+
|
|
378
|
+
def _from_db(
|
|
379
|
+
self,
|
|
380
|
+
record: models.AccessToken,
|
|
381
|
+
user_role: UserRole,
|
|
382
|
+
) -> Tuple[AccessTokenId, AccessTokenClaims]:
|
|
383
|
+
token_id = AccessTokenId(record.id)
|
|
384
|
+
refresh_token_id = RefreshTokenId(record.refresh_token_id)
|
|
385
|
+
return token_id, AccessTokenClaims(
|
|
386
|
+
token_id=token_id,
|
|
387
|
+
subject=UserId(record.user_id),
|
|
388
|
+
issued_at=record.created_at,
|
|
389
|
+
expiration_time=record.expires_at,
|
|
390
|
+
attributes=AccessTokenAttributes(
|
|
391
|
+
user_role=user_role,
|
|
392
|
+
refresh_token_id=refresh_token_id,
|
|
393
|
+
),
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
def _to_db(self, claim: AccessTokenClaims) -> models.AccessToken:
|
|
397
|
+
assert claim.expiration_time
|
|
398
|
+
assert claim.subject
|
|
399
|
+
user_id = int(claim.subject)
|
|
400
|
+
assert claim.attributes
|
|
401
|
+
refresh_token_id = int(claim.attributes.refresh_token_id)
|
|
402
|
+
return models.AccessToken(
|
|
403
|
+
user_id=user_id,
|
|
404
|
+
created_at=claim.issued_at,
|
|
405
|
+
expires_at=claim.expiration_time,
|
|
406
|
+
refresh_token_id=refresh_token_id,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
class _RefreshTokenStore(
|
|
411
|
+
_Store[
|
|
412
|
+
RefreshTokenClaims,
|
|
413
|
+
RefreshToken,
|
|
414
|
+
RefreshTokenId,
|
|
415
|
+
models.RefreshToken,
|
|
416
|
+
]
|
|
417
|
+
):
|
|
418
|
+
_table = models.RefreshToken
|
|
419
|
+
_token_id = RefreshTokenId
|
|
420
|
+
_token = RefreshToken
|
|
421
|
+
|
|
422
|
+
def _from_db(
|
|
423
|
+
self,
|
|
424
|
+
record: models.RefreshToken,
|
|
425
|
+
user_role: UserRole,
|
|
426
|
+
) -> Tuple[RefreshTokenId, RefreshTokenClaims]:
|
|
427
|
+
token_id = RefreshTokenId(record.id)
|
|
428
|
+
return token_id, RefreshTokenClaims(
|
|
429
|
+
token_id=token_id,
|
|
430
|
+
subject=UserId(record.user_id),
|
|
431
|
+
issued_at=record.created_at,
|
|
432
|
+
expiration_time=record.expires_at,
|
|
433
|
+
attributes=RefreshTokenAttributes(
|
|
434
|
+
user_role=user_role,
|
|
435
|
+
),
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
def _to_db(self, claims: RefreshTokenClaims) -> models.RefreshToken:
|
|
439
|
+
assert claims.expiration_time
|
|
440
|
+
assert claims.subject
|
|
441
|
+
user_id = int(claims.subject)
|
|
442
|
+
return models.RefreshToken(
|
|
443
|
+
user_id=user_id,
|
|
444
|
+
created_at=claims.issued_at,
|
|
445
|
+
expires_at=claims.expiration_time,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
async def _update(self) -> None:
|
|
449
|
+
await super()._update()
|
|
450
|
+
if get_env_enable_prometheus():
|
|
451
|
+
from phoenix.server.prometheus import JWT_STORE_TOKENS_ACTIVE
|
|
452
|
+
|
|
453
|
+
JWT_STORE_TOKENS_ACTIVE.set(len(self._claims._cache))
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
class _ApiKeyStore(
|
|
457
|
+
_Store[
|
|
458
|
+
ApiKeyClaims,
|
|
459
|
+
ApiKey,
|
|
460
|
+
ApiKeyId,
|
|
461
|
+
models.ApiKey,
|
|
462
|
+
]
|
|
463
|
+
):
|
|
464
|
+
_table = models.ApiKey
|
|
465
|
+
_token_id = ApiKeyId
|
|
466
|
+
_token = ApiKey
|
|
467
|
+
|
|
468
|
+
def _from_db(
|
|
469
|
+
self,
|
|
470
|
+
record: models.ApiKey,
|
|
471
|
+
user_role: UserRole,
|
|
472
|
+
) -> Tuple[ApiKeyId, ApiKeyClaims]:
|
|
473
|
+
token_id = ApiKeyId(record.id)
|
|
474
|
+
return token_id, ApiKeyClaims(
|
|
475
|
+
token_id=token_id,
|
|
476
|
+
subject=UserId(record.user_id),
|
|
477
|
+
issued_at=record.created_at,
|
|
478
|
+
expiration_time=record.expires_at,
|
|
479
|
+
attributes=ApiKeyAttributes(
|
|
480
|
+
user_role=user_role,
|
|
481
|
+
name=record.name,
|
|
482
|
+
description=record.description,
|
|
483
|
+
),
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
def _to_db(self, claims: ApiKeyClaims) -> models.ApiKey:
|
|
487
|
+
assert claims.attributes
|
|
488
|
+
assert claims.attributes.name
|
|
489
|
+
assert claims.subject
|
|
490
|
+
user_id = int(claims.subject)
|
|
491
|
+
return models.ApiKey(
|
|
492
|
+
user_id=user_id,
|
|
493
|
+
name=claims.attributes.name,
|
|
494
|
+
description=claims.attributes.description or None,
|
|
495
|
+
created_at=claims.issued_at,
|
|
496
|
+
expires_at=claims.expiration_time or None,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
async def _update(self) -> None:
|
|
500
|
+
await super()._update()
|
|
501
|
+
if get_env_enable_prometheus():
|
|
502
|
+
from phoenix.server.prometheus import JWT_STORE_API_KEYS_ACTIVE
|
|
503
|
+
|
|
504
|
+
JWT_STORE_API_KEYS_ACTIVE.set(len(self._claims._cache))
|
phoenix/server/main.py
CHANGED
|
@@ -10,13 +10,15 @@ from time import sleep, time
|
|
|
10
10
|
from typing import List, Optional
|
|
11
11
|
from urllib.parse import urljoin
|
|
12
12
|
|
|
13
|
+
from fastapi_mail import ConnectionConfig
|
|
13
14
|
from jinja2 import BaseLoader, Environment
|
|
14
15
|
from uvicorn import Config, Server
|
|
15
16
|
|
|
16
17
|
import phoenix.trace.v1 as pb
|
|
17
18
|
from phoenix.config import (
|
|
18
19
|
EXPORT_DIR,
|
|
19
|
-
|
|
20
|
+
get_env_access_token_expiry,
|
|
21
|
+
get_env_auth_settings,
|
|
20
22
|
get_env_database_connection_str,
|
|
21
23
|
get_env_database_schema,
|
|
22
24
|
get_env_db_logging_level,
|
|
@@ -27,7 +29,16 @@ from phoenix.config import (
|
|
|
27
29
|
get_env_log_migrations,
|
|
28
30
|
get_env_logging_level,
|
|
29
31
|
get_env_logging_mode,
|
|
32
|
+
get_env_oauth2_settings,
|
|
33
|
+
get_env_password_reset_token_expiry,
|
|
30
34
|
get_env_port,
|
|
35
|
+
get_env_refresh_token_expiry,
|
|
36
|
+
get_env_smtp_hostname,
|
|
37
|
+
get_env_smtp_mail_from,
|
|
38
|
+
get_env_smtp_password,
|
|
39
|
+
get_env_smtp_port,
|
|
40
|
+
get_env_smtp_username,
|
|
41
|
+
get_env_smtp_validate_certs,
|
|
31
42
|
get_pids_path,
|
|
32
43
|
)
|
|
33
44
|
from phoenix.core.model_schema_adapter import create_model_from_inferences
|
|
@@ -48,6 +59,7 @@ from phoenix.server.app import (
|
|
|
48
59
|
create_engine_and_run_migrations,
|
|
49
60
|
instrument_engine_if_enabled,
|
|
50
61
|
)
|
|
62
|
+
from phoenix.server.email.sender import EMAIL_TEMPLATE_FOLDER, FastMailSender
|
|
51
63
|
from phoenix.server.types import DbSessionFactory
|
|
52
64
|
from phoenix.settings import Settings
|
|
53
65
|
from phoenix.trace.fixtures import (
|
|
@@ -83,6 +95,7 @@ _WELCOME_MESSAGE = Environment(loader=BaseLoader()).from_string("""
|
|
|
83
95
|
|
|
|
84
96
|
| 🚀 Phoenix Server 🚀
|
|
85
97
|
| Phoenix UI: {{ ui_path }}
|
|
98
|
+
| Authentication: {{ auth_enabled }}
|
|
86
99
|
| Log traces:
|
|
87
100
|
| - gRPC: {{ grpc_path }}
|
|
88
101
|
| - HTTP: {{ http_path }}
|
|
@@ -92,11 +105,6 @@ _WELCOME_MESSAGE = Environment(loader=BaseLoader()).from_string("""
|
|
|
92
105
|
{% endif -%}
|
|
93
106
|
""")
|
|
94
107
|
|
|
95
|
-
_EXPERIMENTAL_WARNING = """
|
|
96
|
-
🚨 WARNING: Phoenix is running in experimental mode. 🚨
|
|
97
|
-
| Authentication enabled: {auth_enabled}
|
|
98
|
-
"""
|
|
99
|
-
|
|
100
108
|
|
|
101
109
|
def _write_pid_file_when_ready(
|
|
102
110
|
server: Server,
|
|
@@ -298,7 +306,7 @@ def main() -> None:
|
|
|
298
306
|
reference_inferences,
|
|
299
307
|
)
|
|
300
308
|
|
|
301
|
-
authentication_enabled, secret =
|
|
309
|
+
authentication_enabled, secret = get_env_auth_settings()
|
|
302
310
|
|
|
303
311
|
fixture_spans: List[Span] = []
|
|
304
312
|
fixture_evals: List[pb.Evaluation] = []
|
|
@@ -345,9 +353,8 @@ def main() -> None:
|
|
|
345
353
|
http_path=urljoin(root_path, "v1/traces"),
|
|
346
354
|
storage=get_printable_db_url(db_connection_str),
|
|
347
355
|
schema=get_env_database_schema(),
|
|
356
|
+
auth_enabled=authentication_enabled,
|
|
348
357
|
)
|
|
349
|
-
if authentication_enabled:
|
|
350
|
-
msg += _EXPERIMENTAL_WARNING.format(auth_enabled=True)
|
|
351
358
|
if sys.platform.startswith("win"):
|
|
352
359
|
msg = codecs.encode(msg, "ascii", errors="ignore").decode("ascii").strip()
|
|
353
360
|
scaffolder_config = ScaffolderConfig(
|
|
@@ -357,6 +364,25 @@ def main() -> None:
|
|
|
357
364
|
scaffold_datasets=scaffold_datasets,
|
|
358
365
|
phoenix_url=root_path,
|
|
359
366
|
)
|
|
367
|
+
email_sender = None
|
|
368
|
+
if mail_sever := get_env_smtp_hostname():
|
|
369
|
+
assert (mail_username := get_env_smtp_username()), "SMTP username is required"
|
|
370
|
+
assert (mail_password := get_env_smtp_password()), "SMTP password is required"
|
|
371
|
+
assert (mail_from := get_env_smtp_mail_from()), "SMTP mail_from is required"
|
|
372
|
+
email_sender = FastMailSender(
|
|
373
|
+
ConnectionConfig(
|
|
374
|
+
MAIL_USERNAME=mail_username,
|
|
375
|
+
MAIL_PASSWORD=mail_password,
|
|
376
|
+
MAIL_FROM=mail_from,
|
|
377
|
+
MAIL_SERVER=mail_sever,
|
|
378
|
+
MAIL_PORT=get_env_smtp_port(),
|
|
379
|
+
VALIDATE_CERTS=get_env_smtp_validate_certs(),
|
|
380
|
+
USE_CREDENTIALS=True,
|
|
381
|
+
MAIL_STARTTLS=True,
|
|
382
|
+
MAIL_SSL_TLS=False,
|
|
383
|
+
TEMPLATE_FOLDER=EMAIL_TEMPLATE_FOLDER,
|
|
384
|
+
)
|
|
385
|
+
)
|
|
360
386
|
app = create_app(
|
|
361
387
|
db=factory,
|
|
362
388
|
export_path=export_path,
|
|
@@ -374,7 +400,12 @@ def main() -> None:
|
|
|
374
400
|
startup_callbacks=[lambda: print(msg)],
|
|
375
401
|
shutdown_callbacks=instrumentation_cleanups,
|
|
376
402
|
secret=secret,
|
|
403
|
+
password_reset_token_expiry=get_env_password_reset_token_expiry(),
|
|
404
|
+
access_token_expiry=get_env_access_token_expiry(),
|
|
405
|
+
refresh_token_expiry=get_env_refresh_token_expiry(),
|
|
377
406
|
scaffolder_config=scaffolder_config,
|
|
407
|
+
email_sender=email_sender,
|
|
408
|
+
oauth2_client_configs=get_env_oauth2_settings(),
|
|
378
409
|
)
|
|
379
410
|
server = Server(config=Config(app, host=host, port=port, root_path=host_root_path)) # type: ignore
|
|
380
411
|
Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start()
|
phoenix/server/oauth2.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import Any, Dict, Iterable
|
|
2
|
+
|
|
3
|
+
from authlib.integrations.base_client import BaseApp
|
|
4
|
+
from authlib.integrations.base_client.async_app import AsyncOAuth2Mixin
|
|
5
|
+
from authlib.integrations.base_client.async_openid import AsyncOpenIDMixin
|
|
6
|
+
from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncHttpxOAuth2Client
|
|
7
|
+
|
|
8
|
+
from phoenix.config import OAuth2ClientConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OAuth2Client(AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp): # type:ignore[misc]
|
|
12
|
+
"""
|
|
13
|
+
An OAuth2 client class that supports OpenID Connect. Adapted from authlib's
|
|
14
|
+
`StarletteOAuth2App` to be useable without integration with Starlette.
|
|
15
|
+
|
|
16
|
+
https://github.com/lepture/authlib/blob/904d66bebd79bf39fb8814353a22bab7d3e092c4/authlib/integrations/starlette_client/apps.py#L58
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
client_cls = AsyncHttpxOAuth2Client
|
|
20
|
+
|
|
21
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
22
|
+
super().__init__(framework=None, *args, **kwargs)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OAuth2Clients:
|
|
26
|
+
def __init__(self) -> None:
|
|
27
|
+
self._clients: Dict[str, OAuth2Client] = {}
|
|
28
|
+
|
|
29
|
+
def add_client(self, config: OAuth2ClientConfig) -> None:
|
|
30
|
+
if (idp_name := config.idp_name) in self._clients:
|
|
31
|
+
raise ValueError(f"oauth client already registered: {idp_name}")
|
|
32
|
+
client = OAuth2Client(
|
|
33
|
+
client_id=config.client_id,
|
|
34
|
+
client_secret=config.client_secret,
|
|
35
|
+
server_metadata_url=config.oidc_config_url,
|
|
36
|
+
client_kwargs={"scope": "openid email profile"},
|
|
37
|
+
)
|
|
38
|
+
assert isinstance(client, OAuth2Client)
|
|
39
|
+
self._clients[config.idp_name] = client
|
|
40
|
+
|
|
41
|
+
def get_client(self, idp_name: str) -> OAuth2Client:
|
|
42
|
+
if (client := self._clients.get(idp_name)) is None:
|
|
43
|
+
raise ValueError(f"unknown or unregistered OAuth2 client: {idp_name}")
|
|
44
|
+
return client
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_configs(cls, configs: Iterable[OAuth2ClientConfig]) -> "OAuth2Clients":
|
|
48
|
+
oauth2_clients = cls()
|
|
49
|
+
for config in configs:
|
|
50
|
+
oauth2_clients.add_client(config)
|
|
51
|
+
return oauth2_clients
|
phoenix/server/prometheus.py
CHANGED
|
@@ -50,6 +50,26 @@ BULK_LOADER_EXCEPTIONS = Counter(
|
|
|
50
50
|
documentation="Total count of bulk loader exceptions",
|
|
51
51
|
)
|
|
52
52
|
|
|
53
|
+
RATE_LIMITER_CACHE_SIZE = Gauge(
|
|
54
|
+
name="rate_limiter_cache_size",
|
|
55
|
+
documentation="Current size of the rate limiter cache",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
RATE_LIMITER_THROTTLES = Counter(
|
|
59
|
+
name="rate_limiter_throttles_total",
|
|
60
|
+
documentation="Total count of rate limiter throttles",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
JWT_STORE_TOKENS_ACTIVE = Gauge(
|
|
64
|
+
name="jwt_store_tokens_active",
|
|
65
|
+
documentation="Current number of refresh tokens in the JWT store",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
JWT_STORE_API_KEYS_ACTIVE = Gauge(
|
|
69
|
+
name="jwt_store_api_keys_active",
|
|
70
|
+
documentation="Current number of API keys in the JWT store",
|
|
71
|
+
)
|
|
72
|
+
|
|
53
73
|
|
|
54
74
|
class PrometheusMiddleware(BaseHTTPMiddleware):
|
|
55
75
|
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|