arize-phoenix 4.24.0__py3-none-any.whl → 4.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (46) hide show
  1. {arize_phoenix-4.24.0.dist-info → arize_phoenix-4.26.0.dist-info}/METADATA +12 -7
  2. {arize_phoenix-4.24.0.dist-info → arize_phoenix-4.26.0.dist-info}/RECORD +46 -41
  3. phoenix/auth.py +45 -0
  4. phoenix/db/engines.py +15 -2
  5. phoenix/db/insertion/dataset.py +1 -0
  6. phoenix/db/migrate.py +21 -10
  7. phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +7 -6
  8. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +4 -12
  9. phoenix/db/models.py +1 -1
  10. phoenix/inferences/fixtures.py +1 -0
  11. phoenix/inferences/inferences.py +1 -0
  12. phoenix/metrics/__init__.py +1 -0
  13. phoenix/server/api/context.py +14 -0
  14. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  15. phoenix/server/api/mutations/__init__.py +4 -0
  16. phoenix/server/api/mutations/api_key_mutations.py +119 -0
  17. phoenix/server/api/mutations/auth.py +7 -0
  18. phoenix/server/api/mutations/user_mutations.py +89 -0
  19. phoenix/server/api/queries.py +7 -6
  20. phoenix/server/api/routers/auth.py +52 -0
  21. phoenix/server/api/routers/v1/datasets.py +1 -0
  22. phoenix/server/api/routers/v1/spans.py +1 -1
  23. phoenix/server/api/types/UserRole.py +1 -1
  24. phoenix/server/app.py +61 -9
  25. phoenix/server/main.py +24 -19
  26. phoenix/server/static/.vite/manifest.json +31 -31
  27. phoenix/server/static/assets/{components-DzA9gIHT.js → components-1Ahruijo.js} +4 -4
  28. phoenix/server/static/assets/{index-BuTlV4Gk.js → index-BEE_RWJx.js} +2 -2
  29. phoenix/server/static/assets/{pages-DzkUGFGV.js → pages-CFS6mPnW.js} +263 -220
  30. phoenix/server/static/assets/{vendor-CIqy43_9.js → vendor-aSQri0vz.js} +58 -58
  31. phoenix/server/static/assets/{vendor-arizeai-B1YgcWL8.js → vendor-arizeai-CsdcB1NH.js} +1 -1
  32. phoenix/server/static/assets/{vendor-codemirror-_bcwCA1C.js → vendor-codemirror-CYHkhs7D.js} +1 -1
  33. phoenix/server/static/assets/{vendor-recharts-C3pM_Wlg.js → vendor-recharts-B0sannek.js} +1 -1
  34. phoenix/server/types.py +12 -4
  35. phoenix/services.py +1 -0
  36. phoenix/session/client.py +1 -1
  37. phoenix/session/evaluation.py +1 -0
  38. phoenix/session/session.py +2 -1
  39. phoenix/trace/fixtures.py +37 -0
  40. phoenix/trace/langchain/instrumentor.py +1 -1
  41. phoenix/trace/llama_index/callback.py +1 -0
  42. phoenix/trace/openai/instrumentor.py +1 -0
  43. phoenix/version.py +1 -1
  44. {arize_phoenix-4.24.0.dist-info → arize_phoenix-4.26.0.dist-info}/WHEEL +0 -0
  45. {arize_phoenix-4.24.0.dist-info → arize_phoenix-4.26.0.dist-info}/licenses/IP_NOTICE +0 -0
  46. {arize_phoenix-4.24.0.dist-info → arize_phoenix-4.26.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,119 @@
1
+ from datetime import datetime
2
+ from typing import Any, Dict, Optional
3
+
4
+ import jwt
5
+ import strawberry
6
+ from sqlalchemy import insert, select
7
+ from strawberry import UNSET
8
+ from strawberry.types import Info
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.api.context import Context
12
+ from phoenix.server.api.mutations.auth import HasSecret, IsAuthenticated
13
+ from phoenix.server.api.queries import Query
14
+ from phoenix.server.api.types.SystemApiKey import SystemApiKey
15
+
16
+
17
+ @strawberry.type
18
+ class CreateSystemApiKeyMutationPayload:
19
+ jwt: str
20
+ api_key: SystemApiKey
21
+ query: Query
22
+
23
+
24
+ @strawberry.input
25
+ class CreateApiKeyInput:
26
+ name: str
27
+ description: Optional[str] = UNSET
28
+ expires_at: Optional[datetime] = UNSET
29
+
30
+
31
+ @strawberry.type
32
+ class ApiKeyMutationMixin:
33
+ @strawberry.mutation(permission_classes=[HasSecret, IsAuthenticated]) # type: ignore
34
+ async def create_system_api_key(
35
+ self, info: Info[Context, None], input: CreateApiKeyInput
36
+ ) -> CreateSystemApiKeyMutationPayload:
37
+ # TODO(auth): safe guard against auth being disabled and secret not being set
38
+ async with info.context.db() as session:
39
+ # Get the system user - note this could be pushed into a dataloader
40
+ system_user = await session.scalar(
41
+ select(models.User)
42
+ .join(models.UserRole) # Join User with UserRole
43
+ .where(models.UserRole.name == "SYSTEM") # Filter where role is SYSTEM
44
+ .limit(1)
45
+ )
46
+ if system_user is None:
47
+ raise ValueError("System user not found")
48
+
49
+ insert_stmt = (
50
+ insert(models.APIKey)
51
+ .values(
52
+ user_id=system_user.id,
53
+ name=input.name,
54
+ description=input.description or None,
55
+ expires_at=input.expires_at or None,
56
+ )
57
+ .returning(models.APIKey)
58
+ )
59
+ api_key = await session.scalar(insert_stmt)
60
+ assert api_key is not None
61
+
62
+ encoded_jwt = create_jwt(
63
+ secret=info.context.get_secret(),
64
+ name=api_key.name,
65
+ id=api_key.id,
66
+ description=api_key.description,
67
+ iat=api_key.created_at,
68
+ exp=api_key.expires_at,
69
+ )
70
+ return CreateSystemApiKeyMutationPayload(
71
+ jwt=encoded_jwt,
72
+ api_key=SystemApiKey(
73
+ id_attr=api_key.id,
74
+ name=api_key.name,
75
+ description=api_key.description,
76
+ created_at=api_key.created_at,
77
+ expires_at=api_key.expires_at,
78
+ ),
79
+ query=Query(),
80
+ )
81
+
82
+
83
+ def create_jwt(
84
+ *,
85
+ secret: str,
86
+ algorithm: str = "HS256",
87
+ name: str,
88
+ description: Optional[str],
89
+ iat: datetime,
90
+ exp: Optional[datetime],
91
+ id: int,
92
+ ) -> str:
93
+ """Create a signed JSON Web Token for authentication
94
+
95
+ Args:
96
+ secret (str): the secret to sign with
97
+ name (str): name of the key / token
98
+ description (Optional[str]): description of the token
99
+ iat (datetime): the issued at time
100
+ exp (Optional[datetime]): the expiry, if set
101
+ id (int): the id of the key
102
+ algorithm (str, optional): the algorithm to use. Defaults to "HS256".
103
+
104
+ Returns:
105
+ str: The encoded JWT
106
+ """
107
+ payload: Dict[str, Any] = {
108
+ "name": name,
109
+ "description": description,
110
+ "iat": iat.utcnow(),
111
+ "id": id,
112
+ }
113
+ if exp is not None:
114
+ payload["exp"] = exp.utcnow()
115
+
116
+ # Encode the payload to create the JWT
117
+ token = jwt.encode(payload, secret, algorithm=algorithm)
118
+
119
+ return token
@@ -9,3 +9,10 @@ class IsAuthenticated(BasePermission):
9
9
 
10
10
  async def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
11
11
  return not info.context.read_only
12
+
13
+
14
+ class HasSecret(BasePermission):
15
+ message = "Application secret is not set"
16
+
17
+ async def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
18
+ return info.context.secret is not None
@@ -0,0 +1,89 @@
1
+ import asyncio
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from sqlalchemy import insert, select
6
+ from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
7
+ from strawberry.types import Info
8
+
9
+ from phoenix.auth import compute_password_hash, validate_email_format, validate_password_format
10
+ from phoenix.db import models
11
+ from phoenix.server.api.context import Context
12
+ from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
13
+ from phoenix.server.api.types.User import User
14
+ from phoenix.server.api.types.UserRole import UserRole
15
+
16
+
17
+ @strawberry.input
18
+ class CreateUserInput:
19
+ email: str
20
+ username: Optional[str]
21
+ password: str
22
+ role: UserRoleInput
23
+
24
+
25
+ @strawberry.type
26
+ class UserMutationPayload:
27
+ user: User
28
+
29
+
30
+ @strawberry.type
31
+ class UserMutationMixin:
32
+ @strawberry.mutation
33
+ async def create_user(
34
+ self,
35
+ info: Info[Context, None],
36
+ input: CreateUserInput,
37
+ ) -> UserMutationPayload:
38
+ validate_email_format(email := input.email)
39
+ validate_password_format(password := input.password)
40
+ role_name = input.role.value
41
+ user_role_id = (
42
+ select(models.UserRole.id).where(models.UserRole.name == role_name).scalar_subquery()
43
+ )
44
+ secret = info.context.get_secret()
45
+ loop = asyncio.get_running_loop()
46
+ password_hash = await loop.run_in_executor(
47
+ executor=None,
48
+ func=lambda: compute_password_hash(password=password, salt=secret),
49
+ )
50
+ try:
51
+ async with info.context.db() as session:
52
+ user = await session.scalar(
53
+ insert(models.User)
54
+ .values(
55
+ user_role_id=user_role_id,
56
+ username=input.username,
57
+ email=email,
58
+ auth_method="LOCAL",
59
+ password_hash=password_hash,
60
+ reset_password=True,
61
+ )
62
+ .returning(models.User)
63
+ )
64
+ assert user is not None
65
+ except IntegrityError as error:
66
+ raise ValueError(_get_user_create_error_message(error))
67
+ return UserMutationPayload(
68
+ user=User(
69
+ id_attr=user.id,
70
+ email=user.email,
71
+ username=user.username,
72
+ created_at=user.created_at,
73
+ role=UserRole(id_attr=user.user_role_id, name=role_name),
74
+ )
75
+ )
76
+
77
+
78
+ def _get_user_create_error_message(error: IntegrityError) -> str:
79
+ """
80
+ Gets a user-facing error message to explain why user creation failed.
81
+ """
82
+ original_error_message = str(error)
83
+ username_already_exists = "users.username" in original_error_message
84
+ email_already_exists = "users.email" in original_error_message
85
+ if username_already_exists:
86
+ return "Username already exists"
87
+ elif email_already_exists:
88
+ return "Email already exists"
89
+ return "Failed to create user"
@@ -92,7 +92,8 @@ class Query:
92
92
  )
