arize-phoenix 4.24.0__py3-none-any.whl → 4.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.24.0.dist-info → arize_phoenix-4.26.0.dist-info}/METADATA +12 -7
- {arize_phoenix-4.24.0.dist-info → arize_phoenix-4.26.0.dist-info}/RECORD +46 -41
- phoenix/auth.py +45 -0
- phoenix/db/engines.py +15 -2
- phoenix/db/insertion/dataset.py +1 -0
- phoenix/db/migrate.py +21 -10
- phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +7 -6
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +4 -12
- phoenix/db/models.py +1 -1
- phoenix/inferences/fixtures.py +1 -0
- phoenix/inferences/inferences.py +1 -0
- phoenix/metrics/__init__.py +1 -0
- phoenix/server/api/context.py +14 -0
- phoenix/server/api/input_types/UserRoleInput.py +9 -0
- phoenix/server/api/mutations/__init__.py +4 -0
- phoenix/server/api/mutations/api_key_mutations.py +119 -0
- phoenix/server/api/mutations/auth.py +7 -0
- phoenix/server/api/mutations/user_mutations.py +89 -0
- phoenix/server/api/queries.py +7 -6
- phoenix/server/api/routers/auth.py +52 -0
- phoenix/server/api/routers/v1/datasets.py +1 -0
- phoenix/server/api/routers/v1/spans.py +1 -1
- phoenix/server/api/types/UserRole.py +1 -1
- phoenix/server/app.py +61 -9
- phoenix/server/main.py +24 -19
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-DzA9gIHT.js → components-1Ahruijo.js} +4 -4
- phoenix/server/static/assets/{index-BuTlV4Gk.js → index-BEE_RWJx.js} +2 -2
- phoenix/server/static/assets/{pages-DzkUGFGV.js → pages-CFS6mPnW.js} +263 -220
- phoenix/server/static/assets/{vendor-CIqy43_9.js → vendor-aSQri0vz.js} +58 -58
- phoenix/server/static/assets/{vendor-arizeai-B1YgcWL8.js → vendor-arizeai-CsdcB1NH.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-_bcwCA1C.js → vendor-codemirror-CYHkhs7D.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-C3pM_Wlg.js → vendor-recharts-B0sannek.js} +1 -1
- phoenix/server/types.py +12 -4
- phoenix/services.py +1 -0
- phoenix/session/client.py +1 -1
- phoenix/session/evaluation.py +1 -0
- phoenix/session/session.py +2 -1
- phoenix/trace/fixtures.py +37 -0
- phoenix/trace/langchain/instrumentor.py +1 -1
- phoenix/trace/llama_index/callback.py +1 -0
- phoenix/trace/openai/instrumentor.py +1 -0
- phoenix/version.py +1 -1
- {arize_phoenix-4.24.0.dist-info → arize_phoenix-4.26.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.24.0.dist-info → arize_phoenix-4.26.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.24.0.dist-info → arize_phoenix-4.26.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
import jwt
|
|
5
|
+
import strawberry
|
|
6
|
+
from sqlalchemy import insert, select
|
|
7
|
+
from strawberry import UNSET
|
|
8
|
+
from strawberry.types import Info
|
|
9
|
+
|
|
10
|
+
from phoenix.db import models
|
|
11
|
+
from phoenix.server.api.context import Context
|
|
12
|
+
from phoenix.server.api.mutations.auth import HasSecret, IsAuthenticated
|
|
13
|
+
from phoenix.server.api.queries import Query
|
|
14
|
+
from phoenix.server.api.types.SystemApiKey import SystemApiKey
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@strawberry.type
|
|
18
|
+
class CreateSystemApiKeyMutationPayload:
|
|
19
|
+
jwt: str
|
|
20
|
+
api_key: SystemApiKey
|
|
21
|
+
query: Query
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@strawberry.input
|
|
25
|
+
class CreateApiKeyInput:
|
|
26
|
+
name: str
|
|
27
|
+
description: Optional[str] = UNSET
|
|
28
|
+
expires_at: Optional[datetime] = UNSET
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@strawberry.type
|
|
32
|
+
class ApiKeyMutationMixin:
|
|
33
|
+
@strawberry.mutation(permission_classes=[HasSecret, IsAuthenticated]) # type: ignore
|
|
34
|
+
async def create_system_api_key(
|
|
35
|
+
self, info: Info[Context, None], input: CreateApiKeyInput
|
|
36
|
+
) -> CreateSystemApiKeyMutationPayload:
|
|
37
|
+
# TODO(auth): safe guard against auth being disabled and secret not being set
|
|
38
|
+
async with info.context.db() as session:
|
|
39
|
+
# Get the system user - note this could be pushed into a dataloader
|
|
40
|
+
system_user = await session.scalar(
|
|
41
|
+
select(models.User)
|
|
42
|
+
.join(models.UserRole) # Join User with UserRole
|
|
43
|
+
.where(models.UserRole.name == "SYSTEM") # Filter where role is SYSTEM
|
|
44
|
+
.limit(1)
|
|
45
|
+
)
|
|
46
|
+
if system_user is None:
|
|
47
|
+
raise ValueError("System user not found")
|
|
48
|
+
|
|
49
|
+
insert_stmt = (
|
|
50
|
+
insert(models.APIKey)
|
|
51
|
+
.values(
|
|
52
|
+
user_id=system_user.id,
|
|
53
|
+
name=input.name,
|
|
54
|
+
description=input.description or None,
|
|
55
|
+
expires_at=input.expires_at or None,
|
|
56
|
+
)
|
|
57
|
+
.returning(models.APIKey)
|
|
58
|
+
)
|
|
59
|
+
api_key = await session.scalar(insert_stmt)
|
|
60
|
+
assert api_key is not None
|
|
61
|
+
|
|
62
|
+
encoded_jwt = create_jwt(
|
|
63
|
+
secret=info.context.get_secret(),
|
|
64
|
+
name=api_key.name,
|
|
65
|
+
id=api_key.id,
|
|
66
|
+
description=api_key.description,
|
|
67
|
+
iat=api_key.created_at,
|
|
68
|
+
exp=api_key.expires_at,
|
|
69
|
+
)
|
|
70
|
+
return CreateSystemApiKeyMutationPayload(
|
|
71
|
+
jwt=encoded_jwt,
|
|
72
|
+
api_key=SystemApiKey(
|
|
73
|
+
id_attr=api_key.id,
|
|
74
|
+
name=api_key.name,
|
|
75
|
+
description=api_key.description,
|
|
76
|
+
created_at=api_key.created_at,
|
|
77
|
+
expires_at=api_key.expires_at,
|
|
78
|
+
),
|
|
79
|
+
query=Query(),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def create_jwt(
|
|
84
|
+
*,
|
|
85
|
+
secret: str,
|
|
86
|
+
algorithm: str = "HS256",
|
|
87
|
+
name: str,
|
|
88
|
+
description: Optional[str],
|
|
89
|
+
iat: datetime,
|
|
90
|
+
exp: Optional[datetime],
|
|
91
|
+
id: int,
|
|
92
|
+
) -> str:
|
|
93
|
+
"""Create a signed JSON Web Token for authentication
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
secret (str): the secret to sign with
|
|
97
|
+
name (str): name of the key / token
|
|
98
|
+
description (Optional[str]): description of the token
|
|
99
|
+
iat (datetime): the issued at time
|
|
100
|
+
exp (Optional[datetime]): the expiry, if set
|
|
101
|
+
id (int): the id of the key
|
|
102
|
+
algorithm (str, optional): the algorithm to use. Defaults to "HS256".
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
str: The encoded JWT
|
|
106
|
+
"""
|
|
107
|
+
payload: Dict[str, Any] = {
|
|
108
|
+
"name": name,
|
|
109
|
+
"description": description,
|
|
110
|
+
"iat": iat.utcnow(),
|
|
111
|
+
"id": id,
|
|
112
|
+
}
|
|
113
|
+
if exp is not None:
|
|
114
|
+
payload["exp"] = exp.utcnow()
|
|
115
|
+
|
|
116
|
+
# Encode the payload to create the JWT
|
|
117
|
+
token = jwt.encode(payload, secret, algorithm=algorithm)
|
|
118
|
+
|
|
119
|
+
return token
|
|
@@ -9,3 +9,10 @@ class IsAuthenticated(BasePermission):
|
|
|
9
9
|
|
|
10
10
|
async def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
|
|
11
11
|
return not info.context.read_only
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class HasSecret(BasePermission):
|
|
15
|
+
message = "Application secret is not set"
|
|
16
|
+
|
|
17
|
+
async def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
|
|
18
|
+
return info.context.secret is not None
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from sqlalchemy import insert, select
|
|
6
|
+
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
|
|
7
|
+
from strawberry.types import Info
|
|
8
|
+
|
|
9
|
+
from phoenix.auth import compute_password_hash, validate_email_format, validate_password_format
|
|
10
|
+
from phoenix.db import models
|
|
11
|
+
from phoenix.server.api.context import Context
|
|
12
|
+
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
|
|
13
|
+
from phoenix.server.api.types.User import User
|
|
14
|
+
from phoenix.server.api.types.UserRole import UserRole
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@strawberry.input
|
|
18
|
+
class CreateUserInput:
|
|
19
|
+
email: str
|
|
20
|
+
username: Optional[str]
|
|
21
|
+
password: str
|
|
22
|
+
role: UserRoleInput
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@strawberry.type
|
|
26
|
+
class UserMutationPayload:
|
|
27
|
+
user: User
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@strawberry.type
|
|
31
|
+
class UserMutationMixin:
|
|
32
|
+
@strawberry.mutation
|
|
33
|
+
async def create_user(
|
|
34
|
+
self,
|
|
35
|
+
info: Info[Context, None],
|
|
36
|
+
input: CreateUserInput,
|
|
37
|
+
) -> UserMutationPayload:
|
|
38
|
+
validate_email_format(email := input.email)
|
|
39
|
+
validate_password_format(password := input.password)
|
|
40
|
+
role_name = input.role.value
|
|
41
|
+
user_role_id = (
|
|
42
|
+
select(models.UserRole.id).where(models.UserRole.name == role_name).scalar_subquery()
|
|
43
|
+
)
|
|
44
|
+
secret = info.context.get_secret()
|
|
45
|
+
loop = asyncio.get_running_loop()
|
|
46
|
+
password_hash = await loop.run_in_executor(
|
|
47
|
+
executor=None,
|
|
48
|
+
func=lambda: compute_password_hash(password=password, salt=secret),
|
|
49
|
+
)
|
|
50
|
+
try:
|
|
51
|
+
async with info.context.db() as session:
|
|
52
|
+
user = await session.scalar(
|
|
53
|
+
insert(models.User)
|
|
54
|
+
.values(
|
|
55
|
+
user_role_id=user_role_id,
|
|
56
|
+
username=input.username,
|
|
57
|
+
email=email,
|
|
58
|
+
auth_method="LOCAL",
|
|
59
|
+
password_hash=password_hash,
|
|
60
|
+
reset_password=True,
|
|
61
|
+
)
|
|
62
|
+
.returning(models.User)
|
|
63
|
+
)
|
|
64
|
+
assert user is not None
|
|
65
|
+
except IntegrityError as error:
|
|
66
|
+
raise ValueError(_get_user_create_error_message(error))
|
|
67
|
+
return UserMutationPayload(
|
|
68
|
+
user=User(
|
|
69
|
+
id_attr=user.id,
|
|
70
|
+
email=user.email,
|
|
71
|
+
username=user.username,
|
|
72
|
+
created_at=user.created_at,
|
|
73
|
+
role=UserRole(id_attr=user.user_role_id, name=role_name),
|
|
74
|
+
)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _get_user_create_error_message(error: IntegrityError) -> str:
|
|
79
|
+
"""
|
|
80
|
+
Gets a user-facing error message to explain why user creation failed.
|
|
81
|
+
"""
|
|
82
|
+
original_error_message = str(error)
|
|
83
|
+
username_already_exists = "users.username" in original_error_message
|
|
84
|
+
email_already_exists = "users.email" in original_error_message
|
|
85
|
+
if username_already_exists:
|
|
86
|
+
return "Username already exists"
|
|
87
|
+
elif email_already_exists:
|
|
88
|
+
return "Email already exists"
|
|
89
|
+
return "Failed to create user"
|
phoenix/server/api/queries.py
CHANGED
|
@@ -92,7 +92,8 @@ class Query:
|
|
|
92
92
|
)
|
|
93
93
|
stmt = (
|
|
94
94
|
select(models.User)
|
|
95
|
-
.
|
|
95
|
+
.join(models.UserRole)
|
|
96
|
+
.where(models.UserRole.name != "SYSTEM")
|
|
96
97
|
.order_by(models.User.email)
|
|
97
98
|
.options(joinedload(models.User.role))
|
|
98
99
|
)
|
|
@@ -106,7 +107,7 @@ class Query:
|
|
|
106
107
|
created_at=user.created_at,
|
|
107
108
|
role=UserRole(
|
|
108
109
|
id_attr=user.role.id,
|
|
109
|
-
|
|
110
|
+
name=user.role.name,
|
|
110
111
|
),
|
|
111
112
|
)
|
|
112
113
|
async for user in users
|
|
@@ -120,12 +121,12 @@ class Query:
|
|
|
120
121
|
) -> List[UserRole]:
|
|
121
122
|
async with info.context.db() as session:
|
|
122
123
|
roles = await session.scalars(
|
|
123
|
-
select(models.UserRole).where(models.UserRole.
|
|
124
|
+
select(models.UserRole).where(models.UserRole.name != "SYSTEM")
|
|
124
125
|
)
|
|
125
126
|
return [
|
|
126
127
|
UserRole(
|
|
127
128
|
id_attr=role.id,
|
|
128
|
-
|
|
129
|
+
name=role.name,
|
|
129
130
|
)
|
|
130
131
|
for role in roles
|
|
131
132
|
]
|
|
@@ -137,7 +138,7 @@ class Query:
|
|
|
137
138
|
select(models.APIKey)
|
|
138
139
|
.join(models.User)
|
|
139
140
|
.join(models.UserRole)
|
|
140
|
-
.where(models.UserRole.
|
|
141
|
+
.where(models.UserRole.name != "SYSTEM")
|
|
141
142
|
)
|
|
142
143
|
async with info.context.db() as session:
|
|
143
144
|
api_keys = await session.scalars(stmt)
|
|
@@ -160,7 +161,7 @@ class Query:
|
|
|
160
161
|
select(models.APIKey)
|
|
161
162
|
.join(models.User)
|
|
162
163
|
.join(models.UserRole)
|
|
163
|
-
.where(models.UserRole.
|
|
164
|
+
.where(models.UserRole.name == "SYSTEM")
|
|
164
165
|
)
|
|
165
166
|
async with info.context.db() as session:
|
|
166
167
|
api_keys = await session.scalars(stmt)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from datetime import timedelta
|
|
3
|
+
|
|
4
|
+
from fastapi import APIRouter, Form, Request, Response
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from starlette.status import HTTP_204_NO_CONTENT, HTTP_401_UNAUTHORIZED
|
|
7
|
+
from typing_extensions import Annotated
|
|
8
|
+
|
|
9
|
+
from phoenix.auth import is_valid_password
|
|
10
|
+
from phoenix.db import models
|
|
11
|
+
|
|
12
|
+
router = APIRouter(include_in_schema=False)
|
|
13
|
+
|
|
14
|
+
PHOENIX_ACCESS_TOKEN_COOKIE_NAME = "phoenix-access-token"
|
|
15
|
+
PHOENIX_ACCESS_TOKEN_COOKIE_MAX_AGE_IN_SECONDS = int(timedelta(days=31).total_seconds())
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@router.post("/login")
|
|
19
|
+
async def login(
|
|
20
|
+
request: Request,
|
|
21
|
+
email: Annotated[str, Form()],
|
|
22
|
+
password: Annotated[str, Form()],
|
|
23
|
+
) -> Response:
|
|
24
|
+
async with request.app.state.db() as session:
|
|
25
|
+
if (
|
|
26
|
+
user := await session.scalar(select(models.User).where(models.User.email == email))
|
|
27
|
+
) is None or (password_hash := user.password_hash) is None:
|
|
28
|
+
return Response(status_code=HTTP_401_UNAUTHORIZED)
|
|
29
|
+
secret = request.app.state.get_secret()
|
|
30
|
+
loop = asyncio.get_running_loop()
|
|
31
|
+
if not await loop.run_in_executor(
|
|
32
|
+
executor=None,
|
|
33
|
+
func=lambda: is_valid_password(password=password, salt=secret, password_hash=password_hash),
|
|
34
|
+
):
|
|
35
|
+
return Response(status_code=HTTP_401_UNAUTHORIZED)
|
|
36
|
+
response = Response(status_code=HTTP_204_NO_CONTENT)
|
|
37
|
+
response.set_cookie(
|
|
38
|
+
key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
|
|
39
|
+
value="token", # todo: compute access token
|
|
40
|
+
secure=True,
|
|
41
|
+
httponly=True,
|
|
42
|
+
samesite="strict",
|
|
43
|
+
max_age=PHOENIX_ACCESS_TOKEN_COOKIE_MAX_AGE_IN_SECONDS,
|
|
44
|
+
)
|
|
45
|
+
return response
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@router.post("/logout")
|
|
49
|
+
async def logout() -> Response:
|
|
50
|
+
response = Response(status_code=HTTP_204_NO_CONTENT)
|
|
51
|
+
response.delete_cookie(key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME)
|
|
52
|
+
return response
|
|
@@ -196,7 +196,7 @@ class AnnotateSpansResponseBody(ResponseBody[List[InsertedSpanAnnotation]]):
|
|
|
196
196
|
async def annotate_spans(
|
|
197
197
|
request: Request,
|
|
198
198
|
request_body: AnnotateSpansRequestBody,
|
|
199
|
-
sync: bool = Query(default=
|
|
199
|
+
sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
|
|
200
200
|
) -> AnnotateSpansResponseBody:
|
|
201
201
|
if not request_body.data:
|
|
202
202
|
return AnnotateSpansResponseBody(data=[])
|
phoenix/server/app.py
CHANGED
|
@@ -4,6 +4,7 @@ import json
|
|
|
4
4
|
import logging
|
|
5
5
|
from functools import cached_property
|
|
6
6
|
from pathlib import Path
|
|
7
|
+
from types import MethodType
|
|
7
8
|
from typing import (
|
|
8
9
|
TYPE_CHECKING,
|
|
9
10
|
Any,
|
|
@@ -30,6 +31,7 @@ from sqlalchemy.ext.asyncio import (
|
|
|
30
31
|
AsyncSession,
|
|
31
32
|
async_sessionmaker,
|
|
32
33
|
)
|
|
34
|
+
from starlette.datastructures import State as StarletteState
|
|
33
35
|
from starlette.exceptions import HTTPException
|
|
34
36
|
from starlette.middleware import Middleware
|
|
35
37
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
@@ -80,6 +82,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
80
82
|
TokenCountDataLoader,
|
|
81
83
|
TraceRowIdsDataLoader,
|
|
82
84
|
)
|
|
85
|
+
from phoenix.server.api.routers.auth import router as auth_router
|
|
83
86
|
from phoenix.server.api.routers.v1 import REST_API_VERSION
|
|
84
87
|
from phoenix.server.api.routers.v1 import router as v1_router
|
|
85
88
|
from phoenix.server.api.schema import schema
|
|
@@ -100,6 +103,7 @@ if TYPE_CHECKING:
|
|
|
100
103
|
from opentelemetry.trace import TracerProvider
|
|
101
104
|
|
|
102
105
|
logger = logging.getLogger(__name__)
|
|
106
|
+
logger.addHandler(logging.NullHandler())
|
|
103
107
|
|
|
104
108
|
router = APIRouter(include_in_schema=False)
|
|
105
109
|
|
|
@@ -229,7 +233,8 @@ def _lifespan(
|
|
|
229
233
|
dml_event_handler: DmlEventHandler,
|
|
230
234
|
tracer_provider: Optional["TracerProvider"] = None,
|
|
231
235
|
enable_prometheus: bool = False,
|
|
232
|
-
|
|
236
|
+
startup_callbacks: Iterable[Callable[[], None]] = (),
|
|
237
|
+
shutdown_callbacks: Iterable[Callable[[], None]] = (),
|
|
233
238
|
read_only: bool = False,
|
|
234
239
|
) -> StatefulLifespan[FastAPI]:
|
|
235
240
|
@contextlib.asynccontextmanager
|
|
@@ -247,6 +252,8 @@ def _lifespan(
|
|
|
247
252
|
tracer_provider=tracer_provider,
|
|
248
253
|
enable_prometheus=enable_prometheus,
|
|
249
254
|
), dml_event_handler:
|
|
255
|
+
for callback in startup_callbacks:
|
|
256
|
+
callback()
|
|
250
257
|
yield {
|
|
251
258
|
"event_queue": dml_event_handler,
|
|
252
259
|
"enqueue": enqueue,
|
|
@@ -254,8 +261,8 @@ def _lifespan(
|
|
|
254
261
|
"queue_evaluation_for_bulk_insert": queue_evaluation,
|
|
255
262
|
"enqueue_operation": enqueue_operation,
|
|
256
263
|
}
|
|
257
|
-
for
|
|
258
|
-
|
|
264
|
+
for callback in shutdown_callbacks:
|
|
265
|
+
callback()
|
|
259
266
|
|
|
260
267
|
return lifespan
|
|
261
268
|
|
|
@@ -276,7 +283,26 @@ def create_graphql_router(
|
|
|
276
283
|
cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
|
|
277
284
|
event_queue: CanPutItem[DmlEvent],
|
|
278
285
|
read_only: bool = False,
|
|
286
|
+
secret: Optional[str] = None,
|
|
279
287
|
) -> GraphQLRouter: # type: ignore[type-arg]
|
|
288
|
+
"""Creates the GraphQL router.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
schema (BaseSchema): The GraphQL schema.
|
|
292
|
+
db (DbSessionFactory): The database session factory pointing to a SQL database.
|
|
293
|
+
model (Model): The Model representing inferences (legacy)
|
|
294
|
+
export_path (Path): the file path to export data to for download (legacy)
|
|
295
|
+
last_updated_at (CanGetLastUpdatedAt): How to get the last updated timestamp for updates.
|
|
296
|
+
event_queue (CanPutItem[DmlEvent]): The event queue for DML events.
|
|
297
|
+
corpus (Optional[Model], optional): the corpus for UMAP projection. Defaults to None.
|
|
298
|
+
cache_for_dataloaders (Optional[CacheForDataLoaders], optional): GraphQL data loaders.
|
|
299
|
+
read_only (bool, optional): Marks the app as read-only. Defaults to False.
|
|
300
|
+
secret (Optional[str], optional): The application secret for auth. Defaults to None.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
GraphQLRouter: The router mounted at /graphql
|
|
304
|
+
"""
|
|
305
|
+
|
|
280
306
|
def get_context() -> Context:
|
|
281
307
|
return Context(
|
|
282
308
|
db=db,
|
|
@@ -336,6 +362,7 @@ def create_graphql_router(
|
|
|
336
362
|
),
|
|
337
363
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
338
364
|
read_only=read_only,
|
|
365
|
+
secret=secret,
|
|
339
366
|
)
|
|
340
367
|
|
|
341
368
|
return GraphQLRouter(
|
|
@@ -408,9 +435,12 @@ def create_app(
|
|
|
408
435
|
initial_spans: Optional[Iterable[Union[Span, Tuple[Span, str]]]] = None,
|
|
409
436
|
initial_evaluations: Optional[Iterable[pb.Evaluation]] = None,
|
|
410
437
|
serve_ui: bool = True,
|
|
411
|
-
|
|
438
|
+
startup_callbacks: Iterable[Callable[[], None]] = (),
|
|
439
|
+
shutdown_callbacks: Iterable[Callable[[], None]] = (),
|
|
440
|
+
secret: Optional[str] = None,
|
|
412
441
|
) -> FastAPI:
|
|
413
|
-
|
|
442
|
+
startup_callbacks_list: List[Callable[[], None]] = list(startup_callbacks)
|
|
443
|
+
shutdown_callbacks_list: List[Callable[[], None]] = list(shutdown_callbacks)
|
|
414
444
|
initial_batch_of_spans: Iterable[Tuple[Span, str]] = (
|
|
415
445
|
()
|
|
416
446
|
if initial_spans is None
|
|
@@ -472,6 +502,7 @@ def create_app(
|
|
|
472
502
|
event_queue=dml_event_handler,
|
|
473
503
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
474
504
|
read_only=read_only,
|
|
505
|
+
secret=secret,
|
|
475
506
|
)
|
|
476
507
|
if enable_prometheus:
|
|
477
508
|
from phoenix.server.prometheus import PrometheusMiddleware
|
|
@@ -489,7 +520,8 @@ def create_app(
|
|
|
489
520
|
dml_event_handler=dml_event_handler,
|
|
490
521
|
tracer_provider=tracer_provider,
|
|
491
522
|
enable_prometheus=enable_prometheus,
|
|
492
|
-
|
|
523
|
+
shutdown_callbacks=shutdown_callbacks_list,
|
|
524
|
+
startup_callbacks=startup_callbacks_list,
|
|
493
525
|
),
|
|
494
526
|
middleware=[
|
|
495
527
|
Middleware(HeadersMiddleware),
|
|
@@ -507,6 +539,8 @@ def create_app(
|
|
|
507
539
|
app.include_router(router)
|
|
508
540
|
app.include_router(graphql_router)
|
|
509
541
|
app.add_middleware(GZipMiddleware)
|
|
542
|
+
if authentication_enabled:
|
|
543
|
+
app.include_router(auth_router)
|
|
510
544
|
if serve_ui:
|
|
511
545
|
app.mount(
|
|
512
546
|
"/",
|
|
@@ -525,12 +559,30 @@ def create_app(
|
|
|
525
559
|
),
|
|
526
560
|
name="static",
|
|
527
561
|
)
|
|
528
|
-
|
|
529
|
-
app.state.db = db
|
|
562
|
+
app = _update_app_state(app, db=db, secret=secret)
|
|
530
563
|
if tracer_provider:
|
|
531
564
|
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
|
532
565
|
|
|
533
566
|
FastAPIInstrumentor().instrument(tracer_provider=tracer_provider)
|
|
534
567
|
FastAPIInstrumentor.instrument_app(app, tracer_provider=tracer_provider)
|
|
535
|
-
|
|
568
|
+
shutdown_callbacks_list.append(FastAPIInstrumentor().uninstrument)
|
|
569
|
+
return app
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def _update_app_state(app: FastAPI, /, *, db: DbSessionFactory, secret: Optional[str]) -> FastAPI:
|
|
573
|
+
"""
|
|
574
|
+
Dynamically updates the app's `state` to include useful fields and methods
|
|
575
|
+
(at the time of this writing, FastAPI does not support setting this state
|
|
576
|
+
during the creation of the app).
|
|
577
|
+
"""
|
|
578
|
+
app.state.db = db
|
|
579
|
+
app.state._secret = secret
|
|
580
|
+
|
|
581
|
+
def get_secret(self: StarletteState) -> str:
|
|
582
|
+
if (secret := self._secret) is None:
|
|
583
|
+
raise ValueError("app secret is not set")
|
|
584
|
+
assert isinstance(secret, str)
|
|
585
|
+
return secret
|
|
586
|
+
|
|
587
|
+
app.state.get_secret = MethodType(get_secret, app.state)
|
|
536
588
|
return app
|
phoenix/server/main.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
import atexit
|
|
2
|
+
import codecs
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
5
|
+
import sys
|
|
4
6
|
from argparse import ArgumentParser
|
|
5
|
-
from
|
|
7
|
+
from importlib.metadata import version
|
|
8
|
+
from pathlib import Path
|
|
6
9
|
from threading import Thread
|
|
7
10
|
from time import sleep, time
|
|
8
11
|
from typing import List, Optional
|
|
12
|
+
from urllib.parse import urljoin
|
|
9
13
|
|
|
10
|
-
import pkg_resources
|
|
11
14
|
from uvicorn import Config, Server
|
|
12
15
|
|
|
13
16
|
import phoenix.trace.v1 as pb
|
|
@@ -53,6 +56,7 @@ from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp
|
|
|
53
56
|
from phoenix.trace.schemas import Span
|
|
54
57
|
|
|
55
58
|
logger = logging.getLogger(__name__)
|
|
59
|
+
logger.addHandler(logging.NullHandler())
|
|
56
60
|
|
|
57
61
|
_WELCOME_MESSAGE = """
|
|
58
62
|
|
|
@@ -137,6 +141,7 @@ if __name__ == "__main__":
|
|
|
137
141
|
parser.add_argument("--debug", action="store_true")
|
|
138
142
|
# Whether the app is running in a development environment
|
|
139
143
|
parser.add_argument("--dev", action="store_true")
|
|
144
|
+
parser.add_argument("--no-ui", action="store_true")
|
|
140
145
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
141
146
|
serve_parser = subparsers.add_parser("serve")
|
|
142
147
|
datasets_parser = subparsers.add_parser("datasets")
|
|
@@ -218,7 +223,7 @@ if __name__ == "__main__":
|
|
|
218
223
|
reference_inferences,
|
|
219
224
|
)
|
|
220
225
|
|
|
221
|
-
authentication_enabled,
|
|
226
|
+
authentication_enabled, secret = get_auth_settings()
|
|
222
227
|
|
|
223
228
|
fixture_spans: List[Span] = []
|
|
224
229
|
fixture_evals: List[pb.Evaluation] = []
|
|
@@ -255,6 +260,18 @@ if __name__ == "__main__":
|
|
|
255
260
|
engine = create_engine_and_run_migrations(db_connection_str)
|
|
256
261
|
instrumentation_cleanups = instrument_engine_if_enabled(engine)
|
|
257
262
|
factory = DbSessionFactory(db=_db(engine), dialect=engine.dialect.name)
|
|
263
|
+
# Print information about the server
|
|
264
|
+
msg = _WELCOME_MESSAGE.format(
|
|
265
|
+
version=version("arize-phoenix"),
|
|
266
|
+
ui_path=urljoin(f"http://{host}:{port}", host_root_path),
|
|
267
|
+
grpc_path=f"http://{host}:{get_env_grpc_port()}",
|
|
268
|
+
http_path=urljoin(urljoin(f"http://{host}:{port}", host_root_path), "v1/traces"),
|
|
269
|
+
storage=get_printable_db_url(db_connection_str),
|
|
270
|
+
)
|
|
271
|
+
if authentication_enabled:
|
|
272
|
+
msg += _EXPERIMENTAL_WARNING.format(auth_enabled=True)
|
|
273
|
+
if sys.platform.startswith("win"):
|
|
274
|
+
msg = codecs.encode(msg, "ascii", errors="ignore").decode("ascii").strip()
|
|
258
275
|
app = create_app(
|
|
259
276
|
db=factory,
|
|
260
277
|
export_path=export_path,
|
|
@@ -266,29 +283,17 @@ if __name__ == "__main__":
|
|
|
266
283
|
else create_model_from_inferences(corpus_inferences),
|
|
267
284
|
debug=args.debug,
|
|
268
285
|
dev=args.dev,
|
|
286
|
+
serve_ui=not args.no_ui,
|
|
269
287
|
read_only=read_only,
|
|
270
288
|
enable_prometheus=enable_prometheus,
|
|
271
289
|
initial_spans=fixture_spans,
|
|
272
290
|
initial_evaluations=fixture_evals,
|
|
273
|
-
|
|
291
|
+
startup_callbacks=[lambda: print(msg)],
|
|
292
|
+
shutdown_callbacks=instrumentation_cleanups,
|
|
293
|
+
secret=secret,
|
|
274
294
|
)
|
|
275
295
|
server = Server(config=Config(app, host=host, port=port, root_path=host_root_path)) # type: ignore
|
|
276
296
|
Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start()
|
|
277
297
|
|
|
278
|
-
# Print information about the server
|
|
279
|
-
phoenix_version = pkg_resources.get_distribution("arize-phoenix").version
|
|
280
|
-
print(
|
|
281
|
-
_WELCOME_MESSAGE.format(
|
|
282
|
-
version=phoenix_version,
|
|
283
|
-
ui_path=PosixPath(f"http://{host}:{port}", host_root_path),
|
|
284
|
-
grpc_path=f"http://{host}:{get_env_grpc_port()}",
|
|
285
|
-
http_path=PosixPath(f"http://{host}:{port}", host_root_path, "v1/traces"),
|
|
286
|
-
storage=get_printable_db_url(db_connection_str),
|
|
287
|
-
)
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
if authentication_enabled:
|
|
291
|
-
print(_EXPERIMENTAL_WARNING.format(auth_enabled=authentication_enabled))
|
|
292
|
-
|
|
293
298
|
# Start the server
|
|
294
299
|
server.run()
|