arize-phoenix 4.36.0__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/METADATA +10 -12
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/RECORD +69 -60
- phoenix/__init__.py +86 -0
- phoenix/auth.py +275 -14
- phoenix/config.py +277 -25
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +112 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/models.py +145 -60
- phoenix/experiments/evaluators/code_evaluators.py +9 -3
- phoenix/experiments/functions.py +1 -4
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +32 -0
- phoenix/server/api/context.py +50 -2
- phoenix/server/api/dataloaders/__init__.py +4 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +7 -0
- phoenix/server/api/mutations/__init__.py +0 -2
- phoenix/server/api/mutations/api_key_mutations.py +104 -86
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/project_mutations.py +3 -3
- phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/user_mutations.py +282 -42
- phoenix/server/api/openapi/schema.py +2 -2
- phoenix/server/api/queries.py +48 -39
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +456 -0
- phoenix/server/api/routers/v1/__init__.py +38 -16
- phoenix/server/api/types/ApiKey.py +11 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/User.py +48 -4
- phoenix/server/api/types/UserApiKey.py +35 -1
- phoenix/server/api/types/UserRole.py +7 -0
- phoenix/server/app.py +103 -31
- phoenix/server/bearer_auth.py +161 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +26 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +6 -0
- phoenix/server/jwt_store.py +504 -0
- phoenix/server/main.py +40 -9
- phoenix/server/oauth2.py +51 -0
- phoenix/server/prometheus.py +20 -0
- phoenix/server/rate_limiters.py +191 -0
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-Dte7_KRd.js → components-REunxTt6.js} +348 -286
- phoenix/server/static/assets/index-DAPJxlCw.js +101 -0
- phoenix/server/static/assets/{pages-CnTvEGEN.js → pages-1VrMk2pW.js} +559 -291
- phoenix/server/static/assets/{vendor-BC3OPQuM.js → vendor-B5IC0ivG.js} +5 -5
- phoenix/server/static/assets/{vendor-arizeai-NjB3cZzD.js → vendor-arizeai-aFbT4kl1.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-gE_JCOgX.js → vendor-codemirror-BEGorXSV.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-BXLYwcXF.js → vendor-recharts-6nUU7gU_.js} +1 -1
- phoenix/server/templates/index.html +1 -0
- phoenix/server/types.py +157 -1
- phoenix/session/client.py +7 -2
- phoenix/trace/fixtures.py +24 -0
- phoenix/utilities/client.py +16 -0
- phoenix/version.py +1 -1
- phoenix/db/migrations/future_versions/README.md +0 -4
- phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +0 -293
- phoenix/db/migrations/versions/.gitignore +0 -1
- phoenix/server/api/mutations/auth.py +0 -18
- phoenix/server/api/mutations/auth_mutations.py +0 -65
- phoenix/server/static/assets/index-fq1-hCK4.js +0 -100
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -34
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/app.py
CHANGED
|
@@ -20,21 +20,23 @@ from typing import (
|
|
|
20
20
|
List,
|
|
21
21
|
NamedTuple,
|
|
22
22
|
Optional,
|
|
23
|
+
Sequence,
|
|
23
24
|
Tuple,
|
|
25
|
+
TypedDict,
|
|
24
26
|
Union,
|
|
25
27
|
cast,
|
|
26
28
|
)
|
|
27
29
|
|
|
28
30
|
import strawberry
|
|
29
|
-
from fastapi import APIRouter, FastAPI
|
|
31
|
+
from fastapi import APIRouter, Depends, FastAPI
|
|
30
32
|
from fastapi.middleware.gzip import GZipMiddleware
|
|
31
|
-
from fastapi.responses import FileResponse
|
|
32
33
|
from fastapi.utils import is_body_allowed_for_status_code
|
|
33
34
|
from sqlalchemy import select
|
|
34
35
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
35
36
|
from starlette.datastructures import State as StarletteState
|
|
36
37
|
from starlette.exceptions import HTTPException
|
|
37
38
|
from starlette.middleware import Middleware
|
|
39
|
+
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
38
40
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
39
41
|
from starlette.requests import Request
|
|
40
42
|
from starlette.responses import PlainTextResponse, Response
|
|
@@ -50,6 +52,7 @@ import phoenix.trace.v1 as pb
|
|
|
50
52
|
from phoenix.config import (
|
|
51
53
|
DEFAULT_PROJECT_NAME,
|
|
52
54
|
SERVER_DIR,
|
|
55
|
+
OAuth2ClientConfig,
|
|
53
56
|
get_env_host,
|
|
54
57
|
get_env_port,
|
|
55
58
|
server_instrumentation_is_enabled,
|
|
@@ -58,6 +61,7 @@ from phoenix.core.model_schema import Model
|
|
|
58
61
|
from phoenix.db import models
|
|
59
62
|
from phoenix.db.bulk_inserter import BulkInserter
|
|
60
63
|
from phoenix.db.engines import create_engine
|
|
64
|
+
from phoenix.db.facilitator import Facilitator
|
|
61
65
|
from phoenix.db.helpers import SupportedSQLDialect
|
|
62
66
|
from phoenix.exceptions import PhoenixMigrationError
|
|
63
67
|
from phoenix.pointcloud.umap_parameters import UMAPParameters
|
|
@@ -86,13 +90,24 @@ from phoenix.server.api.dataloaders import (
|
|
|
86
90
|
SpanProjectsDataLoader,
|
|
87
91
|
TokenCountDataLoader,
|
|
88
92
|
TraceRowIdsDataLoader,
|
|
93
|
+
UserRolesDataLoader,
|
|
94
|
+
UsersDataLoader,
|
|
95
|
+
)
|
|
96
|
+
from phoenix.server.api.routers import (
|
|
97
|
+
auth_router,
|
|
98
|
+
create_embeddings_router,
|
|
99
|
+
create_v1_router,
|
|
100
|
+
oauth2_router,
|
|
89
101
|
)
|
|
90
102
|
from phoenix.server.api.routers.v1 import REST_API_VERSION
|
|
91
|
-
from phoenix.server.api.routers.v1 import router as v1_router
|
|
92
103
|
from phoenix.server.api.schema import schema
|
|
104
|
+
from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
|
|
93
105
|
from phoenix.server.dml_event import DmlEvent
|
|
94
106
|
from phoenix.server.dml_event_handler import DmlEventHandler
|
|
107
|
+
from phoenix.server.email.types import EmailSender
|
|
95
108
|
from phoenix.server.grpc_server import GrpcServer
|
|
109
|
+
from phoenix.server.jwt_store import JwtStore
|
|
110
|
+
from phoenix.server.oauth2 import OAuth2Clients
|
|
96
111
|
from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
|
|
97
112
|
from phoenix.server.types import (
|
|
98
113
|
CanGetLastUpdatedAt,
|
|
@@ -100,6 +115,7 @@ from phoenix.server.types import (
|
|
|
100
115
|
DaemonTask,
|
|
101
116
|
DbSessionFactory,
|
|
102
117
|
LastUpdatedAt,
|
|
118
|
+
TokenStore,
|
|
103
119
|
)
|
|
104
120
|
from phoenix.trace.fixtures import (
|
|
105
121
|
TracesFixture,
|
|
@@ -135,6 +151,11 @@ ProjectName: TypeAlias = str
|
|
|
135
151
|
_Callback: TypeAlias = Callable[[], Union[None, Awaitable[None]]]
|
|
136
152
|
|
|
137
153
|
|
|
154
|
+
class OAuth2Idp(TypedDict):
|
|
155
|
+
name: str
|
|
156
|
+
displayName: str
|
|
157
|
+
|
|
158
|
+
|
|
138
159
|
class AppConfig(NamedTuple):
|
|
139
160
|
has_inferences: bool
|
|
140
161
|
""" Whether the model has inferences (e.g. a primary dataset) """
|
|
@@ -146,6 +167,7 @@ class AppConfig(NamedTuple):
|
|
|
146
167
|
web_manifest_path: Path
|
|
147
168
|
authentication_enabled: bool
|
|
148
169
|
""" Whether authentication is enabled """
|
|
170
|
+
oauth2_idps: Sequence[OAuth2Idp]
|
|
149
171
|
|
|
150
172
|
|
|
151
173
|
class Static(StaticFiles):
|
|
@@ -194,6 +216,7 @@ class Static(StaticFiles):
|
|
|
194
216
|
"is_development": self._app_config.is_development,
|
|
195
217
|
"manifest": self._web_manifest,
|
|
196
218
|
"authentication_enabled": self._app_config.authentication_enabled,
|
|
219
|
+
"oauth2_idps": self._app_config.oauth2_idps,
|
|
197
220
|
},
|
|
198
221
|
)
|
|
199
222
|
except Exception as e:
|
|
@@ -218,18 +241,6 @@ class HeadersMiddleware(BaseHTTPMiddleware):
|
|
|
218
241
|
ProjectRowId: TypeAlias = int
|
|
219
242
|
|
|
220
243
|
|
|
221
|
-
@router.get("/exports")
|
|
222
|
-
async def download_exported_file(request: Request, filename: str) -> FileResponse:
|
|
223
|
-
file = request.app.state.export_path / (filename + ".parquet")
|
|
224
|
-
if not file.is_file():
|
|
225
|
-
raise HTTPException(status_code=404)
|
|
226
|
-
return FileResponse(
|
|
227
|
-
path=file,
|
|
228
|
-
filename=file.name,
|
|
229
|
-
media_type="application/x-octet-stream",
|
|
230
|
-
)
|
|
231
|
-
|
|
232
|
-
|
|
233
244
|
@router.get("/arize_phoenix_version")
|
|
234
245
|
async def version() -> PlainTextResponse:
|
|
235
246
|
return PlainTextResponse(f"{phoenix.__version__}")
|
|
@@ -238,13 +249,15 @@ async def version() -> PlainTextResponse:
|
|
|
238
249
|
DB_MUTEX: Optional[asyncio.Lock] = None
|
|
239
250
|
|
|
240
251
|
|
|
241
|
-
def _db(
|
|
252
|
+
def _db(
|
|
253
|
+
engine: AsyncEngine, bypass_lock: bool = False
|
|
254
|
+
) -> Callable[[], AsyncContextManager[AsyncSession]]:
|
|
242
255
|
Session = async_sessionmaker(engine, expire_on_commit=False)
|
|
243
256
|
|
|
244
257
|
@contextlib.asynccontextmanager
|
|
245
258
|
async def factory() -> AsyncIterator[AsyncSession]:
|
|
246
259
|
async with contextlib.AsyncExitStack() as stack:
|
|
247
|
-
if DB_MUTEX:
|
|
260
|
+
if not bypass_lock and DB_MUTEX:
|
|
248
261
|
await stack.enter_async_context(DB_MUTEX)
|
|
249
262
|
yield await stack.enter_async_context(Session.begin())
|
|
250
263
|
|
|
@@ -283,9 +296,6 @@ class Scaffolder(DaemonTask):
|
|
|
283
296
|
return
|
|
284
297
|
await self.start()
|
|
285
298
|
|
|
286
|
-
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
|
|
287
|
-
await self.stop()
|
|
288
|
-
|
|
289
299
|
async def _run(self) -> None:
|
|
290
300
|
"""
|
|
291
301
|
Main entry point for Scaffolder.
|
|
@@ -374,6 +384,7 @@ def _lifespan(
|
|
|
374
384
|
db: DbSessionFactory,
|
|
375
385
|
bulk_inserter: BulkInserter,
|
|
376
386
|
dml_event_handler: DmlEventHandler,
|
|
387
|
+
token_store: Optional[TokenStore] = None,
|
|
377
388
|
tracer_provider: Optional["TracerProvider"] = None,
|
|
378
389
|
enable_prometheus: bool = False,
|
|
379
390
|
startup_callbacks: Iterable[_Callback] = (),
|
|
@@ -400,6 +411,7 @@ def _lifespan(
|
|
|
400
411
|
disabled=read_only,
|
|
401
412
|
tracer_provider=tracer_provider,
|
|
402
413
|
enable_prometheus=enable_prometheus,
|
|
414
|
+
token_store=token_store,
|
|
403
415
|
)
|
|
404
416
|
await stack.enter_async_context(grpc_server)
|
|
405
417
|
await stack.enter_async_context(dml_event_handler)
|
|
@@ -410,6 +422,8 @@ def _lifespan(
|
|
|
410
422
|
queue_evaluation=queue_evaluation,
|
|
411
423
|
)
|
|
412
424
|
await stack.enter_async_context(scaffolder)
|
|
425
|
+
if isinstance(token_store, AsyncContextManager):
|
|
426
|
+
await stack.enter_async_context(token_store)
|
|
413
427
|
yield {
|
|
414
428
|
"event_queue": dml_event_handler,
|
|
415
429
|
"enqueue": enqueue,
|
|
@@ -436,11 +450,13 @@ def create_graphql_router(
|
|
|
436
450
|
model: Model,
|
|
437
451
|
export_path: Path,
|
|
438
452
|
last_updated_at: CanGetLastUpdatedAt,
|
|
453
|
+
authentication_enabled: bool,
|
|
439
454
|
corpus: Optional[Model] = None,
|
|
440
455
|
cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
|
|
441
456
|
event_queue: CanPutItem[DmlEvent],
|
|
442
457
|
read_only: bool = False,
|
|
443
458
|
secret: Optional[str] = None,
|
|
459
|
+
token_store: Optional[TokenStore] = None,
|
|
444
460
|
) -> GraphQLRouter: # type: ignore[type-arg]
|
|
445
461
|
"""Creates the GraphQL router.
|
|
446
462
|
|
|
@@ -450,6 +466,7 @@ def create_graphql_router(
|
|
|
450
466
|
model (Model): The Model representing inferences (legacy)
|
|
451
467
|
export_path (Path): the file path to export data to for download (legacy)
|
|
452
468
|
last_updated_at (CanGetLastUpdatedAt): How to get the last updated timestamp for updates.
|
|
469
|
+
authentication_enabled (bool): Whether authentication is enabled.
|
|
453
470
|
event_queue (CanPutItem[DmlEvent]): The event queue for DML events.
|
|
454
471
|
corpus (Optional[Model], optional): the corpus for UMAP projection. Defaults to None.
|
|
455
472
|
cache_for_dataloaders (Optional[CacheForDataLoaders], optional): GraphQL data loaders.
|
|
@@ -521,18 +538,23 @@ def create_graphql_router(
|
|
|
521
538
|
),
|
|
522
539
|
trace_row_ids=TraceRowIdsDataLoader(db),
|
|
523
540
|
project_by_name=ProjectByNameDataLoader(db),
|
|
541
|
+
users=UsersDataLoader(db),
|
|
542
|
+
user_roles=UserRolesDataLoader(db),
|
|
524
543
|
),
|
|
525
544
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
526
545
|
read_only=read_only,
|
|
546
|
+
auth_enabled=authentication_enabled,
|
|
527
547
|
secret=secret,
|
|
548
|
+
token_store=token_store,
|
|
528
549
|
)
|
|
529
550
|
|
|
530
551
|
return GraphQLRouter(
|
|
531
552
|
schema,
|
|
532
|
-
graphiql
|
|
553
|
+
graphql_ide="graphiql",
|
|
533
554
|
context_getter=get_context,
|
|
534
555
|
include_in_schema=False,
|
|
535
556
|
prefix="/graphql",
|
|
557
|
+
dependencies=(Depends(is_authenticated),) if authentication_enabled else (),
|
|
536
558
|
)
|
|
537
559
|
|
|
538
560
|
|
|
@@ -600,11 +622,19 @@ def create_app(
|
|
|
600
622
|
startup_callbacks: Iterable[_Callback] = (),
|
|
601
623
|
shutdown_callbacks: Iterable[_Callback] = (),
|
|
602
624
|
secret: Optional[str] = None,
|
|
625
|
+
password_reset_token_expiry: Optional[timedelta] = None,
|
|
626
|
+
access_token_expiry: Optional[timedelta] = None,
|
|
627
|
+
refresh_token_expiry: Optional[timedelta] = None,
|
|
603
628
|
scaffolder_config: Optional[ScaffolderConfig] = None,
|
|
629
|
+
email_sender: Optional[EmailSender] = None,
|
|
630
|
+
oauth2_client_configs: Optional[List[OAuth2ClientConfig]] = None,
|
|
631
|
+
bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None,
|
|
604
632
|
) -> FastAPI:
|
|
605
633
|
logger.info(f"Server umap params: {umap_params}")
|
|
634
|
+
bulk_inserter_factory = bulk_inserter_factory or BulkInserter
|
|
606
635
|
startup_callbacks_list: List[_Callback] = list(startup_callbacks)
|
|
607
636
|
shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks)
|
|
637
|
+
startup_callbacks_list.append(Facilitator(db=db))
|
|
608
638
|
initial_batch_of_spans: Iterable[Tuple[Span, str]] = (
|
|
609
639
|
()
|
|
610
640
|
if initial_spans is None
|
|
@@ -619,12 +649,22 @@ def create_app(
|
|
|
619
649
|
)
|
|
620
650
|
last_updated_at = LastUpdatedAt()
|
|
621
651
|
middlewares: List[Middleware] = [Middleware(HeadersMiddleware)]
|
|
652
|
+
if authentication_enabled and secret:
|
|
653
|
+
token_store = JwtStore(db, secret)
|
|
654
|
+
middlewares.append(
|
|
655
|
+
Middleware(
|
|
656
|
+
AuthenticationMiddleware,
|
|
657
|
+
backend=BearerTokenAuthBackend(token_store),
|
|
658
|
+
)
|
|
659
|
+
)
|
|
660
|
+
else:
|
|
661
|
+
token_store = None
|
|
622
662
|
dml_event_handler = DmlEventHandler(
|
|
623
663
|
db=db,
|
|
624
664
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
625
665
|
last_updated_at=last_updated_at,
|
|
626
666
|
)
|
|
627
|
-
bulk_inserter =
|
|
667
|
+
bulk_inserter = bulk_inserter_factory(
|
|
628
668
|
db,
|
|
629
669
|
enable_prometheus=enable_prometheus,
|
|
630
670
|
event_queue=dml_event_handler,
|
|
@@ -662,12 +702,14 @@ def create_app(
|
|
|
662
702
|
),
|
|
663
703
|
model=model,
|
|
664
704
|
corpus=corpus,
|
|
705
|
+
authentication_enabled=authentication_enabled,
|
|
665
706
|
export_path=export_path,
|
|
666
707
|
last_updated_at=last_updated_at,
|
|
667
708
|
event_queue=dml_event_handler,
|
|
668
709
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
669
710
|
read_only=read_only,
|
|
670
711
|
secret=secret,
|
|
712
|
+
token_store=token_store,
|
|
671
713
|
)
|
|
672
714
|
if enable_prometheus:
|
|
673
715
|
from phoenix.server.prometheus import PrometheusMiddleware
|
|
@@ -681,6 +723,7 @@ def create_app(
|
|
|
681
723
|
read_only=read_only,
|
|
682
724
|
bulk_inserter=bulk_inserter,
|
|
683
725
|
dml_event_handler=dml_event_handler,
|
|
726
|
+
token_store=token_store,
|
|
684
727
|
tracer_provider=tracer_provider,
|
|
685
728
|
enable_prometheus=enable_prometheus,
|
|
686
729
|
shutdown_callbacks=shutdown_callbacks_list,
|
|
@@ -694,14 +737,20 @@ def create_app(
|
|
|
694
737
|
"defaultModelsExpandDepth": -1, # hides the schema section in the Swagger UI
|
|
695
738
|
},
|
|
696
739
|
)
|
|
697
|
-
app.
|
|
698
|
-
app.
|
|
699
|
-
app.include_router(v1_router)
|
|
740
|
+
app.include_router(create_v1_router(authentication_enabled))
|
|
741
|
+
app.include_router(create_embeddings_router(authentication_enabled))
|
|
700
742
|
app.include_router(router)
|
|
701
743
|
app.include_router(graphql_router)
|
|
744
|
+
if authentication_enabled:
|
|
745
|
+
app.include_router(auth_router)
|
|
746
|
+
app.include_router(oauth2_router)
|
|
702
747
|
app.add_middleware(GZipMiddleware)
|
|
703
748
|
web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json"
|
|
704
749
|
if serve_ui and web_manifest_path.is_file():
|
|
750
|
+
oauth2_idps = [
|
|
751
|
+
OAuth2Idp(name=config.idp_name, displayName=config.idp_display_name)
|
|
752
|
+
for config in oauth2_client_configs or []
|
|
753
|
+
]
|
|
705
754
|
app.mount(
|
|
706
755
|
"/",
|
|
707
756
|
app=Static(
|
|
@@ -715,11 +764,21 @@ def create_app(
|
|
|
715
764
|
is_development=dev,
|
|
716
765
|
authentication_enabled=authentication_enabled,
|
|
717
766
|
web_manifest_path=web_manifest_path,
|
|
767
|
+
oauth2_idps=oauth2_idps,
|
|
718
768
|
),
|
|
719
769
|
),
|
|
720
770
|
name="static",
|
|
721
771
|
)
|
|
722
|
-
app =
|
|
772
|
+
app.state.read_only = read_only
|
|
773
|
+
app.state.export_path = export_path
|
|
774
|
+
app.state.password_reset_token_expiry = password_reset_token_expiry
|
|
775
|
+
app.state.access_token_expiry = access_token_expiry
|
|
776
|
+
app.state.refresh_token_expiry = refresh_token_expiry
|
|
777
|
+
app.state.oauth2_clients = OAuth2Clients.from_configs(oauth2_client_configs or [])
|
|
778
|
+
app.state.db = db
|
|
779
|
+
app.state.email_sender = email_sender
|
|
780
|
+
app = _add_get_secret_method(app=app, secret=secret)
|
|
781
|
+
app = _add_get_token_store_method(app=app, token_store=token_store)
|
|
723
782
|
if tracer_provider:
|
|
724
783
|
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
|
725
784
|
|
|
@@ -729,13 +788,10 @@ def create_app(
|
|
|
729
788
|
return app
|
|
730
789
|
|
|
731
790
|
|
|
732
|
-
def
|
|
791
|
+
def _add_get_secret_method(*, app: FastAPI, secret: Optional[str]) -> FastAPI:
|
|
733
792
|
"""
|
|
734
|
-
Dynamically
|
|
735
|
-
(at the time of this writing, FastAPI does not support setting this state
|
|
736
|
-
during the creation of the app).
|
|
793
|
+
Dynamically adds a `get_secret` method to the app's `state`.
|
|
737
794
|
"""
|
|
738
|
-
app.state.db = db
|
|
739
795
|
app.state._secret = secret
|
|
740
796
|
|
|
741
797
|
def get_secret(self: StarletteState) -> str:
|
|
@@ -746,3 +802,19 @@ def _update_app_state(app: FastAPI, /, *, db: DbSessionFactory, secret: Optional
|
|
|
746
802
|
|
|
747
803
|
app.state.get_secret = MethodType(get_secret, app.state)
|
|
748
804
|
return app
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def _add_get_token_store_method(*, app: FastAPI, token_store: Optional[JwtStore]) -> FastAPI:
|
|
808
|
+
"""
|
|
809
|
+
Dynamically adds a `get_token_store` method to the app's `state`.
|
|
810
|
+
"""
|
|
811
|
+
app.state._token_store = token_store
|
|
812
|
+
|
|
813
|
+
def get_token_store(self: StarletteState) -> JwtStore:
|
|
814
|
+
if (token_store := self._token_store) is None:
|
|
815
|
+
raise ValueError("token store is not set on the app")
|
|
816
|
+
assert isinstance(token_store, JwtStore)
|
|
817
|
+
return token_store
|
|
818
|
+
|
|
819
|
+
app.state.get_token_store = MethodType(get_token_store, app.state)
|
|
820
|
+
return app
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from datetime import datetime, timedelta, timezone
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Awaitable,
|
|
7
|
+
Callable,
|
|
8
|
+
Optional,
|
|
9
|
+
Tuple,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
import grpc
|
|
13
|
+
from fastapi import HTTPException, Request
|
|
14
|
+
from grpc_interceptor import AsyncServerInterceptor
|
|
15
|
+
from grpc_interceptor.exceptions import Unauthenticated
|
|
16
|
+
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
|
|
17
|
+
from starlette.requests import HTTPConnection
|
|
18
|
+
from starlette.status import HTTP_401_UNAUTHORIZED
|
|
19
|
+
|
|
20
|
+
from phoenix.auth import (
|
|
21
|
+
PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
|
|
22
|
+
CanReadToken,
|
|
23
|
+
ClaimSetStatus,
|
|
24
|
+
Token,
|
|
25
|
+
)
|
|
26
|
+
from phoenix.db import enums
|
|
27
|
+
from phoenix.db.enums import UserRole
|
|
28
|
+
from phoenix.db.models import User as OrmUser
|
|
29
|
+
from phoenix.server.types import (
|
|
30
|
+
AccessToken,
|
|
31
|
+
AccessTokenAttributes,
|
|
32
|
+
AccessTokenClaims,
|
|
33
|
+
ApiKeyClaims,
|
|
34
|
+
RefreshToken,
|
|
35
|
+
RefreshTokenAttributes,
|
|
36
|
+
RefreshTokenClaims,
|
|
37
|
+
TokenStore,
|
|
38
|
+
UserClaimSet,
|
|
39
|
+
UserId,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class HasTokenStore(ABC):
|
|
44
|
+
def __init__(self, token_store: CanReadToken) -> None:
|
|
45
|
+
super().__init__()
|
|
46
|
+
self._token_store = token_store
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class BearerTokenAuthBackend(HasTokenStore, AuthenticationBackend):
|
|
50
|
+
async def authenticate(
|
|
51
|
+
self,
|
|
52
|
+
conn: HTTPConnection,
|
|
53
|
+
) -> Optional[Tuple[AuthCredentials, BaseUser]]:
|
|
54
|
+
if header := conn.headers.get("Authorization"):
|
|
55
|
+
scheme, _, token = header.partition(" ")
|
|
56
|
+
if scheme.lower() != "bearer" or not token:
|
|
57
|
+
return None
|
|
58
|
+
elif access_token := conn.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME):
|
|
59
|
+
token = access_token
|
|
60
|
+
else:
|
|
61
|
+
return None
|
|
62
|
+
claims = await self._token_store.read(Token(token))
|
|
63
|
+
if not (isinstance(claims, UserClaimSet) and isinstance(claims.subject, UserId)):
|
|
64
|
+
return None
|
|
65
|
+
if not isinstance(claims, (ApiKeyClaims, AccessTokenClaims)):
|
|
66
|
+
return None
|
|
67
|
+
return AuthCredentials(), PhoenixUser(claims.subject, claims)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class PhoenixUser(BaseUser):
|
|
71
|
+
def __init__(self, user_id: UserId, claims: UserClaimSet) -> None:
|
|
72
|
+
self._user_id = user_id
|
|
73
|
+
self.claims = claims
|
|
74
|
+
assert claims.attributes
|
|
75
|
+
self._is_admin = (
|
|
76
|
+
claims.status is ClaimSetStatus.VALID
|
|
77
|
+
and claims.attributes.user_role == enums.UserRole.ADMIN
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@cached_property
|
|
81
|
+
def is_admin(self) -> bool:
|
|
82
|
+
return self._is_admin
|
|
83
|
+
|
|
84
|
+
@cached_property
|
|
85
|
+
def identity(self) -> UserId:
|
|
86
|
+
return self._user_id
|
|
87
|
+
|
|
88
|
+
@cached_property
|
|
89
|
+
def is_authenticated(self) -> bool:
|
|
90
|
+
return True
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
|
|
94
|
+
async def intercept(
|
|
95
|
+
self,
|
|
96
|
+
method: Callable[[Any, grpc.ServicerContext], Awaitable[Any]],
|
|
97
|
+
request_or_iterator: Any,
|
|
98
|
+
context: grpc.ServicerContext,
|
|
99
|
+
method_name: str,
|
|
100
|
+
) -> Any:
|
|
101
|
+
for datum in context.invocation_metadata():
|
|
102
|
+
if datum.key.lower() == "authorization":
|
|
103
|
+
scheme, _, token = datum.value.partition(" ")
|
|
104
|
+
if scheme.lower() != "bearer" or not token:
|
|
105
|
+
break
|
|
106
|
+
claims = await self._token_store.read(Token(token))
|
|
107
|
+
if not (isinstance(claims, UserClaimSet) and isinstance(claims.subject, UserId)):
|
|
108
|
+
break
|
|
109
|
+
if not isinstance(claims, (ApiKeyClaims, AccessTokenClaims)):
|
|
110
|
+
raise Unauthenticated(details="Invalid token")
|
|
111
|
+
if claims.status is ClaimSetStatus.EXPIRED:
|
|
112
|
+
raise Unauthenticated(details="Expired token")
|
|
113
|
+
if claims.status is ClaimSetStatus.VALID:
|
|
114
|
+
return await method(request_or_iterator, context)
|
|
115
|
+
raise Unauthenticated()
|
|
116
|
+
raise Unauthenticated()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
async def is_authenticated(request: Request) -> None:
|
|
120
|
+
"""
|
|
121
|
+
Raises a 401 if the request is not authenticated.
|
|
122
|
+
"""
|
|
123
|
+
if not isinstance((user := request.user), PhoenixUser):
|
|
124
|
+
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
|
125
|
+
claims = user.claims
|
|
126
|
+
if claims.status is ClaimSetStatus.EXPIRED:
|
|
127
|
+
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
|
|
128
|
+
if claims.status is not ClaimSetStatus.VALID:
|
|
129
|
+
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
async def create_access_and_refresh_tokens(
|
|
133
|
+
*,
|
|
134
|
+
token_store: TokenStore,
|
|
135
|
+
user: OrmUser,
|
|
136
|
+
access_token_expiry: timedelta,
|
|
137
|
+
refresh_token_expiry: timedelta,
|
|
138
|
+
) -> Tuple[AccessToken, RefreshToken]:
|
|
139
|
+
issued_at = datetime.now(timezone.utc)
|
|
140
|
+
user_id = UserId(user.id)
|
|
141
|
+
user_role = UserRole(user.role.name)
|
|
142
|
+
refresh_token_claims = RefreshTokenClaims(
|
|
143
|
+
subject=user_id,
|
|
144
|
+
issued_at=issued_at,
|
|
145
|
+
expiration_time=issued_at + refresh_token_expiry,
|
|
146
|
+
attributes=RefreshTokenAttributes(
|
|
147
|
+
user_role=user_role,
|
|
148
|
+
),
|
|
149
|
+
)
|
|
150
|
+
refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims)
|
|
151
|
+
access_token_claims = AccessTokenClaims(
|
|
152
|
+
subject=user_id,
|
|
153
|
+
issued_at=issued_at,
|
|
154
|
+
expiration_time=issued_at + access_token_expiry,
|
|
155
|
+
attributes=AccessTokenAttributes(
|
|
156
|
+
user_role=user_role,
|
|
157
|
+
refresh_token_id=refresh_token_id,
|
|
158
|
+
),
|
|
159
|
+
)
|
|
160
|
+
access_token, _ = await token_store.create_access_token(access_token_claims)
|
|
161
|
+
return access_token, refresh_token
|
|
File without changes
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from fastapi_mail import ConnectionConfig, FastMail, MessageSchema
|
|
4
|
+
|
|
5
|
+
EMAIL_TEMPLATE_FOLDER = Path(__file__).parent / "templates"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FastMailSender:
|
|
9
|
+
def __init__(self, conf: ConnectionConfig) -> None:
|
|
10
|
+
self._fm = FastMail(conf)
|
|
11
|
+
|
|
12
|
+
async def send_password_reset_email(
|
|
13
|
+
self,
|
|
14
|
+
email: str,
|
|
15
|
+
reset_url: str,
|
|
16
|
+
) -> None:
|
|
17
|
+
message = MessageSchema(
|
|
18
|
+
subject="[Phoenix] Password Reset Request",
|
|
19
|
+
recipients=[email],
|
|
20
|
+
template_body=dict(reset_url=reset_url),
|
|
21
|
+
subtype="html",
|
|
22
|
+
)
|
|
23
|
+
await self._fm.send_message(
|
|
24
|
+
message,
|
|
25
|
+
template_name="password_reset.html",
|
|
26
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
<!DOCTYPE html>
|
|
2
|
+
<html>
|
|
3
|
+
<head>
|
|
4
|
+
<meta charset="UTF-8" />
|
|
5
|
+
<title>Password Reset</title>
|
|
6
|
+
</head>
|
|
7
|
+
<body>
|
|
8
|
+
<p>Hello.</p>
|
|
9
|
+
<p>
|
|
10
|
+
You have requested a password reset. Please click on the link below to
|
|
11
|
+
reset your password:
|
|
12
|
+
</p>
|
|
13
|
+
<p>
|
|
14
|
+
<a id="reset-url" href="{{ reset_url }}">Reset Password</a
|
|
15
|
+
>
|
|
16
|
+
</p>
|
|
17
|
+
<p>If you did not make this request, please contact your administrator.</p>
|
|
18
|
+
</body>
|
|
19
|
+
</html>
|
phoenix/server/grpc_server.py
CHANGED
|
@@ -12,7 +12,9 @@ from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
|
|
|
12
12
|
)
|
|
13
13
|
from typing_extensions import TypeAlias
|
|
14
14
|
|
|
15
|
+
from phoenix.auth import CanReadToken
|
|
15
16
|
from phoenix.config import get_env_grpc_port
|
|
17
|
+
from phoenix.server.bearer_auth import ApiKeyInterceptor
|
|
16
18
|
from phoenix.trace.otel import decode_otlp_span
|
|
17
19
|
from phoenix.trace.schemas import Span
|
|
18
20
|
from phoenix.utilities.project import get_project_name
|
|
@@ -52,17 +54,21 @@ class GrpcServer:
|
|
|
52
54
|
tracer_provider: Optional["TracerProvider"] = None,
|
|
53
55
|
enable_prometheus: bool = False,
|
|
54
56
|
disabled: bool = False,
|
|
57
|
+
token_store: Optional[CanReadToken] = None,
|
|
55
58
|
) -> None:
|
|
56
59
|
self._callback = callback
|
|
57
60
|
self._server: Optional[Server] = None
|
|
58
61
|
self._tracer_provider = tracer_provider
|
|
59
62
|
self._enable_prometheus = enable_prometheus
|
|
60
63
|
self._disabled = disabled
|
|
64
|
+
self._token_store = token_store
|
|
61
65
|
|
|
62
66
|
async def __aenter__(self) -> None:
|
|
63
67
|
if self._disabled:
|
|
64
68
|
return
|
|
65
69
|
interceptors: List[ServerInterceptor] = []
|
|
70
|
+
if self._token_store:
|
|
71
|
+
interceptors.append(ApiKeyInterceptor(self._token_store))
|
|
66
72
|
if self._enable_prometheus:
|
|
67
73
|
...
|
|
68
74
|
# TODO: convert to async interceptor
|