93
93
  stmt = (
94
94
  select(models.User)
95
- .where(models.UserRole.role != "SYSTEM")
95
+ .join(models.UserRole)
96
+ .where(models.UserRole.name != "SYSTEM")
96
97
  .order_by(models.User.email)
97
98
  .options(joinedload(models.User.role))
98
99
  )
@@ -106,7 +107,7 @@ class Query:
106
107
  created_at=user.created_at,
107
108
  role=UserRole(
108
109
  id_attr=user.role.id,
109
- role=user.role.role,
110
+ name=user.role.name,
110
111
  ),
111
112
  )
112
113
  async for user in users
@@ -120,12 +121,12 @@ class Query:
120
121
  ) -> List[UserRole]:
121
122
  async with info.context.db() as session:
122
123
  roles = await session.scalars(
123
- select(models.UserRole).where(models.UserRole.role != "SYSTEM")
124
+ select(models.UserRole).where(models.UserRole.name != "SYSTEM")
124
125
  )
125
126
  return [
126
127
  UserRole(
127
128
  id_attr=role.id,
128
- role=role.role,
129
+ name=role.name,
129
130
  )
130
131
  for role in roles
131
132
  ]
@@ -137,7 +138,7 @@ class Query:
137
138
  select(models.APIKey)
138
139
  .join(models.User)
