arize-phoenix 11.38.0__py3-none-any.whl → 12.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/METADATA +3 -3
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/RECORD +71 -50
- phoenix/config.py +1 -11
- phoenix/db/bulk_inserter.py +8 -0
- phoenix/db/facilitator.py +1 -1
- phoenix/db/helpers.py +202 -33
- phoenix/db/insertion/dataset.py +7 -0
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/types.py +30 -0
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +285 -46
- phoenix/server/api/context.py +13 -2
- phoenix/server/api/dataloaders/__init__.py +6 -2
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/helpers/playground_clients.py +65 -35
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/chat_mutations.py +8 -3
- phoenix/server/api/mutations/dataset_mutations.py +5 -0
- phoenix/server/api/mutations/dataset_split_mutations.py +387 -0
- phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
- phoenix/server/api/queries.py +32 -0
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/annotations.py +320 -0
- phoenix/server/api/routers/v1/datasets.py +5 -0
- phoenix/server/api/routers/v1/experiments.py +10 -3
- phoenix/server/api/routers/v1/sessions.py +111 -0
- phoenix/server/api/routers/v1/traces.py +1 -2
- phoenix/server/api/routers/v1/users.py +7 -0
- phoenix/server/api/subscriptions.py +5 -2
- phoenix/server/api/types/DatasetExample.py +11 -0
- phoenix/server/api/types/DatasetSplit.py +32 -0
- phoenix/server/api/types/Experiment.py +0 -4
- phoenix/server/api/types/Project.py +16 -0
- phoenix/server/api/types/ProjectSession.py +88 -3
- phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
- phoenix/server/api/types/Span.py +5 -5
- phoenix/server/api/types/Trace.py +61 -0
- phoenix/server/app.py +6 -2
- phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
- phoenix/server/dml_event.py +13 -0
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-BQPHTBfv.js → components-Dl9SUw1U.js} +371 -327
- phoenix/server/static/assets/{index-BL5BMgJU.js → index-CqQS0dTo.js} +2 -2
- phoenix/server/static/assets/{pages-C0Y17J0T.js → pages-DKSjVA_E.js} +762 -514
- phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-CtbHQYl8.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-D-lWOwIS.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-BRBpy3_z.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts--KdSwB3m.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-CvRzZnIo.js} +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,6 +14,9 @@ from strawberry.relay import GlobalID
|
|
|
14
14
|
from phoenix.db import models
|
|
15
15
|
from phoenix.db.insertion.types import Precursors
|
|
16
16
|
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
17
|
+
from phoenix.server.api.types.ProjectSessionAnnotation import (
|
|
18
|
+
ProjectSessionAnnotation as SessionAnnotationNodeType,
|
|
19
|
+
)
|
|
17
20
|
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation as SpanAnnotationNodeType
|
|
18
21
|
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation as TraceAnnotationNodeType
|
|
19
22
|
from phoenix.server.api.types.User import User as UserNodeType
|
|
@@ -24,6 +27,7 @@ logger = logging.getLogger(__name__)
|
|
|
24
27
|
|
|
25
28
|
SPAN_ANNOTATION_NODE_NAME = SpanAnnotationNodeType.__name__
|
|
26
29
|
TRACE_ANNOTATION_NODE_NAME = TraceAnnotationNodeType.__name__
|
|
30
|
+
SESSION_ANNOTATION_NODE_NAME = SessionAnnotationNodeType.__name__
|
|
27
31
|
MAX_TRACE_IDS = 1_000
|
|
28
32
|
USER_NODE_NAME = UserNodeType.__name__
|
|
29
33
|
MAX_SPAN_IDS = 1_000
|
|
@@ -161,6 +165,35 @@ class TraceAnnotationsResponseBody(PaginatedResponseBody[TraceAnnotation]):
|
|
|
161
165
|
pass
|
|
162
166
|
|
|
163
167
|
|
|
168
|
+
class SessionAnnotationData(AnnotationData):
|
|
169
|
+
session_id: str = Field(description="Session ID")
|
|
170
|
+
|
|
171
|
+
def as_precursor(self, *, user_id: Optional[int] = None) -> Precursors.SessionAnnotation:
|
|
172
|
+
return Precursors.SessionAnnotation(
|
|
173
|
+
datetime.now(timezone.utc),
|
|
174
|
+
self.session_id,
|
|
175
|
+
models.ProjectSessionAnnotation(
|
|
176
|
+
name=self.name,
|
|
177
|
+
annotator_kind=self.annotator_kind,
|
|
178
|
+
score=self.result.score if self.result else None,
|
|
179
|
+
label=self.result.label if self.result else None,
|
|
180
|
+
explanation=self.result.explanation if self.result else None,
|
|
181
|
+
metadata_=self.metadata or {},
|
|
182
|
+
identifier=self.identifier,
|
|
183
|
+
source="API",
|
|
184
|
+
user_id=user_id,
|
|
185
|
+
),
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class SessionAnnotation(SessionAnnotationData, Annotation):
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class SessionAnnotationsResponseBody(PaginatedResponseBody[SessionAnnotation]):
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
|
|
164
197
|
@router.get(
|
|
165
198
|
"/projects/{project_identifier}/span_annotations",
|
|
166
199
|
operation_id="listSpanAnnotationsBySpanIds",
|
|
@@ -304,3 +337,290 @@ async def list_span_annotations(
|
|
|
304
337
|
]
|
|
305
338
|
|
|
306
339
|
return SpanAnnotationsResponseBody(data=data, next_cursor=next_cursor)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@router.get(
|
|
343
|
+
"/projects/{project_identifier}/trace_annotations",
|
|
344
|
+
operation_id="listTraceAnnotationsByTraceIds",
|
|
345
|
+
summary="Get trace annotations for a list of trace_ids.",
|
|
346
|
+
status_code=HTTP_200_OK,
|
|
347
|
+
responses=add_errors_to_responses(
|
|
348
|
+
[
|
|
349
|
+
{"status_code": HTTP_404_NOT_FOUND, "description": "Project or traces not found"},
|
|
350
|
+
{"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
|
|
351
|
+
]
|
|
352
|
+
),
|
|
353
|
+
)
|
|
354
|
+
async def list_trace_annotations(
|
|
355
|
+
request: Request,
|
|
356
|
+
project_identifier: str = Path(
|
|
357
|
+
description=(
|
|
358
|
+
"The project identifier: either project ID or project name. If using a project name as "
|
|
359
|
+
"the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
|
|
360
|
+
"characters."
|
|
361
|
+
)
|
|
362
|
+
),
|
|
363
|
+
trace_ids: list[str] = Query(
|
|
364
|
+
..., min_length=1, description="One or more trace id to fetch annotations for"
|
|
365
|
+
),
|
|
366
|
+
include_annotation_names: Optional[list[str]] = Query(
|
|
367
|
+
default=None,
|
|
368
|
+
description=(
|
|
369
|
+
"Optional list of annotation names to include. If provided, only annotations with "
|
|
370
|
+
"these names will be returned. 'note' annotations are excluded by default unless "
|
|
371
|
+
"explicitly included in this list."
|
|
372
|
+
),
|
|
373
|
+
),
|
|
374
|
+
exclude_annotation_names: Optional[list[str]] = Query(
|
|
375
|
+
default=None, description="Optional list of annotation names to exclude from results."
|
|
376
|
+
),
|
|
377
|
+
cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
|
|
378
|
+
limit: int = Query(
|
|
379
|
+
default=10,
|
|
380
|
+
gt=0,
|
|
381
|
+
le=10000,
|
|
382
|
+
description="The maximum number of annotations to return in a single request",
|
|
383
|
+
),
|
|
384
|
+
) -> TraceAnnotationsResponseBody:
|
|
385
|
+
trace_ids = list({*trace_ids})
|
|
386
|
+
if len(trace_ids) > MAX_TRACE_IDS:
|
|
387
|
+
raise HTTPException(
|
|
388
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
389
|
+
detail=f"Too many trace_ids supplied: {len(trace_ids)} (max {MAX_TRACE_IDS})",
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
async with request.app.state.db() as session:
|
|
393
|
+
project = await _get_project_by_identifier(session, project_identifier)
|
|
394
|
+
if not project:
|
|
395
|
+
raise HTTPException(
|
|
396
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
397
|
+
detail=f"Project with identifier {project_identifier} not found",
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Build the base query
|
|
401
|
+
where_conditions = [
|
|
402
|
+
models.Project.id == project.id,
|
|
403
|
+
models.Trace.trace_id.in_(trace_ids),
|
|
404
|
+
]
|
|
405
|
+
|
|
406
|
+
# Add annotation name filtering
|
|
407
|
+
if include_annotation_names:
|
|
408
|
+
where_conditions.append(models.TraceAnnotation.name.in_(include_annotation_names))
|
|
409
|
+
|
|
410
|
+
if exclude_annotation_names:
|
|
411
|
+
where_conditions.append(models.TraceAnnotation.name.not_in(exclude_annotation_names))
|
|
412
|
+
|
|
413
|
+
stmt = (
|
|
414
|
+
select(models.Trace.trace_id, models.TraceAnnotation)
|
|
415
|
+
.join(models.Project, models.Trace.project_rowid == models.Project.id)
|
|
416
|
+
.join(models.TraceAnnotation, models.TraceAnnotation.trace_rowid == models.Trace.id)
|
|
417
|
+
.where(*where_conditions)
|
|
418
|
+
.order_by(models.TraceAnnotation.id.desc())
|
|
419
|
+
.limit(limit + 1)
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if cursor:
|
|
423
|
+
try:
|
|
424
|
+
cursor_id = int(GlobalID.from_id(cursor).node_id)
|
|
425
|
+
except ValueError:
|
|
426
|
+
raise HTTPException(
|
|
427
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
428
|
+
detail="Invalid cursor value",
|
|
429
|
+
)
|
|
430
|
+
stmt = stmt.where(models.TraceAnnotation.id <= cursor_id)
|
|
431
|
+
|
|
432
|
+
rows: list[tuple[str, models.TraceAnnotation]] = [
|
|
433
|
+
r async for r in (await session.stream(stmt))
|
|
434
|
+
]
|
|
435
|
+
|
|
436
|
+
next_cursor: Optional[str] = None
|
|
437
|
+
if len(rows) == limit + 1:
|
|
438
|
+
*rows, extra = rows
|
|
439
|
+
next_cursor = str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(extra[1].id)))
|
|
440
|
+
|
|
441
|
+
if not rows:
|
|
442
|
+
traces_exist = await session.scalar(
|
|
443
|
+
select(
|
|
444
|
+
exists().where(
|
|
445
|
+
models.Trace.trace_id.in_(trace_ids),
|
|
446
|
+
models.Trace.project_rowid == project.id,
|
|
447
|
+
)
|
|
448
|
+
)
|
|
449
|
+
)
|
|
450
|
+
if not traces_exist:
|
|
451
|
+
raise HTTPException(
|
|
452
|
+
detail="None of the supplied trace_ids exist in this project",
|
|
453
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
return TraceAnnotationsResponseBody(data=[], next_cursor=None)
|
|
457
|
+
|
|
458
|
+
data = [
|
|
459
|
+
TraceAnnotation(
|
|
460
|
+
id=str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(anno.id))),
|
|
461
|
+
trace_id=trace_id,
|
|
462
|
+
name=anno.name,
|
|
463
|
+
result=AnnotationResult(
|
|
464
|
+
label=anno.label,
|
|
465
|
+
score=anno.score,
|
|
466
|
+
explanation=anno.explanation,
|
|
467
|
+
),
|
|
468
|
+
metadata=anno.metadata_,
|
|
469
|
+
annotator_kind=anno.annotator_kind,
|
|
470
|
+
created_at=anno.created_at,
|
|
471
|
+
updated_at=anno.updated_at,
|
|
472
|
+
identifier=anno.identifier,
|
|
473
|
+
source=anno.source,
|
|
474
|
+
user_id=str(GlobalID("User", str(anno.user_id))) if anno.user_id else None,
|
|
475
|
+
)
|
|
476
|
+
for trace_id, anno in rows
|
|
477
|
+
]
|
|
478
|
+
|
|
479
|
+
return TraceAnnotationsResponseBody(data=data, next_cursor=next_cursor)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
@router.get(
|
|
483
|
+
"/projects/{project_identifier}/session_annotations",
|
|
484
|
+
operation_id="listSessionAnnotationsBySessionIds",
|
|
485
|
+
summary="Get session annotations for a list of session_ids.",
|
|
486
|
+
status_code=HTTP_200_OK,
|
|
487
|
+
responses=add_errors_to_responses(
|
|
488
|
+
[
|
|
489
|
+
{"status_code": HTTP_404_NOT_FOUND, "description": "Project or sessions not found"},
|
|
490
|
+
{"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
|
|
491
|
+
]
|
|
492
|
+
),
|
|
493
|
+
)
|
|
494
|
+
async def list_session_annotations(
|
|
495
|
+
request: Request,
|
|
496
|
+
project_identifier: str = Path(
|
|
497
|
+
description=(
|
|
498
|
+
"The project identifier: either project ID or project name. If using a project name as "
|
|
499
|
+
"the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
|
|
500
|
+
"characters."
|
|
501
|
+
)
|
|
502
|
+
),
|
|
503
|
+
session_ids: list[str] = Query(
|
|
504
|
+
..., min_length=1, description="One or more session id to fetch annotations for"
|
|
505
|
+
),
|
|
506
|
+
include_annotation_names: Optional[list[str]] = Query(
|
|
507
|
+
default=None,
|
|
508
|
+
description=(
|
|
509
|
+
"Optional list of annotation names to include. If provided, only annotations with "
|
|
510
|
+
"these names will be returned. 'note' annotations are excluded by default unless "
|
|
511
|
+
"explicitly included in this list."
|
|
512
|
+
),
|
|
513
|
+
),
|
|
514
|
+
exclude_annotation_names: Optional[list[str]] = Query(
|
|
515
|
+
default=None, description="Optional list of annotation names to exclude from results."
|
|
516
|
+
),
|
|
517
|
+
cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
|
|
518
|
+
limit: int = Query(
|
|
519
|
+
default=10,
|
|
520
|
+
gt=0,
|
|
521
|
+
le=10000,
|
|
522
|
+
description="The maximum number of annotations to return in a single request",
|
|
523
|
+
),
|
|
524
|
+
) -> SessionAnnotationsResponseBody:
|
|
525
|
+
session_ids = list({*session_ids})
|
|
526
|
+
if len(session_ids) > MAX_SESSION_IDS:
|
|
527
|
+
raise HTTPException(
|
|
528
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
529
|
+
detail=f"Too many session_ids supplied: {len(session_ids)} (max {MAX_SESSION_IDS})",
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
async with request.app.state.db() as session:
|
|
533
|
+
project = await _get_project_by_identifier(session, project_identifier)
|
|
534
|
+
if not project:
|
|
535
|
+
raise HTTPException(
|
|
536
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
537
|
+
detail=f"Project with identifier {project_identifier} not found",
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
# Build the base query
|
|
541
|
+
where_conditions = [
|
|
542
|
+
models.Project.id == project.id,
|
|
543
|
+
models.ProjectSession.session_id.in_(session_ids),
|
|
544
|
+
]
|
|
545
|
+
|
|
546
|
+
# Add annotation name filtering
|
|
547
|
+
if include_annotation_names:
|
|
548
|
+
where_conditions.append(
|
|
549
|
+
models.ProjectSessionAnnotation.name.in_(include_annotation_names)
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
if exclude_annotation_names:
|
|
553
|
+
where_conditions.append(
|
|
554
|
+
models.ProjectSessionAnnotation.name.not_in(exclude_annotation_names)
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
stmt = (
|
|
558
|
+
select(models.ProjectSession.session_id, models.ProjectSessionAnnotation)
|
|
559
|
+
.join(models.Project, models.ProjectSession.project_id == models.Project.id)
|
|
560
|
+
.join(
|
|
561
|
+
models.ProjectSessionAnnotation,
|
|
562
|
+
models.ProjectSessionAnnotation.project_session_id == models.ProjectSession.id,
|
|
563
|
+
)
|
|
564
|
+
.where(*where_conditions)
|
|
565
|
+
.order_by(models.ProjectSessionAnnotation.id.desc())
|
|
566
|
+
.limit(limit + 1)
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
if cursor:
|
|
570
|
+
try:
|
|
571
|
+
cursor_id = int(GlobalID.from_id(cursor).node_id)
|
|
572
|
+
except ValueError:
|
|
573
|
+
raise HTTPException(
|
|
574
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
575
|
+
detail="Invalid cursor value",
|
|
576
|
+
)
|
|
577
|
+
stmt = stmt.where(models.ProjectSessionAnnotation.id <= cursor_id)
|
|
578
|
+
|
|
579
|
+
rows: list[tuple[str, models.ProjectSessionAnnotation]] = [
|
|
580
|
+
r async for r in (await session.stream(stmt))
|
|
581
|
+
]
|
|
582
|
+
|
|
583
|
+
next_cursor: Optional[str] = None
|
|
584
|
+
if len(rows) == limit + 1:
|
|
585
|
+
*rows, extra = rows
|
|
586
|
+
next_cursor = str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(extra[1].id)))
|
|
587
|
+
|
|
588
|
+
if not rows:
|
|
589
|
+
sessions_exist = await session.scalar(
|
|
590
|
+
select(
|
|
591
|
+
exists().where(
|
|
592
|
+
models.ProjectSession.session_id.in_(session_ids),
|
|
593
|
+
models.ProjectSession.project_id == project.id,
|
|
594
|
+
)
|
|
595
|
+
)
|
|
596
|
+
)
|
|
597
|
+
if not sessions_exist:
|
|
598
|
+
raise HTTPException(
|
|
599
|
+
detail="None of the supplied session_ids exist in this project",
|
|
600
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
return SessionAnnotationsResponseBody(data=[], next_cursor=None)
|
|
604
|
+
|
|
605
|
+
data = [
|
|
606
|
+
SessionAnnotation(
|
|
607
|
+
id=str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(anno.id))),
|
|
608
|
+
session_id=session_id,
|
|
609
|
+
name=anno.name,
|
|
610
|
+
result=AnnotationResult(
|
|
611
|
+
label=anno.label,
|
|
612
|
+
score=anno.score,
|
|
613
|
+
explanation=anno.explanation,
|
|
614
|
+
),
|
|
615
|
+
metadata=anno.metadata_,
|
|
616
|
+
annotator_kind=anno.annotator_kind,
|
|
617
|
+
created_at=anno.created_at,
|
|
618
|
+
updated_at=anno.updated_at,
|
|
619
|
+
identifier=anno.identifier,
|
|
620
|
+
source=anno.source,
|
|
621
|
+
user_id=str(GlobalID(USER_NODE_NAME, str(anno.user_id))) if anno.user_id else None,
|
|
622
|
+
)
|
|
623
|
+
for session_id, anno in rows
|
|
624
|
+
]
|
|
625
|
+
|
|
626
|
+
return SessionAnnotationsResponseBody(data=data, next_cursor=next_cursor)
|
|
@@ -48,6 +48,7 @@ from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVer
|
|
|
48
48
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
49
49
|
from phoenix.server.api.utils import delete_projects, delete_traces
|
|
50
50
|
from phoenix.server.authorization import is_not_locked
|
|
51
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
51
52
|
from phoenix.server.dml_event import DatasetInsertEvent
|
|
52
53
|
|
|
53
54
|
from .models import V1RoutesBaseModel
|
|
@@ -478,6 +479,9 @@ async def upload_dataset(
|
|
|
478
479
|
detail="Invalid request Content-Type",
|
|
479
480
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
480
481
|
)
|
|
482
|
+
user_id: Optional[int] = None
|
|
483
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
484
|
+
user_id = int(request.user.identity)
|
|
481
485
|
operation = cast(
|
|
482
486
|
Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
|
|
483
487
|
partial(
|
|
@@ -486,6 +490,7 @@ async def upload_dataset(
|
|
|
486
490
|
action=action,
|
|
487
491
|
name=name,
|
|
488
492
|
description=description,
|
|
493
|
+
user_id=user_id,
|
|
489
494
|
),
|
|
490
495
|
)
|
|
491
496
|
if sync:
|
|
@@ -15,10 +15,14 @@ from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESS
|
|
|
15
15
|
from strawberry.relay import GlobalID
|
|
16
16
|
|
|
17
17
|
from phoenix.db import models
|
|
18
|
-
from phoenix.db.helpers import
|
|
18
|
+
from phoenix.db.helpers import (
|
|
19
|
+
SupportedSQLDialect,
|
|
20
|
+
insert_experiment_with_examples_snapshot,
|
|
21
|
+
)
|
|
19
22
|
from phoenix.db.insertion.helpers import insert_on_conflict
|
|
20
23
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
21
24
|
from phoenix.server.authorization import is_not_locked
|
|
25
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
22
26
|
from phoenix.server.dml_event import ExperimentInsertEvent
|
|
23
27
|
from phoenix.server.experiments.utils import generate_experiment_project_name
|
|
24
28
|
|
|
@@ -157,6 +161,9 @@ async def create_experiment(
|
|
|
157
161
|
detail=f"DatasetVersion with ID {dataset_version_globalid} does not exist",
|
|
158
162
|
status_code=HTTP_404_NOT_FOUND,
|
|
159
163
|
)
|
|
164
|
+
user_id: Optional[int] = None
|
|
165
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
166
|
+
user_id = int(request.user.identity)
|
|
160
167
|
|
|
161
168
|
# generate a semi-unique name for the experiment
|
|
162
169
|
experiment_name = request_body.name or _generate_experiment_name(dataset_name)
|
|
@@ -172,9 +179,9 @@ async def create_experiment(
|
|
|
172
179
|
repetitions=request_body.repetitions,
|
|
173
180
|
metadata_=request_body.metadata or {},
|
|
174
181
|
project_name=project_name,
|
|
182
|
+
user_id=user_id,
|
|
175
183
|
)
|
|
176
|
-
session
|
|
177
|
-
await session.flush()
|
|
184
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
178
185
|
|
|
179
186
|
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
180
187
|
project_rowid = await session.scalar(
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
7
|
+
from pydantic import Field
|
|
8
|
+
from sqlalchemy import select
|
|
9
|
+
from starlette.requests import Request
|
|
10
|
+
from starlette.status import HTTP_404_NOT_FOUND
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.db.helpers import SupportedSQLDialect
|
|
14
|
+
from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
|
|
15
|
+
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
16
|
+
from phoenix.server.authorization import is_not_locked
|
|
17
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
18
|
+
|
|
19
|
+
from .annotations import SessionAnnotationData
|
|
20
|
+
from .utils import RequestBody, ResponseBody, add_errors_to_responses
|
|
21
|
+
|
|
22
|
+
router = APIRouter(tags=["sessions"])
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class InsertedSessionAnnotation(V1RoutesBaseModel):
|
|
26
|
+
id: str = Field(description="The ID of the inserted session annotation")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AnnotateSessionsRequestBody(RequestBody[list[SessionAnnotationData]]):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AnnotateSessionsResponseBody(ResponseBody[list[InsertedSessionAnnotation]]):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@router.post(
|
|
38
|
+
"/session_annotations",
|
|
39
|
+
dependencies=[Depends(is_not_locked)],
|
|
40
|
+
operation_id="annotateSessions",
|
|
41
|
+
summary="Create session annotations",
|
|
42
|
+
responses=add_errors_to_responses(
|
|
43
|
+
[{"status_code": HTTP_404_NOT_FOUND, "description": "Session not found"}]
|
|
44
|
+
),
|
|
45
|
+
response_description="Session annotations inserted successfully",
|
|
46
|
+
include_in_schema=True,
|
|
47
|
+
)
|
|
48
|
+
async def annotate_sessions(
|
|
49
|
+
request: Request,
|
|
50
|
+
request_body: AnnotateSessionsRequestBody,
|
|
51
|
+
sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
|
|
52
|
+
) -> AnnotateSessionsResponseBody:
|
|
53
|
+
if not request_body.data:
|
|
54
|
+
return AnnotateSessionsResponseBody(data=[])
|
|
55
|
+
|
|
56
|
+
user_id: Optional[int] = None
|
|
57
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
58
|
+
user_id = int(request.user.identity)
|
|
59
|
+
|
|
60
|
+
session_annotations = request_body.data
|
|
61
|
+
filtered_session_annotations = list(filter(lambda d: d.name != "note", session_annotations))
|
|
62
|
+
if len(filtered_session_annotations) != len(session_annotations):
|
|
63
|
+
warnings.warn(
|
|
64
|
+
(
|
|
65
|
+
"Session annotations with the name 'note' are not supported in this endpoint. "
|
|
66
|
+
"They will be ignored."
|
|
67
|
+
),
|
|
68
|
+
UserWarning,
|
|
69
|
+
)
|
|
70
|
+
precursors = [d.as_precursor(user_id=user_id) for d in filtered_session_annotations]
|
|
71
|
+
if not sync:
|
|
72
|
+
await request.state.enqueue_annotations(*precursors)
|
|
73
|
+
return AnnotateSessionsResponseBody(data=[])
|
|
74
|
+
|
|
75
|
+
session_ids = {p.session_id for p in precursors}
|
|
76
|
+
async with request.app.state.db() as session:
|
|
77
|
+
existing_sessions = {
|
|
78
|
+
session_id: rowid
|
|
79
|
+
async for session_id, rowid in await session.stream(
|
|
80
|
+
select(models.ProjectSession.session_id, models.ProjectSession.id).filter(
|
|
81
|
+
models.ProjectSession.session_id.in_(session_ids)
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
missing_session_ids = session_ids - set(existing_sessions.keys())
|
|
87
|
+
# We prefer to fail the entire operation if there are missing sessions in sync mode
|
|
88
|
+
if missing_session_ids:
|
|
89
|
+
raise HTTPException(
|
|
90
|
+
detail=f"Sessions with IDs {', '.join(missing_session_ids)} do not exist.",
|
|
91
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
async with request.app.state.db() as session:
|
|
95
|
+
inserted_ids = []
|
|
96
|
+
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
97
|
+
for p in precursors:
|
|
98
|
+
values = dict(as_kv(p.as_insertable(existing_sessions[p.session_id]).row))
|
|
99
|
+
session_annotation_id = await session.scalar(
|
|
100
|
+
insert_on_conflict(
|
|
101
|
+
values,
|
|
102
|
+
dialect=dialect,
|
|
103
|
+
table=models.ProjectSessionAnnotation,
|
|
104
|
+
unique_by=("name", "project_session_id", "identifier"),
|
|
105
|
+
).returning(models.ProjectSessionAnnotation.id)
|
|
106
|
+
)
|
|
107
|
+
inserted_ids.append(session_annotation_id)
|
|
108
|
+
|
|
109
|
+
return AnnotateSessionsResponseBody(
|
|
110
|
+
data=[InsertedSessionAnnotation(id=str(inserted_id)) for inserted_id in inserted_ids]
|
|
111
|
+
)
|
|
@@ -144,12 +144,11 @@ class AnnotateTracesResponseBody(ResponseBody[list[InsertedTraceAnnotation]]):
|
|
|
144
144
|
responses=add_errors_to_responses(
|
|
145
145
|
[{"status_code": HTTP_404_NOT_FOUND, "description": "Trace not found"}]
|
|
146
146
|
),
|
|
147
|
-
include_in_schema=False,
|
|
148
147
|
)
|
|
149
148
|
async def annotate_traces(
|
|
150
149
|
request: Request,
|
|
151
150
|
request_body: AnnotateTracesRequestBody,
|
|
152
|
-
sync: bool = Query(default=
|
|
151
|
+
sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
|
|
153
152
|
) -> AnnotateTracesResponseBody:
|
|
154
153
|
if not request_body.data:
|
|
155
154
|
return AnnotateTracesResponseBody(data=[])
|
|
@@ -217,6 +217,13 @@ async def create_user(
|
|
|
217
217
|
detail="Cannot create users with SYSTEM role",
|
|
218
218
|
)
|
|
219
219
|
|
|
220
|
+
# TODO: Implement VIEWER role
|
|
221
|
+
if role == "VIEWER":
|
|
222
|
+
raise HTTPException(
|
|
223
|
+
status_code=HTTP_400_BAD_REQUEST,
|
|
224
|
+
detail="VIEWER role not yet implemented",
|
|
225
|
+
)
|
|
226
|
+
|
|
220
227
|
user: models.User
|
|
221
228
|
if isinstance(user_data, LocalUserData):
|
|
222
229
|
password = (user_data.password or secrets.token_hex()).strip()
|
|
@@ -26,6 +26,7 @@ from typing_extensions import TypeAlias, assert_never
|
|
|
26
26
|
from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
27
27
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
28
28
|
from phoenix.db import models
|
|
29
|
+
from phoenix.db.helpers import insert_experiment_with_examples_snapshot
|
|
29
30
|
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
30
31
|
from phoenix.server.api.context import Context
|
|
31
32
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
@@ -43,6 +44,7 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
43
44
|
get_db_trace,
|
|
44
45
|
streaming_llm_span,
|
|
45
46
|
)
|
|
47
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
46
48
|
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
47
49
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
48
50
|
ChatCompletionInput,
|
|
@@ -302,6 +304,7 @@ class Subscription:
|
|
|
302
304
|
description="Traces from prompt playground",
|
|
303
305
|
)
|
|
304
306
|
)
|
|
307
|
+
user_id = get_user(info)
|
|
305
308
|
experiment = models.Experiment(
|
|
306
309
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
307
310
|
dataset_version_id=resolved_version_id,
|
|
@@ -311,9 +314,9 @@ class Subscription:
|
|
|
311
314
|
repetitions=input.repetitions,
|
|
312
315
|
metadata_=input.experiment_metadata or dict(),
|
|
313
316
|
project_name=project_name,
|
|
317
|
+
user_id=user_id,
|
|
314
318
|
)
|
|
315
|
-
session
|
|
316
|
-
await session.flush()
|
|
319
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
317
320
|
yield ChatCompletionSubscriptionExperiment(
|
|
318
321
|
experiment=to_gql_experiment(experiment)
|
|
319
322
|
) # eagerly yields experiment so it can be linked by consumers of the subscription
|
|
@@ -12,6 +12,7 @@ from phoenix.db import models
|
|
|
12
12
|
from phoenix.server.api.context import Context
|
|
13
13
|
from phoenix.server.api.exceptions import BadRequest
|
|
14
14
|
from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
|
|
15
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
15
16
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
16
17
|
from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
17
18
|
ExperimentRepeatedRunGroup,
|
|
@@ -131,3 +132,13 @@ class DatasetExample(Node):
|
|
|
131
132
|
)
|
|
132
133
|
for group in repeated_run_groups
|
|
133
134
|
]
|
|
135
|
+
|
|
136
|
+
@strawberry.field
|
|
137
|
+
async def dataset_splits(
|
|
138
|
+
self,
|
|
139
|
+
info: Info[Context, None],
|
|
140
|
+
) -> list[DatasetSplit]:
|
|
141
|
+
return [
|
|
142
|
+
to_gql_dataset_split(split)
|
|
143
|
+
for split in await info.context.data_loaders.dataset_example_splits.load(self.id_attr)
|
|
144
|
+
]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import ClassVar, Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry.relay import Node, NodeID
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@strawberry.type
|
|
12
|
+
class DatasetSplit(Node):
|
|
13
|
+
_table: ClassVar[type[models.Base]] = models.DatasetSplit
|
|
14
|
+
id_attr: NodeID[int]
|
|
15
|
+
name: str
|
|
16
|
+
description: Optional[str]
|
|
17
|
+
metadata: JSON
|
|
18
|
+
color: str
|
|
19
|
+
created_at: datetime
|
|
20
|
+
updated_at: datetime
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def to_gql_dataset_split(dataset_split: models.DatasetSplit) -> DatasetSplit:
|
|
24
|
+
return DatasetSplit(
|
|
25
|
+
id_attr=dataset_split.id,
|
|
26
|
+
name=dataset_split.name,
|
|
27
|
+
description=dataset_split.description,
|
|
28
|
+
color=dataset_split.color or "#ffffff",
|
|
29
|
+
metadata=dataset_split.metadata_,
|
|
30
|
+
created_at=dataset_split.created_at,
|
|
31
|
+
updated_at=dataset_split.updated_at,
|
|
32
|
+
)
|
|
@@ -193,10 +193,6 @@ class Experiment(Node):
|
|
|
193
193
|
async for token_type, is_prompt, cost, tokens in data
|
|
194
194
|
]
|
|
195
195
|
|
|
196
|
-
@strawberry.field
|
|
197
|
-
async def repetition_count(self, info: Info[Context, None]) -> int:
|
|
198
|
-
return await info.context.data_loaders.experiment_repetition_counts.load(self.id_attr)
|
|
199
|
-
|
|
200
196
|
|
|
201
197
|
def to_gql_experiment(
|
|
202
198
|
experiment: models.Experiment,
|
|
@@ -588,6 +588,22 @@ class Project(Node):
|
|
|
588
588
|
async with info.context.db() as session:
|
|
589
589
|
return list(await session.scalars(stmt))
|
|
590
590
|
|
|
591
|
+
@strawberry.field(
|
|
592
|
+
description="Names of all available annotations for sessions. "
|
|
593
|
+
"(The list contains no duplicates.)"
|
|
594
|
+
) # type: ignore
|
|
595
|
+
async def session_annotation_names(
|
|
596
|
+
self,
|
|
597
|
+
info: Info[Context, None],
|
|
598
|
+
) -> list[str]:
|
|
599
|
+
stmt = (
|
|
600
|
+
select(distinct(models.ProjectSessionAnnotation.name))
|
|
601
|
+
.join(models.ProjectSession)
|
|
602
|
+
.where(models.ProjectSession.project_id == self.project_rowid)
|
|
603
|
+
)
|
|
604
|
+
async with info.context.db() as session:
|
|
605
|
+
return list(await session.scalars(stmt))
|
|
606
|
+
|
|
591
607
|
@strawberry.field(
|
|
592
608
|
description="Names of available document evaluations.",
|
|
593
609
|
) # type: ignore
|