arize-phoenix 4.35.2__py3-none-any.whl → 5.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/METADATA +10 -12
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/RECORD +92 -79
- phoenix/__init__.py +86 -0
- phoenix/auth.py +275 -14
- phoenix/config.py +369 -27
- phoenix/db/alembic.ini +0 -34
- phoenix/db/engines.py +27 -10
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +112 -0
- phoenix/db/insertion/dataset.py +0 -1
- phoenix/db/insertion/types.py +1 -1
- phoenix/db/migrate.py +3 -3
- phoenix/db/migrations/env.py +0 -7
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/models.py +145 -60
- phoenix/experiments/evaluators/code_evaluators.py +9 -3
- phoenix/experiments/functions.py +1 -4
- phoenix/inferences/fixtures.py +0 -1
- phoenix/inferences/inferences.py +0 -1
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +0 -1
- phoenix/otel/settings.py +4 -4
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +32 -0
- phoenix/server/api/context.py +50 -2
- phoenix/server/api/dataloaders/__init__.py +4 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +7 -0
- phoenix/server/api/mutations/__init__.py +0 -2
- phoenix/server/api/mutations/api_key_mutations.py +104 -86
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/project_mutations.py +3 -3
- phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/user_mutations.py +282 -42
- phoenix/server/api/openapi/schema.py +2 -2
- phoenix/server/api/queries.py +48 -39
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +456 -0
- phoenix/server/api/routers/v1/__init__.py +38 -16
- phoenix/server/api/routers/v1/datasets.py +0 -1
- phoenix/server/api/types/ApiKey.py +11 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/User.py +48 -4
- phoenix/server/api/types/UserApiKey.py +35 -1
- phoenix/server/api/types/UserRole.py +7 -0
- phoenix/server/app.py +105 -34
- phoenix/server/bearer_auth.py +161 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +26 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +6 -0
- phoenix/server/jwt_store.py +504 -0
- phoenix/server/main.py +61 -30
- phoenix/server/oauth2.py +51 -0
- phoenix/server/prometheus.py +20 -0
- phoenix/server/rate_limiters.py +191 -0
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-Dte7_KRd.js → components-REunxTt6.js} +348 -286
- phoenix/server/static/assets/index-DAPJxlCw.js +101 -0
- phoenix/server/static/assets/{pages-CnTvEGEN.js → pages-1VrMk2pW.js} +559 -291
- phoenix/server/static/assets/{vendor-BC3OPQuM.js → vendor-B5IC0ivG.js} +5 -5
- phoenix/server/static/assets/{vendor-arizeai-NjB3cZzD.js → vendor-arizeai-aFbT4kl1.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-gE_JCOgX.js → vendor-codemirror-BEGorXSV.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-BXLYwcXF.js → vendor-recharts-6nUU7gU_.js} +1 -1
- phoenix/server/telemetry.py +2 -2
- phoenix/server/templates/index.html +1 -0
- phoenix/server/types.py +157 -1
- phoenix/services.py +0 -1
- phoenix/session/client.py +7 -3
- phoenix/session/evaluation.py +0 -1
- phoenix/session/session.py +0 -1
- phoenix/settings.py +9 -0
- phoenix/trace/exporter.py +0 -1
- phoenix/trace/fixtures.py +0 -2
- phoenix/utilities/client.py +16 -0
- phoenix/utilities/logging.py +9 -1
- phoenix/utilities/re.py +3 -3
- phoenix/version.py +1 -1
- phoenix/db/migrations/future_versions/README.md +0 -4
- phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +0 -293
- phoenix/db/migrations/versions/.gitignore +0 -1
- phoenix/server/api/mutations/auth.py +0 -18
- phoenix/server/api/mutations/auth_mutations.py +0 -65
- phoenix/server/static/assets/index-fq1-hCK4.js +0 -100
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -103
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -31
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/app.py
CHANGED
|
@@ -20,21 +20,23 @@ from typing import (
|
|
|
20
20
|
List,
|
|
21
21
|
NamedTuple,
|
|
22
22
|
Optional,
|
|
23
|
+
Sequence,
|
|
23
24
|
Tuple,
|
|
25
|
+
TypedDict,
|
|
24
26
|
Union,
|
|
25
27
|
cast,
|
|
26
28
|
)
|
|
27
29
|
|
|
28
30
|
import strawberry
|
|
29
|
-
from fastapi import APIRouter, FastAPI
|
|
31
|
+
from fastapi import APIRouter, Depends, FastAPI
|
|
30
32
|
from fastapi.middleware.gzip import GZipMiddleware
|
|
31
|
-
from fastapi.responses import FileResponse
|
|
32
33
|
from fastapi.utils import is_body_allowed_for_status_code
|
|
33
34
|
from sqlalchemy import select
|
|
34
35
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
35
36
|
from starlette.datastructures import State as StarletteState
|
|
36
37
|
from starlette.exceptions import HTTPException
|
|
37
38
|
from starlette.middleware import Middleware
|
|
39
|
+
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
38
40
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
39
41
|
from starlette.requests import Request
|
|
40
42
|
from starlette.responses import PlainTextResponse, Response
|
|
@@ -50,6 +52,7 @@ import phoenix.trace.v1 as pb
|
|
|
50
52
|
from phoenix.config import (
|
|
51
53
|
DEFAULT_PROJECT_NAME,
|
|
52
54
|
SERVER_DIR,
|
|
55
|
+
OAuth2ClientConfig,
|
|
53
56
|
get_env_host,
|
|
54
57
|
get_env_port,
|
|
55
58
|
server_instrumentation_is_enabled,
|
|
@@ -58,6 +61,7 @@ from phoenix.core.model_schema import Model
|
|
|
58
61
|
from phoenix.db import models
|
|
59
62
|
from phoenix.db.bulk_inserter import BulkInserter
|
|
60
63
|
from phoenix.db.engines import create_engine
|
|
64
|
+
from phoenix.db.facilitator import Facilitator
|
|
61
65
|
from phoenix.db.helpers import SupportedSQLDialect
|
|
62
66
|
from phoenix.exceptions import PhoenixMigrationError
|
|
63
67
|
from phoenix.pointcloud.umap_parameters import UMAPParameters
|
|
@@ -86,13 +90,24 @@ from phoenix.server.api.dataloaders import (
|
|
|
86
90
|
SpanProjectsDataLoader,
|
|
87
91
|
TokenCountDataLoader,
|
|
88
92
|
TraceRowIdsDataLoader,
|
|
93
|
+
UserRolesDataLoader,
|
|
94
|
+
UsersDataLoader,
|
|
95
|
+
)
|
|
96
|
+
from phoenix.server.api.routers import (
|
|
97
|
+
auth_router,
|
|
98
|
+
create_embeddings_router,
|
|
99
|
+
create_v1_router,
|
|
100
|
+
oauth2_router,
|
|
89
101
|
)
|
|
90
102
|
from phoenix.server.api.routers.v1 import REST_API_VERSION
|
|
91
|
-
from phoenix.server.api.routers.v1 import router as v1_router
|
|
92
103
|
from phoenix.server.api.schema import schema
|
|
104
|
+
from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
|
|
93
105
|
from phoenix.server.dml_event import DmlEvent
|
|
94
106
|
from phoenix.server.dml_event_handler import DmlEventHandler
|
|
107
|
+
from phoenix.server.email.types import EmailSender
|
|
95
108
|
from phoenix.server.grpc_server import GrpcServer
|
|
109
|
+
from phoenix.server.jwt_store import JwtStore
|
|
110
|
+
from phoenix.server.oauth2 import OAuth2Clients
|
|
96
111
|
from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
|
|
97
112
|
from phoenix.server.types import (
|
|
98
113
|
CanGetLastUpdatedAt,
|
|
@@ -100,6 +115,7 @@ from phoenix.server.types import (
|
|
|
100
115
|
DaemonTask,
|
|
101
116
|
DbSessionFactory,
|
|
102
117
|
LastUpdatedAt,
|
|
118
|
+
TokenStore,
|
|
103
119
|
)
|
|
104
120
|
from phoenix.trace.fixtures import (
|
|
105
121
|
TracesFixture,
|
|
@@ -118,8 +134,6 @@ if TYPE_CHECKING:
|
|
|
118
134
|
from opentelemetry.trace import TracerProvider
|
|
119
135
|
|
|
120
136
|
logger = logging.getLogger(__name__)
|
|
121
|
-
logger.setLevel(logging.INFO)
|
|
122
|
-
logger.addHandler(logging.NullHandler())
|
|
123
137
|
|
|
124
138
|
router = APIRouter(include_in_schema=False)
|
|
125
139
|
|
|
@@ -137,6 +151,11 @@ ProjectName: TypeAlias = str
|
|
|
137
151
|
_Callback: TypeAlias = Callable[[], Union[None, Awaitable[None]]]
|
|
138
152
|
|
|
139
153
|
|
|
154
|
+
class OAuth2Idp(TypedDict):
|
|
155
|
+
name: str
|
|
156
|
+
displayName: str
|
|
157
|
+
|
|
158
|
+
|
|
140
159
|
class AppConfig(NamedTuple):
|
|
141
160
|
has_inferences: bool
|
|
142
161
|
""" Whether the model has inferences (e.g. a primary dataset) """
|
|
@@ -148,6 +167,7 @@ class AppConfig(NamedTuple):
|
|
|
148
167
|
web_manifest_path: Path
|
|
149
168
|
authentication_enabled: bool
|
|
150
169
|
""" Whether authentication is enabled """
|
|
170
|
+
oauth2_idps: Sequence[OAuth2Idp]
|
|
151
171
|
|
|
152
172
|
|
|
153
173
|
class Static(StaticFiles):
|
|
@@ -196,6 +216,7 @@ class Static(StaticFiles):
|
|
|
196
216
|
"is_development": self._app_config.is_development,
|
|
197
217
|
"manifest": self._web_manifest,
|
|
198
218
|
"authentication_enabled": self._app_config.authentication_enabled,
|
|
219
|
+
"oauth2_idps": self._app_config.oauth2_idps,
|
|
199
220
|
},
|
|
200
221
|
)
|
|
201
222
|
except Exception as e:
|
|
@@ -220,18 +241,6 @@ class HeadersMiddleware(BaseHTTPMiddleware):
|
|
|
220
241
|
ProjectRowId: TypeAlias = int
|
|
221
242
|
|
|
222
243
|
|
|
223
|
-
@router.get("/exports")
|
|
224
|
-
async def download_exported_file(request: Request, filename: str) -> FileResponse:
|
|
225
|
-
file = request.app.state.export_path / (filename + ".parquet")
|
|
226
|
-
if not file.is_file():
|
|
227
|
-
raise HTTPException(status_code=404)
|
|
228
|
-
return FileResponse(
|
|
229
|
-
path=file,
|
|
230
|
-
filename=file.name,
|
|
231
|
-
media_type="application/x-octet-stream",
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
|
|
235
244
|
@router.get("/arize_phoenix_version")
|
|
236
245
|
async def version() -> PlainTextResponse:
|
|
237
246
|
return PlainTextResponse(f"{phoenix.__version__}")
|
|
@@ -240,13 +249,15 @@ async def version() -> PlainTextResponse:
|
|
|
240
249
|
DB_MUTEX: Optional[asyncio.Lock] = None
|
|
241
250
|
|
|
242
251
|
|
|
243
|
-
def _db(
|
|
252
|
+
def _db(
|
|
253
|
+
engine: AsyncEngine, bypass_lock: bool = False
|
|
254
|
+
) -> Callable[[], AsyncContextManager[AsyncSession]]:
|
|
244
255
|
Session = async_sessionmaker(engine, expire_on_commit=False)
|
|
245
256
|
|
|
246
257
|
@contextlib.asynccontextmanager
|
|
247
258
|
async def factory() -> AsyncIterator[AsyncSession]:
|
|
248
259
|
async with contextlib.AsyncExitStack() as stack:
|
|
249
|
-
if DB_MUTEX:
|
|
260
|
+
if not bypass_lock and DB_MUTEX:
|
|
250
261
|
await stack.enter_async_context(DB_MUTEX)
|
|
251
262
|
yield await stack.enter_async_context(Session.begin())
|
|
252
263
|
|
|
@@ -285,9 +296,6 @@ class Scaffolder(DaemonTask):
|
|
|
285
296
|
return
|
|
286
297
|
await self.start()
|
|
287
298
|
|
|
288
|
-
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
|
|
289
|
-
await self.stop()
|
|
290
|
-
|
|
291
299
|
async def _run(self) -> None:
|
|
292
300
|
"""
|
|
293
301
|
Main entry point for Scaffolder.
|
|
@@ -376,6 +384,7 @@ def _lifespan(
|
|
|
376
384
|
db: DbSessionFactory,
|
|
377
385
|
bulk_inserter: BulkInserter,
|
|
378
386
|
dml_event_handler: DmlEventHandler,
|
|
387
|
+
token_store: Optional[TokenStore] = None,
|
|
379
388
|
tracer_provider: Optional["TracerProvider"] = None,
|
|
380
389
|
enable_prometheus: bool = False,
|
|
381
390
|
startup_callbacks: Iterable[_Callback] = (),
|
|
@@ -402,6 +411,7 @@ def _lifespan(
|
|
|
402
411
|
disabled=read_only,
|
|
403
412
|
tracer_provider=tracer_provider,
|
|
404
413
|
enable_prometheus=enable_prometheus,
|
|
414
|
+
token_store=token_store,
|
|
405
415
|
)
|
|
406
416
|
await stack.enter_async_context(grpc_server)
|
|
407
417
|
await stack.enter_async_context(dml_event_handler)
|
|
@@ -412,6 +422,8 @@ def _lifespan(
|
|
|
412
422
|
queue_evaluation=queue_evaluation,
|
|
413
423
|
)
|
|
414
424
|
await stack.enter_async_context(scaffolder)
|
|
425
|
+
if isinstance(token_store, AsyncContextManager):
|
|
426
|
+
await stack.enter_async_context(token_store)
|
|
415
427
|
yield {
|
|
416
428
|
"event_queue": dml_event_handler,
|
|
417
429
|
"enqueue": enqueue,
|
|
@@ -438,11 +450,13 @@ def create_graphql_router(
|
|
|
438
450
|
model: Model,
|
|
439
451
|
export_path: Path,
|
|
440
452
|
last_updated_at: CanGetLastUpdatedAt,
|
|
453
|
+
authentication_enabled: bool,
|
|
441
454
|
corpus: Optional[Model] = None,
|
|
442
455
|
cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
|
|
443
456
|
event_queue: CanPutItem[DmlEvent],
|
|
444
457
|
read_only: bool = False,
|
|
445
458
|
secret: Optional[str] = None,
|
|
459
|
+
token_store: Optional[TokenStore] = None,
|
|
446
460
|
) -> GraphQLRouter: # type: ignore[type-arg]
|
|
447
461
|
"""Creates the GraphQL router.
|
|
448
462
|
|
|
@@ -452,6 +466,7 @@ def create_graphql_router(
|
|
|
452
466
|
model (Model): The Model representing inferences (legacy)
|
|
453
467
|
export_path (Path): the file path to export data to for download (legacy)
|
|
454
468
|
last_updated_at (CanGetLastUpdatedAt): How to get the last updated timestamp for updates.
|
|
469
|
+
authentication_enabled (bool): Whether authentication is enabled.
|
|
455
470
|
event_queue (CanPutItem[DmlEvent]): The event queue for DML events.
|
|
456
471
|
corpus (Optional[Model], optional): the corpus for UMAP projection. Defaults to None.
|
|
457
472
|
cache_for_dataloaders (Optional[CacheForDataLoaders], optional): GraphQL data loaders.
|
|
@@ -523,18 +538,23 @@ def create_graphql_router(
|
|
|
523
538
|
),
|
|
524
539
|
trace_row_ids=TraceRowIdsDataLoader(db),
|
|
525
540
|
project_by_name=ProjectByNameDataLoader(db),
|
|
541
|
+
users=UsersDataLoader(db),
|
|
542
|
+
user_roles=UserRolesDataLoader(db),
|
|
526
543
|
),
|
|
527
544
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
528
545
|
read_only=read_only,
|
|
546
|
+
auth_enabled=authentication_enabled,
|
|
529
547
|
secret=secret,
|
|
548
|
+
token_store=token_store,
|
|
530
549
|
)
|
|
531
550
|
|
|
532
551
|
return GraphQLRouter(
|
|
533
552
|
schema,
|
|
534
|
-
graphiql
|
|
553
|
+
graphql_ide="graphiql",
|
|
535
554
|
context_getter=get_context,
|
|
536
555
|
include_in_schema=False,
|
|
537
556
|
prefix="/graphql",
|
|
557
|
+
dependencies=(Depends(is_authenticated),) if authentication_enabled else (),
|
|
538
558
|
)
|
|
539
559
|
|
|
540
560
|
|
|
@@ -542,7 +562,7 @@ def create_engine_and_run_migrations(
|
|
|
542
562
|
database_url: str,
|
|
543
563
|
) -> AsyncEngine:
|
|
544
564
|
try:
|
|
545
|
-
return create_engine(database_url)
|
|
565
|
+
return create_engine(connection_str=database_url, migrate=True, log_to_stdout=False)
|
|
546
566
|
except PhoenixMigrationError as e:
|
|
547
567
|
msg = (
|
|
548
568
|
"\n\n⚠️⚠️ Phoenix failed to migrate the database to the latest version. ⚠️⚠️\n\n"
|
|
@@ -602,10 +622,19 @@ def create_app(
|
|
|
602
622
|
startup_callbacks: Iterable[_Callback] = (),
|
|
603
623
|
shutdown_callbacks: Iterable[_Callback] = (),
|
|
604
624
|
secret: Optional[str] = None,
|
|
625
|
+
password_reset_token_expiry: Optional[timedelta] = None,
|
|
626
|
+
access_token_expiry: Optional[timedelta] = None,
|
|
627
|
+
refresh_token_expiry: Optional[timedelta] = None,
|
|
605
628
|
scaffolder_config: Optional[ScaffolderConfig] = None,
|
|
629
|
+
email_sender: Optional[EmailSender] = None,
|
|
630
|
+
oauth2_client_configs: Optional[List[OAuth2ClientConfig]] = None,
|
|
631
|
+
bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None,
|
|
606
632
|
) -> FastAPI:
|
|
633
|
+
logger.info(f"Server umap params: {umap_params}")
|
|
634
|
+
bulk_inserter_factory = bulk_inserter_factory or BulkInserter
|
|
607
635
|
startup_callbacks_list: List[_Callback] = list(startup_callbacks)
|
|
608
636
|
shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks)
|
|
637
|
+
startup_callbacks_list.append(Facilitator(db=db))
|
|
609
638
|
initial_batch_of_spans: Iterable[Tuple[Span, str]] = (
|
|
610
639
|
()
|
|
611
640
|
if initial_spans is None
|
|
@@ -620,12 +649,22 @@ def create_app(
|
|
|
620
649
|
)
|
|
621
650
|
last_updated_at = LastUpdatedAt()
|
|
622
651
|
middlewares: List[Middleware] = [Middleware(HeadersMiddleware)]
|
|
652
|
+
if authentication_enabled and secret:
|
|
653
|
+
token_store = JwtStore(db, secret)
|
|
654
|
+
middlewares.append(
|
|
655
|
+
Middleware(
|
|
656
|
+
AuthenticationMiddleware,
|
|
657
|
+
backend=BearerTokenAuthBackend(token_store),
|
|
658
|
+
)
|
|
659
|
+
)
|
|
660
|
+
else:
|
|
661
|
+
token_store = None
|
|
623
662
|
dml_event_handler = DmlEventHandler(
|
|
624
663
|
db=db,
|
|
625
664
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
626
665
|
last_updated_at=last_updated_at,
|
|
627
666
|
)
|
|
628
|
-
bulk_inserter =
|
|
667
|
+
bulk_inserter = bulk_inserter_factory(
|
|
629
668
|
db,
|
|
630
669
|
enable_prometheus=enable_prometheus,
|
|
631
670
|
event_queue=dml_event_handler,
|
|
@@ -663,12 +702,14 @@ def create_app(
|
|
|
663
702
|
),
|
|
664
703
|
model=model,
|
|
665
704
|
corpus=corpus,
|
|
705
|
+
authentication_enabled=authentication_enabled,
|
|
666
706
|
export_path=export_path,
|
|
667
707
|
last_updated_at=last_updated_at,
|
|
668
708
|
event_queue=dml_event_handler,
|
|
669
709
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
670
710
|
read_only=read_only,
|
|
671
711
|
secret=secret,
|
|
712
|
+
token_store=token_store,
|
|
672
713
|
)
|
|
673
714
|
if enable_prometheus:
|
|
674
715
|
from phoenix.server.prometheus import PrometheusMiddleware
|
|
@@ -682,6 +723,7 @@ def create_app(
|
|
|
682
723
|
read_only=read_only,
|
|
683
724
|
bulk_inserter=bulk_inserter,
|
|
684
725
|
dml_event_handler=dml_event_handler,
|
|
726
|
+
token_store=token_store,
|
|
685
727
|
tracer_provider=tracer_provider,
|
|
686
728
|
enable_prometheus=enable_prometheus,
|
|
687
729
|
shutdown_callbacks=shutdown_callbacks_list,
|
|
@@ -695,14 +737,20 @@ def create_app(
|
|
|
695
737
|
"defaultModelsExpandDepth": -1, # hides the schema section in the Swagger UI
|
|
696
738
|
},
|
|
697
739
|
)
|
|
698
|
-
app.
|
|
699
|
-
app.
|
|
700
|
-
app.include_router(v1_router)
|
|
740
|
+
app.include_router(create_v1_router(authentication_enabled))
|
|
741
|
+
app.include_router(create_embeddings_router(authentication_enabled))
|
|
701
742
|
app.include_router(router)
|
|
702
743
|
app.include_router(graphql_router)
|
|
744
|
+
if authentication_enabled:
|
|
745
|
+
app.include_router(auth_router)
|
|
746
|
+
app.include_router(oauth2_router)
|
|
703
747
|
app.add_middleware(GZipMiddleware)
|
|
704
748
|
web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json"
|
|
705
749
|
if serve_ui and web_manifest_path.is_file():
|
|
750
|
+
oauth2_idps = [
|
|
751
|
+
OAuth2Idp(name=config.idp_name, displayName=config.idp_display_name)
|
|
752
|
+
for config in oauth2_client_configs or []
|
|
753
|
+
]
|
|
706
754
|
app.mount(
|
|
707
755
|
"/",
|
|
708
756
|
app=Static(
|
|
@@ -716,11 +764,21 @@ def create_app(
|
|
|
716
764
|
is_development=dev,
|
|
717
765
|
authentication_enabled=authentication_enabled,
|
|
718
766
|
web_manifest_path=web_manifest_path,
|
|
767
|
+
oauth2_idps=oauth2_idps,
|
|
719
768
|
),
|
|
720
769
|
),
|
|
721
770
|
name="static",
|
|
722
771
|
)
|
|
723
|
-
app =
|
|
772
|
+
app.state.read_only = read_only
|
|
773
|
+
app.state.export_path = export_path
|
|
774
|
+
app.state.password_reset_token_expiry = password_reset_token_expiry
|
|
775
|
+
app.state.access_token_expiry = access_token_expiry
|
|
776
|
+
app.state.refresh_token_expiry = refresh_token_expiry
|
|
777
|
+
app.state.oauth2_clients = OAuth2Clients.from_configs(oauth2_client_configs or [])
|
|
778
|
+
app.state.db = db
|
|
779
|
+
app.state.email_sender = email_sender
|
|
780
|
+
app = _add_get_secret_method(app=app, secret=secret)
|
|
781
|
+
app = _add_get_token_store_method(app=app, token_store=token_store)
|
|
724
782
|
if tracer_provider:
|
|
725
783
|
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
|
726
784
|
|
|
@@ -730,13 +788,10 @@ def create_app(
|
|
|
730
788
|
return app
|
|
731
789
|
|
|
732
790
|
|
|
733
|
-
def
|
|
791
|
+
def _add_get_secret_method(*, app: FastAPI, secret: Optional[str]) -> FastAPI:
|
|
734
792
|
"""
|
|
735
|
-
Dynamically
|
|
736
|
-
(at the time of this writing, FastAPI does not support setting this state
|
|
737
|
-
during the creation of the app).
|
|
793
|
+
Dynamically adds a `get_secret` method to the app's `state`.
|
|
738
794
|
"""
|
|
739
|
-
app.state.db = db
|
|
740
795
|
app.state._secret = secret
|
|
741
796
|
|
|
742
797
|
def get_secret(self: StarletteState) -> str:
|
|
@@ -747,3 +802,19 @@ def _update_app_state(app: FastAPI, /, *, db: DbSessionFactory, secret: Optional
|
|
|
747
802
|
|
|
748
803
|
app.state.get_secret = MethodType(get_secret, app.state)
|
|
749
804
|
return app
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def _add_get_token_store_method(*, app: FastAPI, token_store: Optional[JwtStore]) -> FastAPI:
|
|
808
|
+
"""
|
|
809
|
+
Dynamically adds a `get_token_store` method to the app's `state`.
|
|
810
|
+
"""
|
|
811
|
+
app.state._token_store = token_store
|
|
812
|
+
|
|
813
|
+
def get_token_store(self: StarletteState) -> JwtStore:
|
|
814
|
+
if (token_store := self._token_store) is None:
|
|
815
|
+
raise ValueError("token store is not set on the app")
|
|
816
|
+
assert isinstance(token_store, JwtStore)
|
|
817
|
+
return token_store
|
|
818
|
+
|
|
819
|
+
app.state.get_token_store = MethodType(get_token_store, app.state)
|
|
820
|
+
return app
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from datetime import datetime, timedelta, timezone
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Awaitable,
|
|
7
|
+
Callable,
|
|
8
|
+
Optional,
|
|
9
|
+
Tuple,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
import grpc
|
|
13
|
+
from fastapi import HTTPException, Request
|
|
14
|
+
from grpc_interceptor import AsyncServerInterceptor
|
|
15
|
+
from grpc_interceptor.exceptions import Unauthenticated
|
|
16
|
+
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
|
|
17
|
+
from starlette.requests import HTTPConnection
|
|
18
|
+
from starlette.status import HTTP_401_UNAUTHORIZED
|
|
19
|
+
|
|
20
|
+
from phoenix.auth import (
|
|
21
|
+
PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
|
|
22
|
+
CanReadToken,
|
|
23
|
+
ClaimSetStatus,
|
|
24
|
+
Token,
|
|
25
|
+
)
|
|
26
|
+
from phoenix.db import enums
|
|
27
|
+
from phoenix.db.enums import UserRole
|
|
28
|
+
from phoenix.db.models import User as OrmUser
|
|
29
|
+
from phoenix.server.types import (
|
|
30
|
+
AccessToken,
|
|
31
|
+
AccessTokenAttributes,
|
|
32
|
+
AccessTokenClaims,
|
|
33
|
+
ApiKeyClaims,
|
|
34
|
+
RefreshToken,
|
|
35
|
+
RefreshTokenAttributes,
|
|
36
|
+
RefreshTokenClaims,
|
|
37
|
+
TokenStore,
|
|
38
|
+
UserClaimSet,
|
|
39
|
+
UserId,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class HasTokenStore(ABC):
|
|
44
|
+
def __init__(self, token_store: CanReadToken) -> None:
|
|
45
|
+
super().__init__()
|
|
46
|
+
self._token_store = token_store
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class BearerTokenAuthBackend(HasTokenStore, AuthenticationBackend):
|
|
50
|
+
async def authenticate(
|
|
51
|
+
self,
|
|
52
|
+
conn: HTTPConnection,
|
|
53
|
+
) -> Optional[Tuple[AuthCredentials, BaseUser]]:
|
|
54
|
+
if header := conn.headers.get("Authorization"):
|
|
55
|
+
scheme, _, token = header.partition(" ")
|
|
56
|
+
if scheme.lower() != "bearer" or not token:
|
|
57
|
+
return None
|
|
58
|
+
elif access_token := conn.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME):
|
|
59
|
+
token = access_token
|
|
60
|
+
else:
|
|
61
|
+
return None
|
|
62
|
+
claims = await self._token_store.read(Token(token))
|
|
63
|
+
if not (isinstance(claims, UserClaimSet) and isinstance(claims.subject, UserId)):
|
|
64
|
+
return None
|
|
65
|
+
if not isinstance(claims, (ApiKeyClaims, AccessTokenClaims)):
|
|
66
|
+
return None
|
|
67
|
+
return AuthCredentials(), PhoenixUser(claims.subject, claims)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class PhoenixUser(BaseUser):
|
|
71
|
+
def __init__(self, user_id: UserId, claims: UserClaimSet) -> None:
|
|
72
|
+
self._user_id = user_id
|
|
73
|
+
self.claims = claims
|
|
74
|
+
assert claims.attributes
|
|
75
|
+
self._is_admin = (
|
|
76
|
+
claims.status is ClaimSetStatus.VALID
|
|
77
|
+
and claims.attributes.user_role == enums.UserRole.ADMIN
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@cached_property
|
|
81
|
+
def is_admin(self) -> bool:
|
|
82
|
+
return self._is_admin
|
|
83
|
+
|
|
84
|
+
@cached_property
|
|
85
|
+
def identity(self) -> UserId:
|
|
86
|
+
return self._user_id
|
|
87
|
+
|
|
88
|
+
@cached_property
|
|
89
|
+
def is_authenticated(self) -> bool:
|
|
90
|
+
return True
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
|
|
94
|
+
async def intercept(
|
|
95
|
+
self,
|
|
96
|
+
method: Callable[[Any, grpc.ServicerContext], Awaitable[Any]],
|
|
97
|
+
request_or_iterator: Any,
|
|
98
|
+
context: grpc.ServicerContext,
|
|
99
|
+
method_name: str,
|
|
100
|
+
) -> Any:
|
|
101
|
+
for datum in context.invocation_metadata():
|
|
102
|
+
if datum.key.lower() == "authorization":
|
|
103
|
+
scheme, _, token = datum.value.partition(" ")
|
|
104
|
+
if scheme.lower() != "bearer" or not token:
|
|
105
|
+
break
|
|
106
|
+
claims = await self._token_store.read(Token(token))
|
|
107
|
+
if not (isinstance(claims, UserClaimSet) and isinstance(claims.subject, UserId)):
|
|
108
|
+
break
|
|
109
|
+
if not isinstance(claims, (ApiKeyClaims, AccessTokenClaims)):
|
|
110
|
+
raise Unauthenticated(details="Invalid token")
|
|
111
|
+
if claims.status is ClaimSetStatus.EXPIRED:
|
|
112
|
+
raise Unauthenticated(details="Expired token")
|
|
113
|
+
if claims.status is ClaimSetStatus.VALID:
|
|
114
|
+
return await method(request_or_iterator, context)
|
|
115
|
+
raise Unauthenticated()
|
|
116
|
+
raise Unauthenticated()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
async def is_authenticated(request: Request) -> None:
|
|
120
|
+
"""
|
|
121
|
+
Raises a 401 if the request is not authenticated.
|
|
122
|
+
"""
|
|
123
|
+
if not isinstance((user := request.user), PhoenixUser):
|
|
124
|
+
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
|
125
|
+
claims = user.claims
|
|
126
|
+
if claims.status is ClaimSetStatus.EXPIRED:
|
|
127
|
+
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
|
|
128
|
+
if claims.status is not ClaimSetStatus.VALID:
|
|
129
|
+
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
async def create_access_and_refresh_tokens(
|
|
133
|
+
*,
|
|
134
|
+
token_store: TokenStore,
|
|
135
|
+
user: OrmUser,
|
|
136
|
+
access_token_expiry: timedelta,
|
|
137
|
+
refresh_token_expiry: timedelta,
|
|
138
|
+
) -> Tuple[AccessToken, RefreshToken]:
|
|
139
|
+
issued_at = datetime.now(timezone.utc)
|
|
140
|
+
user_id = UserId(user.id)
|
|
141
|
+
user_role = UserRole(user.role.name)
|
|
142
|
+
refresh_token_claims = RefreshTokenClaims(
|
|
143
|
+
subject=user_id,
|
|
144
|
+
issued_at=issued_at,
|
|
145
|
+
expiration_time=issued_at + refresh_token_expiry,
|
|
146
|
+
attributes=RefreshTokenAttributes(
|
|
147
|
+
user_role=user_role,
|
|
148
|
+
),
|
|
149
|
+
)
|
|
150
|
+
refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims)
|
|
151
|
+
access_token_claims = AccessTokenClaims(
|
|
152
|
+
subject=user_id,
|
|
153
|
+
issued_at=issued_at,
|
|
154
|
+
expiration_time=issued_at + access_token_expiry,
|
|
155
|
+
attributes=AccessTokenAttributes(
|
|
156
|
+
user_role=user_role,
|
|
157
|
+
refresh_token_id=refresh_token_id,
|
|
158
|
+
),
|
|
159
|
+
)
|
|
160
|
+
access_token, _ = await token_store.create_access_token(access_token_claims)
|
|
161
|
+
return access_token, refresh_token
|
|
File without changes
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from fastapi_mail import ConnectionConfig, FastMail, MessageSchema
|
|
4
|
+
|
|
5
|
+
EMAIL_TEMPLATE_FOLDER = Path(__file__).parent / "templates"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FastMailSender:
|
|
9
|
+
def __init__(self, conf: ConnectionConfig) -> None:
|
|
10
|
+
self._fm = FastMail(conf)
|
|
11
|
+
|
|
12
|
+
async def send_password_reset_email(
|
|
13
|
+
self,
|
|
14
|
+
email: str,
|
|
15
|
+
reset_url: str,
|
|
16
|
+
) -> None:
|
|
17
|
+
message = MessageSchema(
|
|
18
|
+
subject="[Phoenix] Password Reset Request",
|
|
19
|
+
recipients=[email],
|
|
20
|
+
template_body=dict(reset_url=reset_url),
|
|
21
|
+
subtype="html",
|
|
22
|
+
)
|
|
23
|
+
await self._fm.send_message(
|
|
24
|
+
message,
|
|
25
|
+
template_name="password_reset.html",
|
|
26
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
<!DOCTYPE html>
|
|
2
|
+
<html>
|
|
3
|
+
<head>
|
|
4
|
+
<meta charset="UTF-8" />
|
|
5
|
+
<title>Password Reset</title>
|
|
6
|
+
</head>
|
|
7
|
+
<body>
|
|
8
|
+
<p>Hello.</p>
|
|
9
|
+
<p>
|
|
10
|
+
You have requested a password reset. Please click on the link below to
|
|
11
|
+
reset your password:
|
|
12
|
+
</p>
|
|
13
|
+
<p>
|
|
14
|
+
<a id="reset-url" href="{{ reset_url }}">Reset Password</a
|
|
15
|
+
>
|
|
16
|
+
</p>
|
|
17
|
+
<p>If you did not make this request, please contact your administrator.</p>
|
|
18
|
+
</body>
|
|
19
|
+
</html>
|
phoenix/server/grpc_server.py
CHANGED
|
@@ -12,7 +12,9 @@ from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
|
|
|
12
12
|
)
|
|
13
13
|
from typing_extensions import TypeAlias
|
|
14
14
|
|
|
15
|
+
from phoenix.auth import CanReadToken
|
|
15
16
|
from phoenix.config import get_env_grpc_port
|
|
17
|
+
from phoenix.server.bearer_auth import ApiKeyInterceptor
|
|
16
18
|
from phoenix.trace.otel import decode_otlp_span
|
|
17
19
|
from phoenix.trace.schemas import Span
|
|
18
20
|
from phoenix.utilities.project import get_project_name
|
|
@@ -52,17 +54,21 @@ class GrpcServer:
|
|
|
52
54
|
tracer_provider: Optional["TracerProvider"] = None,
|
|
53
55
|
enable_prometheus: bool = False,
|
|
54
56
|
disabled: bool = False,
|
|
57
|
+
token_store: Optional[CanReadToken] = None,
|
|
55
58
|
) -> None:
|
|
56
59
|
self._callback = callback
|
|
57
60
|
self._server: Optional[Server] = None
|
|
58
61
|
self._tracer_provider = tracer_provider
|
|
59
62
|
self._enable_prometheus = enable_prometheus
|
|
60
63
|
self._disabled = disabled
|
|
64
|
+
self._token_store = token_store
|
|
61
65
|
|
|
62
66
|
async def __aenter__(self) -> None:
|
|
63
67
|
if self._disabled:
|
|
64
68
|
return
|
|
65
69
|
interceptors: List[ServerInterceptor] = []
|
|
70
|
+
if self._token_store:
|
|
71
|
+
interceptors.append(ApiKeyInterceptor(self._token_store))
|
|
66
72
|
if self._enable_prometheus:
|
|
67
73
|
...
|
|
68
74
|
# TODO: convert to async interceptor
|