139
140
  .join(models.UserRole)
140
- .where(models.UserRole.role != "SYSTEM")
141
+ .where(models.UserRole.name != "SYSTEM")
141
142
  )
142
143
  async with info.context.db() as session:
143
144
  api_keys = await session.scalars(stmt)
@@ -160,7 +161,7 @@ class Query:
160
161
  select(models.APIKey)
161
162
  .join(models.User)
162
163
  .join(models.UserRole)
163
- .where(models.UserRole.role == "SYSTEM")
164
+ .where(models.UserRole.name == "SYSTEM")
164
165
  )
165
166
  async with info.context.db() as session:
166
167
  api_keys = await session.scalars(stmt)
@@ -0,0 +1,52 @@
1
+ import asyncio
2
+ from datetime import timedelta
3
+
4
+ from fastapi import APIRouter, Form, Request, Response
5
+ from sqlalchemy import select
6
+ from starlette.status import HTTP_204_NO_CONTENT, HTTP_401_UNAUTHORIZED
7
+ from typing_extensions import Annotated
8
+
9
+ from phoenix.auth import is_valid_password
10
+ from phoenix.db import models
11
+
12
+ router = APIRouter(include_in_schema=False)
13
+
14
+ PHOENIX_ACCESS_TOKEN_COOKIE_NAME = "phoenix-access-token"
15
+ PHOENIX_ACCESS_TOKEN_COOKIE_MAX_AGE_IN_SECONDS = int(timedelta(days=31).total_seconds())
16
+
17
+
18
+ @router.post("/login")
19
+ async def login(
20
+ request: Request,
21
+ email: Annotated[str, Form()],
22
+ password: Annotated[str, Form()],
23
+ ) -> Response:
24
+ async with request.app.state.db() as session:
25
+ if (
26
+ user := await session.scalar(select(models.User).where(models.User.email == email))
27
+ ) is None or (password_hash := user.password_hash) is None:
28
+ return Response(status_code=HTTP_401_UNAUTHORIZED)
29
+ secret = request.app.state.get_secret()
30
+ loop = asyncio.get_running_loop()
31
+ if not await loop.run_in_executor(
32
+ executor=None,
33
+ func=lambda: is_valid_password(password=password, salt=secret, password_hash=password_hash),
34
+ ):
35
+ return Response(status_code=HTTP_401_UNAUTHORIZED)
36
+ response = Response(status_code=HTTP_204_NO_CONTENT)
37
+ response.set_cookie(
38
+ key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
39
+ value="token", # todo: compute access token
40
+ secure=True,
41
+ httponly=True,
42
+ samesite="strict",
43
+ max_age=PHOENIX_ACCESS_TOKEN_COOKIE_MAX_AGE_IN_SECONDS,
44
+ )
45
+ return response
46
+
47
+
48
+ @router.post("/logout")
49
+ async def logout() -> Response:
50
+ response = Response(status_code=HTTP_204_NO_CONTENT)
51
+ response.delete_cookie(key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME)
52
+ return response
@@ -71,6 +71,7 @@ from .utils import (
71
71
  )
