arize-phoenix 11.38.0__py3-none-any.whl → 12.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/METADATA +3 -3
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/RECORD +83 -58
- phoenix/config.py +1 -11
- phoenix/db/bulk_inserter.py +8 -0
- phoenix/db/facilitator.py +1 -1
- phoenix/db/helpers.py +202 -33
- phoenix/db/insertion/dataset.py +7 -0
- phoenix/db/insertion/document_annotation.py +1 -1
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +1 -1
- phoenix/db/insertion/trace_annotation.py +1 -1
- phoenix/db/insertion/types.py +29 -3
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +306 -46
- phoenix/server/api/context.py +15 -2
- phoenix/server/api/dataloaders/__init__.py +8 -2
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/helpers/playground_clients.py +66 -35
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/chat_mutations.py +8 -3
- phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
- phoenix/server/api/mutations/dataset_mutations.py +5 -0
- phoenix/server/api/mutations/dataset_split_mutations.py +423 -0
- phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
- phoenix/server/api/queries.py +53 -0
- phoenix/server/api/routers/auth.py +5 -5
- phoenix/server/api/routers/oauth2.py +5 -23
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/annotations.py +320 -0
- phoenix/server/api/routers/v1/datasets.py +5 -0
- phoenix/server/api/routers/v1/experiments.py +10 -3
- phoenix/server/api/routers/v1/sessions.py +111 -0
- phoenix/server/api/routers/v1/traces.py +1 -2
- phoenix/server/api/routers/v1/users.py +7 -0
- phoenix/server/api/subscriptions.py +5 -2
- phoenix/server/api/types/Dataset.py +8 -0
- phoenix/server/api/types/DatasetExample.py +18 -0
- phoenix/server/api/types/DatasetLabel.py +23 -0
- phoenix/server/api/types/DatasetSplit.py +32 -0
- phoenix/server/api/types/Experiment.py +0 -4
- phoenix/server/api/types/Project.py +16 -0
- phoenix/server/api/types/ProjectSession.py +88 -3
- phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
- phoenix/server/api/types/Prompt.py +18 -1
- phoenix/server/api/types/Span.py +5 -5
- phoenix/server/api/types/Trace.py +61 -0
- phoenix/server/app.py +13 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
- phoenix/server/dml_event.py +13 -0
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-BQPHTBfv.js → components-BG6v0EM8.js} +705 -385
- phoenix/server/static/assets/{index-BL5BMgJU.js → index-CSVcULw1.js} +13 -13
- phoenix/server/static/assets/{pages-C0Y17J0T.js → pages-DgaM7kpM.js} +1356 -1155
- phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-BqTEkGQU.js} +183 -183
- phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-DlOj0PQQ.js} +15 -24
- phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-B2PHH5yZ.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts-CKsi4IjN.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-DN26BkKE.js} +1 -1
- phoenix/server/utils.py +74 -0
- phoenix/session/session.py +25 -5
- phoenix/version.py +1 -1
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,6 @@ import asyncio
|
|
|
2
2
|
import secrets
|
|
3
3
|
from datetime import datetime, timedelta, timezone
|
|
4
4
|
from functools import partial
|
|
5
|
-
from pathlib import Path
|
|
6
5
|
from urllib.parse import urlencode, urlparse, urlunparse
|
|
7
6
|
|
|
8
7
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
|
@@ -38,7 +37,6 @@ from phoenix.config import (
|
|
|
38
37
|
get_base_url,
|
|
39
38
|
get_env_disable_basic_auth,
|
|
40
39
|
get_env_disable_rate_limit,
|
|
41
|
-
get_env_host_root_path,
|
|
42
40
|
)
|
|
43
41
|
from phoenix.db import models
|
|
44
42
|
from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens
|
|
@@ -52,6 +50,7 @@ from phoenix.server.types import (
|
|
|
52
50
|
TokenStore,
|
|
53
51
|
UserId,
|
|
54
52
|
)
|
|
53
|
+
from phoenix.server.utils import prepend_root_path
|
|
55
54
|
|
|
56
55
|
rate_limiter = ServerRateLimiter(
|
|
57
56
|
per_second_rate_limit=0.2,
|
|
@@ -145,7 +144,8 @@ async def logout(
|
|
|
145
144
|
user_id = subject
|
|
146
145
|
if user_id:
|
|
147
146
|
await token_store.log_out(user_id)
|
|
148
|
-
|
|
147
|
+
redirect_path = "/logout" if get_env_disable_basic_auth() else "/login"
|
|
148
|
+
redirect_url = prepend_root_path(request.scope, redirect_path)
|
|
149
149
|
response = Response(status_code=HTTP_302_FOUND, headers={"Location": redirect_url})
|
|
150
150
|
response = delete_access_token_cookie(response)
|
|
151
151
|
response = delete_refresh_token_cookie(response)
|
|
@@ -242,9 +242,9 @@ async def initiate_password_reset(request: Request) -> Response:
|
|
|
242
242
|
)
|
|
243
243
|
token, _ = await token_store.create_password_reset_token(password_reset_token_claims)
|
|
244
244
|
url = urlparse(request.headers.get("referer") or get_base_url())
|
|
245
|
-
path =
|
|
245
|
+
path = prepend_root_path(request.scope, "/reset-password-with-token")
|
|
246
246
|
query_string = urlencode(dict(token=token))
|
|
247
|
-
components = (url.scheme, url.netloc, path
|
|
247
|
+
components = (url.scheme, url.netloc, path, "", query_string, "")
|
|
248
248
|
reset_url = urlunparse(components)
|
|
249
249
|
await sender.send_password_reset_email(email, reset_url)
|
|
250
250
|
return Response(status_code=HTTP_204_NO_CONTENT)
|
|
@@ -46,6 +46,7 @@ from phoenix.server.rate_limiters import (
|
|
|
46
46
|
fastapi_route_rate_limiter,
|
|
47
47
|
)
|
|
48
48
|
from phoenix.server.types import TokenStore
|
|
49
|
+
from phoenix.server.utils import get_root_path, prepend_root_path
|
|
49
50
|
|
|
50
51
|
_LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
|
|
51
52
|
|
|
@@ -191,7 +192,7 @@ async def create_tokens(
|
|
|
191
192
|
access_token_expiry=access_token_expiry,
|
|
192
193
|
refresh_token_expiry=refresh_token_expiry,
|
|
193
194
|
)
|
|
194
|
-
redirect_path =
|
|
195
|
+
redirect_path = prepend_root_path(request.scope, return_url or "/")
|
|
195
196
|
response = RedirectResponse(
|
|
196
197
|
url=redirect_path,
|
|
197
198
|
status_code=HTTP_302_FOUND,
|
|
@@ -564,8 +565,8 @@ def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
|
|
|
564
565
|
Creates a RedirectResponse to the login page to display an error message.
|
|
565
566
|
"""
|
|
566
567
|
# TODO: this needs some cleanup
|
|
567
|
-
login_path =
|
|
568
|
-
request
|
|
568
|
+
login_path = prepend_root_path(
|
|
569
|
+
request.scope, "/login" if not get_env_disable_basic_auth() else "/logout"
|
|
569
570
|
)
|
|
570
571
|
url = URL(login_path).include_query_params(error=error)
|
|
571
572
|
response = RedirectResponse(url=url)
|
|
@@ -574,34 +575,15 @@ def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
|
|
|
574
575
|
return response
|
|
575
576
|
|
|
576
577
|
|
|
577
|
-
def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
|
|
578
|
-
"""
|
|
579
|
-
If a root path is configured, prepends it to the input path.
|
|
580
|
-
"""
|
|
581
|
-
if not path.startswith("/"):
|
|
582
|
-
raise ValueError("path must start with a forward slash")
|
|
583
|
-
root_path = _get_root_path(request=request)
|
|
584
|
-
if root_path.endswith("/"):
|
|
585
|
-
root_path = root_path.rstrip("/")
|
|
586
|
-
return root_path + path
|
|
587
|
-
|
|
588
|
-
|
|
589
578
|
def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
|
|
590
579
|
"""
|
|
591
580
|
If a root path is configured, appends it to the input base url.
|
|
592
581
|
"""
|
|
593
|
-
if not (root_path :=
|
|
582
|
+
if not (root_path := get_root_path(request.scope)):
|
|
594
583
|
return base_url
|
|
595
584
|
return str(URLPath(root_path).make_absolute_url(base_url=base_url))
|
|
596
585
|
|
|
597
586
|
|
|
598
|
-
def _get_root_path(*, request: Request) -> str:
|
|
599
|
-
"""
|
|
600
|
-
Gets the root path from the request.
|
|
601
|
-
"""
|
|
602
|
-
return str(request.scope.get("root_path", ""))
|
|
603
|
-
|
|
604
|
-
|
|
605
587
|
def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
|
|
606
588
|
"""
|
|
607
589
|
Gets the endpoint for create tokens route.
|
|
@@ -14,6 +14,7 @@ from .experiment_runs import router as experiment_runs_router
|
|
|
14
14
|
from .experiments import router as experiments_router
|
|
15
15
|
from .projects import router as projects_router
|
|
16
16
|
from .prompts import router as prompts_router
|
|
17
|
+
from .sessions import router as sessions_router
|
|
17
18
|
from .spans import router as spans_router
|
|
18
19
|
from .traces import router as traces_router
|
|
19
20
|
from .users import router as users_router
|
|
@@ -71,6 +72,7 @@ def create_v1_router(authentication_enabled: bool) -> APIRouter:
|
|
|
71
72
|
router.include_router(evaluations_router)
|
|
72
73
|
router.include_router(prompts_router)
|
|
73
74
|
router.include_router(projects_router)
|
|
75
|
+
router.include_router(sessions_router)
|
|
74
76
|
router.include_router(documents_router)
|
|
75
77
|
router.include_router(users_router)
|
|
76
78
|
return router
|
|
@@ -14,6 +14,9 @@ from strawberry.relay import GlobalID
|
|
|
14
14
|
from phoenix.db import models
|
|
15
15
|
from phoenix.db.insertion.types import Precursors
|
|
16
16
|
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
17
|
+
from phoenix.server.api.types.ProjectSessionAnnotation import (
|
|
18
|
+
ProjectSessionAnnotation as SessionAnnotationNodeType,
|
|
19
|
+
)
|
|
17
20
|
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation as SpanAnnotationNodeType
|
|
18
21
|
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation as TraceAnnotationNodeType
|
|
19
22
|
from phoenix.server.api.types.User import User as UserNodeType
|
|
@@ -24,6 +27,7 @@ logger = logging.getLogger(__name__)
|
|
|
24
27
|
|
|
25
28
|
SPAN_ANNOTATION_NODE_NAME = SpanAnnotationNodeType.__name__
|
|
26
29
|
TRACE_ANNOTATION_NODE_NAME = TraceAnnotationNodeType.__name__
|
|
30
|
+
SESSION_ANNOTATION_NODE_NAME = SessionAnnotationNodeType.__name__
|
|
27
31
|
MAX_TRACE_IDS = 1_000
|
|
28
32
|
USER_NODE_NAME = UserNodeType.__name__
|
|
29
33
|
MAX_SPAN_IDS = 1_000
|
|
@@ -161,6 +165,35 @@ class TraceAnnotationsResponseBody(PaginatedResponseBody[TraceAnnotation]):
|
|
|
161
165
|
pass
|
|
162
166
|
|
|
163
167
|
|
|
168
|
+
class SessionAnnotationData(AnnotationData):
|
|
169
|
+
session_id: str = Field(description="Session ID")
|
|
170
|
+
|
|
171
|
+
def as_precursor(self, *, user_id: Optional[int] = None) -> Precursors.SessionAnnotation:
|
|
172
|
+
return Precursors.SessionAnnotation(
|
|
173
|
+
datetime.now(timezone.utc),
|
|
174
|
+
self.session_id,
|
|
175
|
+
models.ProjectSessionAnnotation(
|
|
176
|
+
name=self.name,
|
|
177
|
+
annotator_kind=self.annotator_kind,
|
|
178
|
+
score=self.result.score if self.result else None,
|
|
179
|
+
label=self.result.label if self.result else None,
|
|
180
|
+
explanation=self.result.explanation if self.result else None,
|
|
181
|
+
metadata_=self.metadata or {},
|
|
182
|
+
identifier=self.identifier,
|
|
183
|
+
source="API",
|
|
184
|
+
user_id=user_id,
|
|
185
|
+
),
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class SessionAnnotation(SessionAnnotationData, Annotation):
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class SessionAnnotationsResponseBody(PaginatedResponseBody[SessionAnnotation]):
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
|
|
164
197
|
@router.get(
|
|
165
198
|
"/projects/{project_identifier}/span_annotations",
|
|
166
199
|
operation_id="listSpanAnnotationsBySpanIds",
|
|
@@ -304,3 +337,290 @@ async def list_span_annotations(
|
|
|
304
337
|
]
|
|
305
338
|
|
|
306
339
|
return SpanAnnotationsResponseBody(data=data, next_cursor=next_cursor)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@router.get(
|
|
343
|
+
"/projects/{project_identifier}/trace_annotations",
|
|
344
|
+
operation_id="listTraceAnnotationsByTraceIds",
|
|
345
|
+
summary="Get trace annotations for a list of trace_ids.",
|
|
346
|
+
status_code=HTTP_200_OK,
|
|
347
|
+
responses=add_errors_to_responses(
|
|
348
|
+
[
|
|
349
|
+
{"status_code": HTTP_404_NOT_FOUND, "description": "Project or traces not found"},
|
|
350
|
+
{"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
|
|
351
|
+
]
|
|
352
|
+
),
|
|
353
|
+
)
|
|
354
|
+
async def list_trace_annotations(
|
|
355
|
+
request: Request,
|
|
356
|
+
project_identifier: str = Path(
|
|
357
|
+
description=(
|
|
358
|
+
"The project identifier: either project ID or project name. If using a project name as "
|
|
359
|
+
"the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
|
|
360
|
+
"characters."
|
|
361
|
+
)
|
|
362
|
+
),
|
|
363
|
+
trace_ids: list[str] = Query(
|
|
364
|
+
..., min_length=1, description="One or more trace id to fetch annotations for"
|
|
365
|
+
),
|
|
366
|
+
include_annotation_names: Optional[list[str]] = Query(
|
|
367
|
+
default=None,
|
|
368
|
+
description=(
|
|
369
|
+
"Optional list of annotation names to include. If provided, only annotations with "
|
|
370
|
+
"these names will be returned. 'note' annotations are excluded by default unless "
|
|
371
|
+
"explicitly included in this list."
|
|
372
|
+
),
|
|
373
|
+
),
|
|
374
|
+
exclude_annotation_names: Optional[list[str]] = Query(
|
|
375
|
+
default=None, description="Optional list of annotation names to exclude from results."
|
|
376
|
+
),
|
|
377
|
+
cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
|
|
378
|
+
limit: int = Query(
|
|
379
|
+
default=10,
|
|
380
|
+
gt=0,
|
|
381
|
+
le=10000,
|
|
382
|
+
description="The maximum number of annotations to return in a single request",
|
|
383
|
+
),
|
|
384
|
+
) -> TraceAnnotationsResponseBody:
|
|
385
|
+
trace_ids = list({*trace_ids})
|
|
386
|
+
if len(trace_ids) > MAX_TRACE_IDS:
|
|
387
|
+
raise HTTPException(
|
|
388
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
389
|
+
detail=f"Too many trace_ids supplied: {len(trace_ids)} (max {MAX_TRACE_IDS})",
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
async with request.app.state.db() as session:
|
|
393
|
+
project = await _get_project_by_identifier(session, project_identifier)
|
|
394
|
+
if not project:
|
|
395
|
+
raise HTTPException(
|
|
396
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
397
|
+
detail=f"Project with identifier {project_identifier} not found",
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Build the base query
|
|
401
|
+
where_conditions = [
|
|
402
|
+
models.Project.id == project.id,
|
|
403
|
+
models.Trace.trace_id.in_(trace_ids),
|
|
404
|
+
]
|
|
405
|
+
|
|
406
|
+
# Add annotation name filtering
|
|
407
|
+
if include_annotation_names:
|
|
408
|
+
where_conditions.append(models.TraceAnnotation.name.in_(include_annotation_names))
|
|
409
|
+
|
|
410
|
+
if exclude_annotation_names:
|
|
411
|
+
where_conditions.append(models.TraceAnnotation.name.not_in(exclude_annotation_names))
|
|
412
|
+
|
|
413
|
+
stmt = (
|
|
414
|
+
select(models.Trace.trace_id, models.TraceAnnotation)
|
|
415
|
+
.join(models.Project, models.Trace.project_rowid == models.Project.id)
|
|
416
|
+
.join(models.TraceAnnotation, models.TraceAnnotation.trace_rowid == models.Trace.id)
|
|
417
|
+
.where(*where_conditions)
|
|
418
|
+
.order_by(models.TraceAnnotation.id.desc())
|
|
419
|
+
.limit(limit + 1)
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if cursor:
|
|
423
|
+
try:
|
|
424
|
+
cursor_id = int(GlobalID.from_id(cursor).node_id)
|
|
425
|
+
except ValueError:
|
|
426
|
+
raise HTTPException(
|
|
427
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
428
|
+
detail="Invalid cursor value",
|
|
429
|
+
)
|
|
430
|
+
stmt = stmt.where(models.TraceAnnotation.id <= cursor_id)
|
|
431
|
+
|
|
432
|
+
rows: list[tuple[str, models.TraceAnnotation]] = [
|
|
433
|
+
r async for r in (await session.stream(stmt))
|
|
434
|
+
]
|
|
435
|
+
|
|
436
|
+
next_cursor: Optional[str] = None
|
|
437
|
+
if len(rows) == limit + 1:
|
|
438
|
+
*rows, extra = rows
|
|
439
|
+
next_cursor = str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(extra[1].id)))
|
|
440
|
+
|
|
441
|
+
if not rows:
|
|
442
|
+
traces_exist = await session.scalar(
|
|
443
|
+
select(
|
|
444
|
+
exists().where(
|
|
445
|
+
models.Trace.trace_id.in_(trace_ids),
|
|
446
|
+
models.Trace.project_rowid == project.id,
|
|
447
|
+
)
|
|
448
|
+
)
|
|
449
|
+
)
|
|
450
|
+
if not traces_exist:
|
|
451
|
+
raise HTTPException(
|
|
452
|
+
detail="None of the supplied trace_ids exist in this project",
|
|
453
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
return TraceAnnotationsResponseBody(data=[], next_cursor=None)
|
|
457
|
+
|
|
458
|
+
data = [
|
|
459
|
+
TraceAnnotation(
|
|
460
|
+
id=str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(anno.id))),
|
|
461
|
+
trace_id=trace_id,
|
|
462
|
+
name=anno.name,
|
|
463
|
+
result=AnnotationResult(
|
|
464
|
+
label=anno.label,
|
|
465
|
+
score=anno.score,
|
|
466
|
+
explanation=anno.explanation,
|
|
467
|
+
),
|
|
468
|
+
metadata=anno.metadata_,
|
|
469
|
+
annotator_kind=anno.annotator_kind,
|
|
470
|
+
created_at=anno.created_at,
|
|
471
|
+
updated_at=anno.updated_at,
|
|
472
|
+
identifier=anno.identifier,
|
|
473
|
+
source=anno.source,
|
|
474
|
+
user_id=str(GlobalID("User", str(anno.user_id))) if anno.user_id else None,
|
|
475
|
+
)
|
|
476
|
+
for trace_id, anno in rows
|
|
477
|
+
]
|
|
478
|
+
|
|
479
|
+
return TraceAnnotationsResponseBody(data=data, next_cursor=next_cursor)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
@router.get(
|
|
483
|
+
"/projects/{project_identifier}/session_annotations",
|
|
484
|
+
operation_id="listSessionAnnotationsBySessionIds",
|
|
485
|
+
summary="Get session annotations for a list of session_ids.",
|
|
486
|
+
status_code=HTTP_200_OK,
|
|
487
|
+
responses=add_errors_to_responses(
|
|
488
|
+
[
|
|
489
|
+
{"status_code": HTTP_404_NOT_FOUND, "description": "Project or sessions not found"},
|
|
490
|
+
{"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
|
|
491
|
+
]
|
|
492
|
+
),
|
|
493
|
+
)
|
|
494
|
+
async def list_session_annotations(
|
|
495
|
+
request: Request,
|
|
496
|
+
project_identifier: str = Path(
|
|
497
|
+
description=(
|
|
498
|
+
"The project identifier: either project ID or project name. If using a project name as "
|
|
499
|
+
"the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
|
|
500
|
+
"characters."
|
|
501
|
+
)
|
|
502
|
+
),
|
|
503
|
+
session_ids: list[str] = Query(
|
|
504
|
+
..., min_length=1, description="One or more session id to fetch annotations for"
|
|
505
|
+
),
|
|
506
|
+
include_annotation_names: Optional[list[str]] = Query(
|
|
507
|
+
default=None,
|
|
508
|
+
description=(
|
|
509
|
+
"Optional list of annotation names to include. If provided, only annotations with "
|
|
510
|
+
"these names will be returned. 'note' annotations are excluded by default unless "
|
|
511
|
+
"explicitly included in this list."
|
|
512
|
+
),
|
|
513
|
+
),
|
|
514
|
+
exclude_annotation_names: Optional[list[str]] = Query(
|
|
515
|
+
default=None, description="Optional list of annotation names to exclude from results."
|
|
516
|
+
),
|
|
517
|
+
cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
|
|
518
|
+
limit: int = Query(
|
|
519
|
+
default=10,
|
|
520
|
+
gt=0,
|
|
521
|
+
le=10000,
|
|
522
|
+
description="The maximum number of annotations to return in a single request",
|
|
523
|
+
),
|
|
524
|
+
) -> SessionAnnotationsResponseBody:
|
|
525
|
+
session_ids = list({*session_ids})
|
|
526
|
+
if len(session_ids) > MAX_SESSION_IDS:
|
|
527
|
+
raise HTTPException(
|
|
528
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
529
|
+
detail=f"Too many session_ids supplied: {len(session_ids)} (max {MAX_SESSION_IDS})",
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
async with request.app.state.db() as session:
|
|
533
|
+
project = await _get_project_by_identifier(session, project_identifier)
|
|
534
|
+
if not project:
|
|
535
|
+
raise HTTPException(
|
|
536
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
537
|
+
detail=f"Project with identifier {project_identifier} not found",
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
# Build the base query
|
|
541
|
+
where_conditions = [
|
|
542
|
+
models.Project.id == project.id,
|
|
543
|
+
models.ProjectSession.session_id.in_(session_ids),
|
|
544
|
+
]
|
|
545
|
+
|
|
546
|
+
# Add annotation name filtering
|
|
547
|
+
if include_annotation_names:
|
|
548
|
+
where_conditions.append(
|
|
549
|
+
models.ProjectSessionAnnotation.name.in_(include_annotation_names)
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
if exclude_annotation_names:
|
|
553
|
+
where_conditions.append(
|
|
554
|
+
models.ProjectSessionAnnotation.name.not_in(exclude_annotation_names)
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
stmt = (
|
|
558
|
+
select(models.ProjectSession.session_id, models.ProjectSessionAnnotation)
|
|
559
|
+
.join(models.Project, models.ProjectSession.project_id == models.Project.id)
|
|
560
|
+
.join(
|
|
561
|
+
models.ProjectSessionAnnotation,
|
|
562
|
+
models.ProjectSessionAnnotation.project_session_id == models.ProjectSession.id,
|
|
563
|
+
)
|
|
564
|
+
.where(*where_conditions)
|
|
565
|
+
.order_by(models.ProjectSessionAnnotation.id.desc())
|
|
566
|
+
.limit(limit + 1)
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
if cursor:
|
|
570
|
+
try:
|
|
571
|
+
cursor_id = int(GlobalID.from_id(cursor).node_id)
|
|
572
|
+
except ValueError:
|
|
573
|
+
raise HTTPException(
|
|
574
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
575
|
+
detail="Invalid cursor value",
|
|
576
|
+
)
|
|
577
|
+
stmt = stmt.where(models.ProjectSessionAnnotation.id <= cursor_id)
|
|
578
|
+
|
|
579
|
+
rows: list[tuple[str, models.ProjectSessionAnnotation]] = [
|
|
580
|
+
r async for r in (await session.stream(stmt))
|
|
581
|
+
]
|
|
582
|
+
|
|
583
|
+
next_cursor: Optional[str] = None
|
|
584
|
+
if len(rows) == limit + 1:
|
|
585
|
+
*rows, extra = rows
|
|
586
|
+
next_cursor = str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(extra[1].id)))
|
|
587
|
+
|
|
588
|
+
if not rows:
|
|
589
|
+
sessions_exist = await session.scalar(
|
|
590
|
+
select(
|
|
591
|
+
exists().where(
|
|
592
|
+
models.ProjectSession.session_id.in_(session_ids),
|
|
593
|
+
models.ProjectSession.project_id == project.id,
|
|
594
|
+
)
|
|
595
|
+
)
|
|
596
|
+
)
|
|
597
|
+
if not sessions_exist:
|
|
598
|
+
raise HTTPException(
|
|
599
|
+
detail="None of the supplied session_ids exist in this project",
|
|
600
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
return SessionAnnotationsResponseBody(data=[], next_cursor=None)
|
|
604
|
+
|
|
605
|
+
data = [
|
|
606
|
+
SessionAnnotation(
|
|
607
|
+
id=str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(anno.id))),
|
|
608
|
+
session_id=session_id,
|
|
609
|
+
name=anno.name,
|
|
610
|
+
result=AnnotationResult(
|
|
611
|
+
label=anno.label,
|
|
612
|
+
score=anno.score,
|
|
613
|
+
explanation=anno.explanation,
|
|
614
|
+
),
|
|
615
|
+
metadata=anno.metadata_,
|
|
616
|
+
annotator_kind=anno.annotator_kind,
|
|
617
|
+
created_at=anno.created_at,
|
|
618
|
+
updated_at=anno.updated_at,
|
|
619
|
+
identifier=anno.identifier,
|
|
620
|
+
source=anno.source,
|
|
621
|
+
user_id=str(GlobalID(USER_NODE_NAME, str(anno.user_id))) if anno.user_id else None,
|
|
622
|
+
)
|
|
623
|
+
for session_id, anno in rows
|
|
624
|
+
]
|
|
625
|
+
|
|
626
|
+
return SessionAnnotationsResponseBody(data=data, next_cursor=next_cursor)
|
|
@@ -48,6 +48,7 @@ from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVer
|
|
|
48
48
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
49
49
|
from phoenix.server.api.utils import delete_projects, delete_traces
|
|
50
50
|
from phoenix.server.authorization import is_not_locked
|
|
51
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
51
52
|
from phoenix.server.dml_event import DatasetInsertEvent
|
|
52
53
|
|
|
53
54
|
from .models import V1RoutesBaseModel
|
|
@@ -478,6 +479,9 @@ async def upload_dataset(
|
|
|
478
479
|
detail="Invalid request Content-Type",
|
|
479
480
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
480
481
|
)
|
|
482
|
+
user_id: Optional[int] = None
|
|
483
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
484
|
+
user_id = int(request.user.identity)
|
|
481
485
|
operation = cast(
|
|
482
486
|
Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
|
|
483
487
|
partial(
|
|
@@ -486,6 +490,7 @@ async def upload_dataset(
|
|
|
486
490
|
action=action,
|
|
487
491
|
name=name,
|
|
488
492
|
description=description,
|
|
493
|
+
user_id=user_id,
|
|
489
494
|
),
|
|
490
495
|
)
|
|
491
496
|
if sync:
|
|
@@ -15,10 +15,14 @@ from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESS
|
|
|
15
15
|
from strawberry.relay import GlobalID
|
|
16
16
|
|
|
17
17
|
from phoenix.db import models
|
|
18
|
-
from phoenix.db.helpers import
|
|
18
|
+
from phoenix.db.helpers import (
|
|
19
|
+
SupportedSQLDialect,
|
|
20
|
+
insert_experiment_with_examples_snapshot,
|
|
21
|
+
)
|
|
19
22
|
from phoenix.db.insertion.helpers import insert_on_conflict
|
|
20
23
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
21
24
|
from phoenix.server.authorization import is_not_locked
|
|
25
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
22
26
|
from phoenix.server.dml_event import ExperimentInsertEvent
|
|
23
27
|
from phoenix.server.experiments.utils import generate_experiment_project_name
|
|
24
28
|
|
|
@@ -157,6 +161,9 @@ async def create_experiment(
|
|
|
157
161
|
detail=f"DatasetVersion with ID {dataset_version_globalid} does not exist",
|
|
158
162
|
status_code=HTTP_404_NOT_FOUND,
|
|
159
163
|
)
|
|
164
|
+
user_id: Optional[int] = None
|
|
165
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
166
|
+
user_id = int(request.user.identity)
|
|
160
167
|
|
|
161
168
|
# generate a semi-unique name for the experiment
|
|
162
169
|
experiment_name = request_body.name or _generate_experiment_name(dataset_name)
|
|
@@ -172,9 +179,9 @@ async def create_experiment(
|
|
|
172
179
|
repetitions=request_body.repetitions,
|
|
173
180
|
metadata_=request_body.metadata or {},
|
|
174
181
|
project_name=project_name,
|
|
182
|
+
user_id=user_id,
|
|
175
183
|
)
|
|
176
|
-
session
|
|
177
|
-
await session.flush()
|
|
184
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
178
185
|
|
|
179
186
|
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
180
187
|
project_rowid = await session.scalar(
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
7
|
+
from pydantic import Field
|
|
8
|
+
from sqlalchemy import select
|
|
9
|
+
from starlette.requests import Request
|
|
10
|
+
from starlette.status import HTTP_404_NOT_FOUND
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.db.helpers import SupportedSQLDialect
|
|
14
|
+
from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
|
|
15
|
+
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
16
|
+
from phoenix.server.authorization import is_not_locked
|
|
17
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
18
|
+
|
|
19
|
+
from .annotations import SessionAnnotationData
|
|
20
|
+
from .utils import RequestBody, ResponseBody, add_errors_to_responses
|
|
21
|
+
|
|
22
|
+
router = APIRouter(tags=["sessions"])
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class InsertedSessionAnnotation(V1RoutesBaseModel):
|
|
26
|
+
id: str = Field(description="The ID of the inserted session annotation")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AnnotateSessionsRequestBody(RequestBody[list[SessionAnnotationData]]):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AnnotateSessionsResponseBody(ResponseBody[list[InsertedSessionAnnotation]]):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@router.post(
|
|
38
|
+
"/session_annotations",
|
|
39
|
+
dependencies=[Depends(is_not_locked)],
|
|
40
|
+
operation_id="annotateSessions",
|
|
41
|
+
summary="Create session annotations",
|
|
42
|
+
responses=add_errors_to_responses(
|
|
43
|
+
[{"status_code": HTTP_404_NOT_FOUND, "description": "Session not found"}]
|
|
44
|
+
),
|
|
45
|
+
response_description="Session annotations inserted successfully",
|
|
46
|
+
include_in_schema=True,
|
|
47
|
+
)
|
|
48
|
+
async def annotate_sessions(
|
|
49
|
+
request: Request,
|
|
50
|
+
request_body: AnnotateSessionsRequestBody,
|
|
51
|
+
sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
|
|
52
|
+
) -> AnnotateSessionsResponseBody:
|
|
53
|
+
if not request_body.data:
|
|
54
|
+
return AnnotateSessionsResponseBody(data=[])
|
|
55
|
+
|
|
56
|
+
user_id: Optional[int] = None
|
|
57
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
58
|
+
user_id = int(request.user.identity)
|
|
59
|
+
|
|
60
|
+
session_annotations = request_body.data
|
|
61
|
+
filtered_session_annotations = list(filter(lambda d: d.name != "note", session_annotations))
|
|
62
|
+
if len(filtered_session_annotations) != len(session_annotations):
|
|
63
|
+
warnings.warn(
|
|
64
|
+
(
|
|
65
|
+
"Session annotations with the name 'note' are not supported in this endpoint. "
|
|
66
|
+
"They will be ignored."
|
|
67
|
+
),
|
|
68
|
+
UserWarning,
|
|
69
|
+
)
|
|
70
|
+
precursors = [d.as_precursor(user_id=user_id) for d in filtered_session_annotations]
|
|
71
|
+
if not sync:
|
|
72
|
+
await request.state.enqueue_annotations(*precursors)
|
|
73
|
+
return AnnotateSessionsResponseBody(data=[])
|
|
74
|
+
|
|
75
|
+
session_ids = {p.session_id for p in precursors}
|
|
76
|
+
async with request.app.state.db() as session:
|
|
77
|
+
existing_sessions = {
|
|
78
|
+
session_id: rowid
|
|
79
|
+
async for session_id, rowid in await session.stream(
|
|
80
|
+
select(models.ProjectSession.session_id, models.ProjectSession.id).filter(
|
|
81
|
+
models.ProjectSession.session_id.in_(session_ids)
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
missing_session_ids = session_ids - set(existing_sessions.keys())
|
|
87
|
+
# We prefer to fail the entire operation if there are missing sessions in sync mode
|
|
88
|
+
if missing_session_ids:
|
|
89
|
+
raise HTTPException(
|
|
90
|
+
detail=f"Sessions with IDs {', '.join(missing_session_ids)} do not exist.",
|
|
91
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
async with request.app.state.db() as session:
|
|
95
|
+
inserted_ids = []
|
|
96
|
+
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
97
|
+
for p in precursors:
|
|
98
|
+
values = dict(as_kv(p.as_insertable(existing_sessions[p.session_id]).row))
|
|
99
|
+
session_annotation_id = await session.scalar(
|
|
100
|
+
insert_on_conflict(
|
|
101
|
+
values,
|
|
102
|
+
dialect=dialect,
|
|
103
|
+
table=models.ProjectSessionAnnotation,
|
|
104
|
+
unique_by=("name", "project_session_id", "identifier"),
|
|
105
|
+
).returning(models.ProjectSessionAnnotation.id)
|
|
106
|
+
)
|
|
107
|
+
inserted_ids.append(session_annotation_id)
|
|
108
|
+
|
|
109
|
+
return AnnotateSessionsResponseBody(
|
|
110
|
+
data=[InsertedSessionAnnotation(id=str(inserted_id)) for inserted_id in inserted_ids]
|
|
111
|
+
)
|
|
@@ -144,12 +144,11 @@ class AnnotateTracesResponseBody(ResponseBody[list[InsertedTraceAnnotation]]):
|
|
|
144
144
|
responses=add_errors_to_responses(
|
|
145
145
|
[{"status_code": HTTP_404_NOT_FOUND, "description": "Trace not found"}]
|
|
146
146
|
),
|
|
147
|
-
include_in_schema=False,
|
|
148
147
|
)
|
|
149
148
|
async def annotate_traces(
|
|
150
149
|
request: Request,
|
|
151
150
|
request_body: AnnotateTracesRequestBody,
|
|
152
|
-
sync: bool = Query(default=
|
|
151
|
+
sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
|
|
153
152
|
) -> AnnotateTracesResponseBody:
|
|
154
153
|
if not request_body.data:
|
|
155
154
|
return AnnotateTracesResponseBody(data=[])
|
|
@@ -217,6 +217,13 @@ async def create_user(
|
|
|
217
217
|
detail="Cannot create users with SYSTEM role",
|
|
218
218
|
)
|
|
219
219
|
|
|
220
|
+
# TODO: Implement VIEWER role
|
|
221
|
+
if role == "VIEWER":
|
|
222
|
+
raise HTTPException(
|
|
223
|
+
status_code=HTTP_400_BAD_REQUEST,
|
|
224
|
+
detail="VIEWER role not yet implemented",
|
|
225
|
+
)
|
|
226
|
+
|
|
220
227
|
user: models.User
|
|
221
228
|
if isinstance(user_data, LocalUserData):
|
|
222
229
|
password = (user_data.password or secrets.token_hex()).strip()
|