72
72
 
73
73
  logger = logging.getLogger(__name__)
74
+ logger.addHandler(logging.NullHandler())
74
75
 
75
76
  DATASET_NODE_NAME = DatasetNodeType.__name__
76
77
  DATASET_VERSION_NODE_NAME = DatasetVersionNodeType.__name__
@@ -196,7 +196,7 @@ class AnnotateSpansResponseBody(ResponseBody[List[InsertedSpanAnnotation]]):
196
196
  async def annotate_spans(
197
197
  request: Request,
198
198
  request_body: AnnotateSpansRequestBody,
199
- sync: bool = Query(default=True, description="If true, fulfill request synchronously."),
199
+ sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
200
200
  ) -> AnnotateSpansResponseBody:
201
201
  if not request_body.data:
202
202
  return AnnotateSpansResponseBody(data=[])
@@ -5,4 +5,4 @@ from strawberry.relay import Node, NodeID
5
5
  @strawberry.type
6
6
  class UserRole(Node):
7
7
  id_attr: NodeID[int]
8
- role: str
8
+ name: str
phoenix/server/app.py CHANGED
@@ -4,6 +4,7 @@ import json
4
4
  import logging
5
5
  from functools import cached_property
6
6
  from pathlib import Path
7
+ from types import MethodType
7
8
  from typing import (
8
9
  TYPE_CHECKING,
9
10
  Any,
@@ -30,6 +31,7 @@ from sqlalchemy.ext.asyncio import (
30
31
  AsyncSession,
31
32
  async_sessionmaker,
32
33
  )
34
+ from starlette.datastructures import State as StarletteState
33
35
  from starlette.exceptions import HTTPException
34
36
  from starlette.middleware import Middleware
35
37
  from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
@@ -80,6 +82,7 @@ from phoenix.server.api.dataloaders import (
80
82
  TokenCountDataLoader,
81
83
  TraceRowIdsDataLoader,
82
84
  )
85
+ from phoenix.server.api.routers.auth import router as auth_router
83
86
  from phoenix.server.api.routers.v1 import REST_API_VERSION
84
87
  from phoenix.server.api.routers.v1 import router as v1_router
85
88
  from phoenix.server.api.schema import schema
@@ -100,6 +103,7 @@ if TYPE_CHECKING:
100
103
  from opentelemetry.trace import TracerProvider
101
104
 
102
105
  logger = logging.getLogger(__name__)
106
+ logger.addHandler(logging.NullHandler())
103
107
 
104
108
  router = APIRouter(include_in_schema=False)
105
109
 
@@ -229,7 +233,8 @@ def _lifespan(
229
233
  dml_event_handler: DmlEventHandler,
230
234
  tracer_provider: Optional["TracerProvider"] = None,
231
235
  enable_prometheus: bool = False,
232
- clean_ups: Iterable[Callable[[], None]] = (),
236
+ startup_callbacks: Iterable[Callable[[], None]] = (),
237
+ shutdown_callbacks: Iterable[Callable[[], None]] = (),
233
238
  read_only: bool = False,
234
239
  ) -> StatefulLifespan[FastAPI]:
235
240
  @contextlib.asynccontextmanager
@@ -247,6 +252,8 @@ def _lifespan(
247
252
  tracer_provider=tracer_provider,
248
253
  enable_prometheus=enable_prometheus,
249
254
  ), dml_event_handler:
255
+ for callback in startup_callbacks:
256
+ callback()
250
257
  yield {
251
258
  "event_queue": dml_event_handler,
252
259
  "enqueue": enqueue,
@@ -254,8 +261,8 @@ def _lifespan(
254
261
  "queue_evaluation_for_bulk_insert": queue_evaluation,
255
262
  "enqueue_operation": enqueue_operation,
256
263
  }
257
- for clean_up in clean_ups:
258
- clean_up()
264
+ for callback in shutdown_callbacks:
265
+ callback()
259
266
 
260
267
  return lifespan
261
268
 
@@ -276,7 +283,26 @@ def create_graphql_router(
276
283
  cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
277
284
  event_queue: CanPutItem[DmlEvent],
278
285
  read_only: bool = False,
286
+ secret: Optional[str] = None,
279
287
  ) -> GraphQLRouter: # type: ignore[type-arg]
288
+ """Creates the GraphQL router.
289
+
290
+ Args:
291
+ schema (BaseSchema): The GraphQL schema.
292
+ db (DbSessionFactory): The database session factory pointing to a SQL database.
293
+ model (Model): The Model representing inferences (legacy)
294
+ export_path (Path): the file path to export data to for download (legacy)
295
+ last_updated_at (CanGetLastUpdatedAt): How to get the last updated timestamp for updates.
296
+ event_queue (CanPutItem[DmlEvent]): The event queue for DML events.
297
+ corpus (Optional[Model], optional): the corpus for UMAP projection. Defaults to None.
298
+ cache_for_dataloaders (Optional[CacheForDataLoaders], optional): GraphQL data loaders.
299
+ read_only (bool, optional): Marks the app as read-only. Defaults to False.
300
+ secret (Optional[str], optional): The application secret for auth. Defaults to None.
301
+
302
+ Returns:
303
+ GraphQLRouter: The router mounted at /graphql
304
+ """
305
+
280
306
  def get_context() -> Context:
281
307
  return Context(
282
308
  db=db,
@@ -336,6 +362,7 @@ def create_graphql_router(
336
362
  ),
337
363
  cache_for_dataloaders=cache_for_dataloaders,
338
364
  read_only=read_only,
365
+ secret=secret,
339
366
  )
340
367
 
341
368
  return GraphQLRouter(
@@ -408,9 +435,12 @@ def create_app(
408
435
  initial_spans: Optional[Iterable[Union[Span, Tuple[Span, str]]]] = None,
409
436
  initial_evaluations: Optional[Iterable[pb.Evaluation]] = None,
410
437
  serve_ui: bool = True,
411
- clean_up_callbacks: List[Callable[[], None]] = [],
438
+ startup_callbacks: Iterable[Callable[[], None]] = (),
439
+ shutdown_callbacks: Iterable[Callable[[], None]] = (),
440
+ secret: Optional[str] = None,
412
441
  ) -> FastAPI:
413
- clean_ups: List[Callable[[], None]] = clean_up_callbacks # To be called at app shutdown.
442
+ startup_callbacks_list: List[Callable[[], None]] = list(startup_callbacks)
443
+ shutdown_callbacks_list: List[Callable[[], None]] = list(shutdown_callbacks)
414
444
  initial_batch_of_spans: Iterable[Tuple[Span, str]] = (
415
445
  ()
416
446
  if initial_spans is None
@@ -472,6 +502,7 @@ def create_app(
472
502
  event_queue=dml_event_handler,
473
503
  cache_for_dataloaders=cache_for_dataloaders,
474
504
  read_only=read_only,
505
+ secret=secret,
475
506
  )
476
507
  if enable_prometheus:
477
508
  from phoenix.server.prometheus import PrometheusMiddleware
@@ -489,7 +520,8 @@ def create_app(
489
520
  dml_event_handler=dml_event_handler,
490
521
  tracer_provider=tracer_provider,
491
522
  enable_prometheus=enable_prometheus,
492
- clean_ups=clean_ups,
523
+ shutdown_callbacks=shutdown_callbacks_list,
524
+ startup_callbacks=startup_callbacks_list,
493
525
  ),
494
526
  middleware=[
495
527
  Middleware(HeadersMiddleware),
@@ -507,6 +539,8 @@ def create_app(
507
539
  app.include_router(router)
508
540
  app.include_router(graphql_router)
509
541
  app.add_middleware(GZipMiddleware)
542
+ if authentication_enabled:
543
+ app.include_router(auth_router)
510
544
  if serve_ui:
511
545
  app.mount(
512
546
  "/",
@@ -525,12 +559,30 @@ def create_app(
525
559
  ),
526
560
  name="static",
527
561
  )
528
-
529
- app.state.db = db
562
+ app = _update_app_state(app, db=db, secret=secret)
530
563
  if tracer_provider:
531
564
  from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
532
565
 
533
566
  FastAPIInstrumentor().instrument(tracer_provider=tracer_provider)
534
567
  FastAPIInstrumentor.instrument_app(app, tracer_provider=tracer_provider)
535
- clean_ups.append(FastAPIInstrumentor().uninstrument)
568
+ shutdown_callbacks_list.append(FastAPIInstrumentor().uninstrument)
569
+ return app
570
+
571
+
572
+ def _update_app_state(app: FastAPI, /, *, db: DbSessionFactory, secret: Optional[str]) -> FastAPI:
573
+ """
574
+ Dynamically updates the app's `state` to include useful fields and methods
575
+ (at the time of this writing, FastAPI does not support setting this state
576
+ during the creation of the app).
577
+ """
578
+ app.state.db = db
579
+ app.state._secret = secret
580
+
581
+ def get_secret(self: StarletteState) -> str:
582
+ if (secret := self._secret) is None:
583
+ raise ValueError("app secret is not set")
584
+ assert isinstance(secret, str)
585
+ return secret
586
+
587
+ app.state.get_secret = MethodType(get_secret, app.state)
536
588
  return app
phoenix/server/main.py CHANGED
@@ -1,13 +1,16 @@
1
1
  import atexit
2
+ import codecs
2
3
  import logging
3
4
  import os
5
+ import sys
4
6
  from argparse import ArgumentParser
5
- from pathlib import Path, PosixPath
7
+ from importlib.metadata import version
8
+ from pathlib import Path
6
9
  from threading import Thread
7
10
  from time import sleep, time
8
11
  from typing import List, Optional
12
+ from urllib.parse import urljoin
9
13
 
10
- import pkg_resources
11
14
  from uvicorn import Config, Server
12
15
 
13
16
  import phoenix.trace.v1 as pb
@@ -53,6 +56,7 @@ from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp
53
56
  from phoenix.trace.schemas import Span
54
57
 
55
58
  logger = logging.getLogger(__name__)
59
+ logger.addHandler(logging.NullHandler())
56
60
 
57
61
  _WELCOME_MESSAGE = """
58
62
 
@@ -137,6 +141,7 @@ if __name__ == "__main__":
137
141
  parser.add_argument("--debug", action="store_true")
138
142
  # Whether the app is running in a development environment
139
143
  parser.add_argument("--dev", action="store_true")
144
+ parser.add_argument("--no-ui", action="store_true")
140
145
  subparsers = parser.add_subparsers(dest="command", required=True)
141
146
  serve_parser = subparsers.add_parser("serve")
142
147
  datasets_parser = subparsers.add_parser("datasets")
@@ -218,7 +223,7 @@ if __name__ == "__main__":
218
223
  reference_inferences,
219
224
  )
220
225
 
221
- authentication_enabled, auth_secret = get_auth_settings()
226
+ authentication_enabled, secret = get_auth_settings()
222
227
 
223
228
  fixture_spans: List[Span] = []
224
229
  fixture_evals: List[pb.Evaluation] = []
@@ -255,6 +260,18 @@ if __name__ == "__main__":
255
260
  engine = create_engine_and_run_migrations(db_connection_str)
256
261
  instrumentation_cleanups = instrument_engine_if_enabled(engine)
257
262
  factory = DbSessionFactory(db=_db(engine), dialect=engine.dialect.name)
263
+ # Print information about the server
264
+ msg = _WELCOME_MESSAGE.format(
265
+ version=version("arize-phoenix"),
266
+ ui_path=urljoin(f"http://{host}:{port}", host_root_path),
267
+ grpc_path=f"http://{host}:{get_env_grpc_port()}",
268
+ http_path=urljoin(urljoin(f"http://{host}:{port}", host_root_path), "v1/traces"),
269
+ storage=get_printable_db_url(db_connection_str),
270
+ )
271
+ if authentication_enabled:
272
+ msg += _EXPERIMENTAL_WARNING.format(auth_enabled=True)
273
+ if sys.platform.startswith("win"):
274
+ msg = codecs.encode(msg, "ascii", errors="ignore").decode("ascii").strip()
258
275
  app = create_app(
259
276
  db=factory,
260
277
  export_path=export_path,
@@ -266,29 +283,17 @@ if __name__ == "__main__":
266
283
  else create_model_from_inferences(corpus_inferences),
267
284
  debug=args.debug,
268
285
  dev=args.dev,
286
+ serve_ui=not args.no_ui,
269
287
  read_only=read_only,
270
288
  enable_prometheus=enable_prometheus,
271
289
  initial_spans=fixture_spans,
272
290
  initial_evaluations=fixture_evals,
273
- clean_up_callbacks=instrumentation_cleanups,
291
+ startup_callbacks=[lambda: print(msg)],
292
+ shutdown_callbacks=instrumentation_cleanups,
293
+ secret=secret,
274
294
  )
275
295
  server = Server(config=Config(app, host=host, port=port, root_path=host_root_path)) # type: ignore
276
296
  Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start()
277
297
 
278
- # Print information about the server
279
- phoenix_version = pkg_resources.get_distribution("arize-phoenix").version
280
- print(
281
- _WELCOME_MESSAGE.format(
282
- version=phoenix_version,
283
- ui_path=PosixPath(f"http://{host}:{port}", host_root_path),
284
- grpc_path=f"http://{host}:{get_env_grpc_port()}",
285
- http_path=PosixPath(f"http://{host}:{port}", host_root_path, "v1/traces"),
286
- storage=get_printable_db_url(db_connection_str),
287
- )
288
- )
289
-
290
- if authentication_enabled:
291
- print(_EXPERIMENTAL_WARNING.format(auth_enabled=authentication_enabled))
292
-
293
298
  # Start the server
294
299
  server.run()