arize-phoenix 11.37.0__py3-none-any.whl → 12.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/METADATA +3 -3
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/RECORD +74 -53
- phoenix/config.py +1 -11
- phoenix/db/bulk_inserter.py +8 -0
- phoenix/db/facilitator.py +1 -1
- phoenix/db/helpers.py +202 -33
- phoenix/db/insertion/dataset.py +7 -0
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/types.py +30 -0
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +285 -46
- phoenix/server/api/context.py +13 -2
- phoenix/server/api/dataloaders/__init__.py +6 -2
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/helpers/playground_clients.py +65 -35
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +2 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/chat_mutations.py +24 -9
- phoenix/server/api/mutations/dataset_mutations.py +5 -0
- phoenix/server/api/mutations/dataset_split_mutations.py +387 -0
- phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
- phoenix/server/api/queries.py +32 -0
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/annotations.py +320 -0
- phoenix/server/api/routers/v1/datasets.py +5 -0
- phoenix/server/api/routers/v1/experiments.py +10 -3
- phoenix/server/api/routers/v1/sessions.py +111 -0
- phoenix/server/api/routers/v1/traces.py +1 -2
- phoenix/server/api/routers/v1/users.py +7 -0
- phoenix/server/api/subscriptions.py +25 -7
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/DatasetExample.py +11 -0
- phoenix/server/api/types/DatasetSplit.py +32 -0
- phoenix/server/api/types/Experiment.py +0 -4
- phoenix/server/api/types/Project.py +16 -0
- phoenix/server/api/types/ProjectSession.py +88 -3
- phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
- phoenix/server/api/types/Span.py +5 -5
- phoenix/server/api/types/Trace.py +61 -0
- phoenix/server/app.py +6 -2
- phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
- phoenix/server/dml_event.py +13 -0
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-CFzdBkk_.js → components-Dl9SUw1U.js} +371 -327
- phoenix/server/static/assets/{index-DayUA9lQ.js → index-CqQS0dTo.js} +2 -2
- phoenix/server/static/assets/{pages-CvUhOO9h.js → pages-DKSjVA_E.js} +771 -518
- phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-CtbHQYl8.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-D-lWOwIS.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-BRBpy3_z.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts--KdSwB3m.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-CvRzZnIo.js} +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
5
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
6
|
+
from starlette.requests import Request
|
|
7
|
+
from strawberry import Info
|
|
8
|
+
from strawberry.relay import GlobalID
|
|
9
|
+
|
|
10
|
+
from phoenix.db import models
|
|
11
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
12
|
+
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound, Unauthorized
|
|
14
|
+
from phoenix.server.api.helpers.annotations import get_user_identifier
|
|
15
|
+
from phoenix.server.api.input_types.CreateProjectSessionAnnotationInput import (
|
|
16
|
+
CreateProjectSessionAnnotationInput,
|
|
17
|
+
)
|
|
18
|
+
from phoenix.server.api.input_types.UpdateAnnotationInput import UpdateAnnotationInput
|
|
19
|
+
from phoenix.server.api.queries import Query
|
|
20
|
+
from phoenix.server.api.types.AnnotationSource import AnnotationSource
|
|
21
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
22
|
+
from phoenix.server.api.types.ProjectSessionAnnotation import (
|
|
23
|
+
ProjectSessionAnnotation,
|
|
24
|
+
to_gql_project_session_annotation,
|
|
25
|
+
)
|
|
26
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
27
|
+
from phoenix.server.dml_event import (
|
|
28
|
+
ProjectSessionAnnotationDeleteEvent,
|
|
29
|
+
ProjectSessionAnnotationInsertEvent,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@strawberry.type
|
|
34
|
+
class ProjectSessionAnnotationMutationPayload:
|
|
35
|
+
project_session_annotation: ProjectSessionAnnotation
|
|
36
|
+
query: Query
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@strawberry.type
|
|
40
|
+
class ProjectSessionAnnotationMutationMixin:
|
|
41
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
42
|
+
async def create_project_session_annotations(
|
|
43
|
+
self, info: Info[Context, None], input: CreateProjectSessionAnnotationInput
|
|
44
|
+
) -> ProjectSessionAnnotationMutationPayload:
|
|
45
|
+
assert isinstance(request := info.context.request, Request)
|
|
46
|
+
user_id: Optional[int] = None
|
|
47
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
48
|
+
user_id = int(user.identity)
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
project_session_id = from_global_id_with_expected_type(
|
|
52
|
+
input.project_session_id, "ProjectSession"
|
|
53
|
+
)
|
|
54
|
+
except ValueError:
|
|
55
|
+
raise BadRequest(f"Invalid session ID: {input.project_session_id}")
|
|
56
|
+
|
|
57
|
+
identifier = ""
|
|
58
|
+
if isinstance(input.identifier, str):
|
|
59
|
+
identifier = input.identifier # Already trimmed in __post_init__
|
|
60
|
+
elif input.source == AnnotationSource.APP and user_id is not None:
|
|
61
|
+
identifier = get_user_identifier(user_id)
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
async with info.context.db() as session:
|
|
65
|
+
anno = models.ProjectSessionAnnotation(
|
|
66
|
+
project_session_id=project_session_id,
|
|
67
|
+
name=input.name,
|
|
68
|
+
label=input.label,
|
|
69
|
+
score=input.score,
|
|
70
|
+
explanation=input.explanation,
|
|
71
|
+
annotator_kind=input.annotator_kind.value,
|
|
72
|
+
metadata_=input.metadata,
|
|
73
|
+
identifier=identifier,
|
|
74
|
+
source=input.source.value,
|
|
75
|
+
user_id=user_id,
|
|
76
|
+
)
|
|
77
|
+
session.add(anno)
|
|
78
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
79
|
+
raise Conflict(f"Error creating annotation: {e}")
|
|
80
|
+
|
|
81
|
+
info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
|
|
82
|
+
|
|
83
|
+
return ProjectSessionAnnotationMutationPayload(
|
|
84
|
+
project_session_annotation=to_gql_project_session_annotation(anno),
|
|
85
|
+
query=Query(),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
89
|
+
async def update_project_session_annotations(
|
|
90
|
+
self, info: Info[Context, None], input: UpdateAnnotationInput
|
|
91
|
+
) -> ProjectSessionAnnotationMutationPayload:
|
|
92
|
+
assert isinstance(request := info.context.request, Request)
|
|
93
|
+
user_id: Optional[int] = None
|
|
94
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
95
|
+
user_id = int(user.identity)
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
id_ = from_global_id_with_expected_type(input.id, "ProjectSessionAnnotation")
|
|
99
|
+
except ValueError:
|
|
100
|
+
raise BadRequest(f"Invalid session annotation ID: {input.id}")
|
|
101
|
+
|
|
102
|
+
async with info.context.db() as session:
|
|
103
|
+
if not (anno := await session.get(models.ProjectSessionAnnotation, id_)):
|
|
104
|
+
raise NotFound(f"Could not find session annotation with ID: {input.id}")
|
|
105
|
+
if anno.user_id != user_id:
|
|
106
|
+
raise Unauthorized("Session annotation is not associated with the current user.")
|
|
107
|
+
|
|
108
|
+
# Update the annotation fields
|
|
109
|
+
anno.name = input.name
|
|
110
|
+
anno.label = input.label
|
|
111
|
+
anno.score = input.score
|
|
112
|
+
anno.explanation = input.explanation
|
|
113
|
+
anno.annotator_kind = input.annotator_kind.value
|
|
114
|
+
anno.metadata_ = input.metadata
|
|
115
|
+
anno.source = input.source.value
|
|
116
|
+
|
|
117
|
+
session.add(anno)
|
|
118
|
+
try:
|
|
119
|
+
await session.flush()
|
|
120
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
121
|
+
raise Conflict(f"Error updating annotation: {e}")
|
|
122
|
+
|
|
123
|
+
info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
|
|
124
|
+
return ProjectSessionAnnotationMutationPayload(
|
|
125
|
+
project_session_annotation=to_gql_project_session_annotation(anno),
|
|
126
|
+
query=Query(),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
130
|
+
async def delete_project_session_annotation(
|
|
131
|
+
self, info: Info[Context, None], id: GlobalID
|
|
132
|
+
) -> ProjectSessionAnnotationMutationPayload:
|
|
133
|
+
try:
|
|
134
|
+
id_ = from_global_id_with_expected_type(id, "ProjectSessionAnnotation")
|
|
135
|
+
except ValueError:
|
|
136
|
+
raise BadRequest(f"Invalid session annotation ID: {id}")
|
|
137
|
+
|
|
138
|
+
assert isinstance(request := info.context.request, Request)
|
|
139
|
+
user_id: Optional[int] = None
|
|
140
|
+
user_is_admin = False
|
|
141
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
142
|
+
user_id = int(user.identity)
|
|
143
|
+
user_is_admin = user.is_admin
|
|
144
|
+
|
|
145
|
+
async with info.context.db() as session:
|
|
146
|
+
if not (anno := await session.get(models.ProjectSessionAnnotation, id_)):
|
|
147
|
+
raise NotFound(f"Could not find session annotation with ID: {id}")
|
|
148
|
+
|
|
149
|
+
if not user_is_admin and anno.user_id != user_id:
|
|
150
|
+
raise Unauthorized(
|
|
151
|
+
"Session annotation is not associated with the current user and "
|
|
152
|
+
"the current user is not an admin."
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
await session.delete(anno)
|
|
156
|
+
|
|
157
|
+
deleted_gql_annotation = to_gql_project_session_annotation(anno)
|
|
158
|
+
info.context.event_queue.put(ProjectSessionAnnotationDeleteEvent((id_,)))
|
|
159
|
+
return ProjectSessionAnnotationMutationPayload(
|
|
160
|
+
project_session_annotation=deleted_gql_annotation, query=Query()
|
|
161
|
+
)
|
phoenix/server/api/queries.py
CHANGED
|
@@ -48,6 +48,7 @@ from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_a
|
|
|
48
48
|
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
49
49
|
from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
|
|
50
50
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
51
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
51
52
|
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
52
53
|
from phoenix.server.api.types.EmbeddingDimension import (
|
|
53
54
|
DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
@@ -959,6 +960,14 @@ class Query:
|
|
|
959
960
|
id_attr=example.id,
|
|
960
961
|
created_at=example.created_at,
|
|
961
962
|
)
|
|
963
|
+
elif type_name == DatasetSplit.__name__:
|
|
964
|
+
async with info.context.db() as session:
|
|
965
|
+
dataset_split = await session.scalar(
|
|
966
|
+
select(models.DatasetSplit).where(models.DatasetSplit.id == node_id)
|
|
967
|
+
)
|
|
968
|
+
if not dataset_split:
|
|
969
|
+
raise NotFound(f"Unknown dataset split: {id}")
|
|
970
|
+
return to_gql_dataset_split(dataset_split)
|
|
962
971
|
elif type_name == Experiment.__name__:
|
|
963
972
|
async with info.context.db() as session:
|
|
964
973
|
experiment = await session.scalar(
|
|
@@ -1140,6 +1149,29 @@ class Query:
|
|
|
1140
1149
|
args=args,
|
|
1141
1150
|
)
|
|
1142
1151
|
|
|
1152
|
+
@strawberry.field
|
|
1153
|
+
async def dataset_splits(
|
|
1154
|
+
self,
|
|
1155
|
+
info: Info[Context, None],
|
|
1156
|
+
first: Optional[int] = 50,
|
|
1157
|
+
last: Optional[int] = UNSET,
|
|
1158
|
+
after: Optional[CursorString] = UNSET,
|
|
1159
|
+
before: Optional[CursorString] = UNSET,
|
|
1160
|
+
) -> Connection[DatasetSplit]:
|
|
1161
|
+
args = ConnectionArgs(
|
|
1162
|
+
first=first,
|
|
1163
|
+
after=after if isinstance(after, CursorString) else None,
|
|
1164
|
+
last=last,
|
|
1165
|
+
before=before if isinstance(before, CursorString) else None,
|
|
1166
|
+
)
|
|
1167
|
+
async with info.context.db() as session:
|
|
1168
|
+
splits = await session.stream_scalars(select(models.DatasetSplit))
|
|
1169
|
+
data = [to_gql_dataset_split(split) async for split in splits]
|
|
1170
|
+
return connection_from_list(
|
|
1171
|
+
data=data,
|
|
1172
|
+
args=args,
|
|
1173
|
+
)
|
|
1174
|
+
|
|
1143
1175
|
@strawberry.field
|
|
1144
1176
|
async def annotation_configs(
|
|
1145
1177
|
self,
|
|
@@ -14,6 +14,7 @@ from .experiment_runs import router as experiment_runs_router
|
|
|
14
14
|
from .experiments import router as experiments_router
|
|
15
15
|
from .projects import router as projects_router
|
|
16
16
|
from .prompts import router as prompts_router
|
|
17
|
+
from .sessions import router as sessions_router
|
|
17
18
|
from .spans import router as spans_router
|
|
18
19
|
from .traces import router as traces_router
|
|
19
20
|
from .users import router as users_router
|
|
@@ -71,6 +72,7 @@ def create_v1_router(authentication_enabled: bool) -> APIRouter:
|
|
|
71
72
|
router.include_router(evaluations_router)
|
|
72
73
|
router.include_router(prompts_router)
|
|
73
74
|
router.include_router(projects_router)
|
|
75
|
+
router.include_router(sessions_router)
|
|
74
76
|
router.include_router(documents_router)
|
|
75
77
|
router.include_router(users_router)
|
|
76
78
|
return router
|
|
@@ -14,6 +14,9 @@ from strawberry.relay import GlobalID
|
|
|
14
14
|
from phoenix.db import models
|
|
15
15
|
from phoenix.db.insertion.types import Precursors
|
|
16
16
|
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
17
|
+
from phoenix.server.api.types.ProjectSessionAnnotation import (
|
|
18
|
+
ProjectSessionAnnotation as SessionAnnotationNodeType,
|
|
19
|
+
)
|
|
17
20
|
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation as SpanAnnotationNodeType
|
|
18
21
|
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation as TraceAnnotationNodeType
|
|
19
22
|
from phoenix.server.api.types.User import User as UserNodeType
|
|
@@ -24,6 +27,7 @@ logger = logging.getLogger(__name__)
|
|
|
24
27
|
|
|
25
28
|
SPAN_ANNOTATION_NODE_NAME = SpanAnnotationNodeType.__name__
|
|
26
29
|
TRACE_ANNOTATION_NODE_NAME = TraceAnnotationNodeType.__name__
|
|
30
|
+
SESSION_ANNOTATION_NODE_NAME = SessionAnnotationNodeType.__name__
|
|
27
31
|
MAX_TRACE_IDS = 1_000
|
|
28
32
|
USER_NODE_NAME = UserNodeType.__name__
|
|
29
33
|
MAX_SPAN_IDS = 1_000
|
|
@@ -161,6 +165,35 @@ class TraceAnnotationsResponseBody(PaginatedResponseBody[TraceAnnotation]):
|
|
|
161
165
|
pass
|
|
162
166
|
|
|
163
167
|
|
|
168
|
+
class SessionAnnotationData(AnnotationData):
|
|
169
|
+
session_id: str = Field(description="Session ID")
|
|
170
|
+
|
|
171
|
+
def as_precursor(self, *, user_id: Optional[int] = None) -> Precursors.SessionAnnotation:
|
|
172
|
+
return Precursors.SessionAnnotation(
|
|
173
|
+
datetime.now(timezone.utc),
|
|
174
|
+
self.session_id,
|
|
175
|
+
models.ProjectSessionAnnotation(
|
|
176
|
+
name=self.name,
|
|
177
|
+
annotator_kind=self.annotator_kind,
|
|
178
|
+
score=self.result.score if self.result else None,
|
|
179
|
+
label=self.result.label if self.result else None,
|
|
180
|
+
explanation=self.result.explanation if self.result else None,
|
|
181
|
+
metadata_=self.metadata or {},
|
|
182
|
+
identifier=self.identifier,
|
|
183
|
+
source="API",
|
|
184
|
+
user_id=user_id,
|
|
185
|
+
),
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class SessionAnnotation(SessionAnnotationData, Annotation):
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class SessionAnnotationsResponseBody(PaginatedResponseBody[SessionAnnotation]):
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
|
|
164
197
|
@router.get(
|
|
165
198
|
"/projects/{project_identifier}/span_annotations",
|
|
166
199
|
operation_id="listSpanAnnotationsBySpanIds",
|
|
@@ -304,3 +337,290 @@ async def list_span_annotations(
|
|
|
304
337
|
]
|
|
305
338
|
|
|
306
339
|
return SpanAnnotationsResponseBody(data=data, next_cursor=next_cursor)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@router.get(
|
|
343
|
+
"/projects/{project_identifier}/trace_annotations",
|
|
344
|
+
operation_id="listTraceAnnotationsByTraceIds",
|
|
345
|
+
summary="Get trace annotations for a list of trace_ids.",
|
|
346
|
+
status_code=HTTP_200_OK,
|
|
347
|
+
responses=add_errors_to_responses(
|
|
348
|
+
[
|
|
349
|
+
{"status_code": HTTP_404_NOT_FOUND, "description": "Project or traces not found"},
|
|
350
|
+
{"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
|
|
351
|
+
]
|
|
352
|
+
),
|
|
353
|
+
)
|
|
354
|
+
async def list_trace_annotations(
|
|
355
|
+
request: Request,
|
|
356
|
+
project_identifier: str = Path(
|
|
357
|
+
description=(
|
|
358
|
+
"The project identifier: either project ID or project name. If using a project name as "
|
|
359
|
+
"the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
|
|
360
|
+
"characters."
|
|
361
|
+
)
|
|
362
|
+
),
|
|
363
|
+
trace_ids: list[str] = Query(
|
|
364
|
+
..., min_length=1, description="One or more trace id to fetch annotations for"
|
|
365
|
+
),
|
|
366
|
+
include_annotation_names: Optional[list[str]] = Query(
|
|
367
|
+
default=None,
|
|
368
|
+
description=(
|
|
369
|
+
"Optional list of annotation names to include. If provided, only annotations with "
|
|
370
|
+
"these names will be returned. 'note' annotations are excluded by default unless "
|
|
371
|
+
"explicitly included in this list."
|
|
372
|
+
),
|
|
373
|
+
),
|
|
374
|
+
exclude_annotation_names: Optional[list[str]] = Query(
|
|
375
|
+
default=None, description="Optional list of annotation names to exclude from results."
|
|
376
|
+
),
|
|
377
|
+
cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
|
|
378
|
+
limit: int = Query(
|
|
379
|
+
default=10,
|
|
380
|
+
gt=0,
|
|
381
|
+
le=10000,
|
|
382
|
+
description="The maximum number of annotations to return in a single request",
|
|
383
|
+
),
|
|
384
|
+
) -> TraceAnnotationsResponseBody:
|
|
385
|
+
trace_ids = list({*trace_ids})
|
|
386
|
+
if len(trace_ids) > MAX_TRACE_IDS:
|
|
387
|
+
raise HTTPException(
|
|
388
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
389
|
+
detail=f"Too many trace_ids supplied: {len(trace_ids)} (max {MAX_TRACE_IDS})",
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
async with request.app.state.db() as session:
|
|
393
|
+
project = await _get_project_by_identifier(session, project_identifier)
|
|
394
|
+
if not project:
|
|
395
|
+
raise HTTPException(
|
|
396
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
397
|
+
detail=f"Project with identifier {project_identifier} not found",
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Build the base query
|
|
401
|
+
where_conditions = [
|
|
402
|
+
models.Project.id == project.id,
|
|
403
|
+
models.Trace.trace_id.in_(trace_ids),
|
|
404
|
+
]
|
|
405
|
+
|
|
406
|
+
# Add annotation name filtering
|
|
407
|
+
if include_annotation_names:
|
|
408
|
+
where_conditions.append(models.TraceAnnotation.name.in_(include_annotation_names))
|
|
409
|
+
|
|
410
|
+
if exclude_annotation_names:
|
|
411
|
+
where_conditions.append(models.TraceAnnotation.name.not_in(exclude_annotation_names))
|
|
412
|
+
|
|
413
|
+
stmt = (
|
|
414
|
+
select(models.Trace.trace_id, models.TraceAnnotation)
|
|
415
|
+
.join(models.Project, models.Trace.project_rowid == models.Project.id)
|
|
416
|
+
.join(models.TraceAnnotation, models.TraceAnnotation.trace_rowid == models.Trace.id)
|
|
417
|
+
.where(*where_conditions)
|
|
418
|
+
.order_by(models.TraceAnnotation.id.desc())
|
|
419
|
+
.limit(limit + 1)
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if cursor:
|
|
423
|
+
try:
|
|
424
|
+
cursor_id = int(GlobalID.from_id(cursor).node_id)
|
|
425
|
+
except ValueError:
|
|
426
|
+
raise HTTPException(
|
|
427
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
428
|
+
detail="Invalid cursor value",
|
|
429
|
+
)
|
|
430
|
+
stmt = stmt.where(models.TraceAnnotation.id <= cursor_id)
|
|
431
|
+
|
|
432
|
+
rows: list[tuple[str, models.TraceAnnotation]] = [
|
|
433
|
+
r async for r in (await session.stream(stmt))
|
|
434
|
+
]
|
|
435
|
+
|
|
436
|
+
next_cursor: Optional[str] = None
|
|
437
|
+
if len(rows) == limit + 1:
|
|
438
|
+
*rows, extra = rows
|
|
439
|
+
next_cursor = str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(extra[1].id)))
|
|
440
|
+
|
|
441
|
+
if not rows:
|
|
442
|
+
traces_exist = await session.scalar(
|
|
443
|
+
select(
|
|
444
|
+
exists().where(
|
|
445
|
+
models.Trace.trace_id.in_(trace_ids),
|
|
446
|
+
models.Trace.project_rowid == project.id,
|
|
447
|
+
)
|
|
448
|
+
)
|
|
449
|
+
)
|
|
450
|
+
if not traces_exist:
|
|
451
|
+
raise HTTPException(
|
|
452
|
+
detail="None of the supplied trace_ids exist in this project",
|
|
453
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
return TraceAnnotationsResponseBody(data=[], next_cursor=None)
|
|
457
|
+
|
|
458
|
+
data = [
|
|
459
|
+
TraceAnnotation(
|
|
460
|
+
id=str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(anno.id))),
|
|
461
|
+
trace_id=trace_id,
|
|
462
|
+
name=anno.name,
|
|
463
|
+
result=AnnotationResult(
|
|
464
|
+
label=anno.label,
|
|
465
|
+
score=anno.score,
|
|
466
|
+
explanation=anno.explanation,
|
|
467
|
+
),
|
|
468
|
+
metadata=anno.metadata_,
|
|
469
|
+
annotator_kind=anno.annotator_kind,
|
|
470
|
+
created_at=anno.created_at,
|
|
471
|
+
updated_at=anno.updated_at,
|
|
472
|
+
identifier=anno.identifier,
|
|
473
|
+
source=anno.source,
|
|
474
|
+
user_id=str(GlobalID("User", str(anno.user_id))) if anno.user_id else None,
|
|
475
|
+
)
|
|
476
|
+
for trace_id, anno in rows
|
|
477
|
+
]
|
|
478
|
+
|
|
479
|
+
return TraceAnnotationsResponseBody(data=data, next_cursor=next_cursor)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
@router.get(
|
|
483
|
+
"/projects/{project_identifier}/session_annotations",
|
|
484
|
+
operation_id="listSessionAnnotationsBySessionIds",
|
|
485
|
+
summary="Get session annotations for a list of session_ids.",
|
|
486
|
+
status_code=HTTP_200_OK,
|
|
487
|
+
responses=add_errors_to_responses(
|
|
488
|
+
[
|
|
489
|
+
{"status_code": HTTP_404_NOT_FOUND, "description": "Project or sessions not found"},
|
|
490
|
+
{"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
|
|
491
|
+
]
|
|
492
|
+
),
|
|
493
|
+
)
|
|
494
|
+
async def list_session_annotations(
|
|
495
|
+
request: Request,
|
|
496
|
+
project_identifier: str = Path(
|
|
497
|
+
description=(
|
|
498
|
+
"The project identifier: either project ID or project name. If using a project name as "
|
|
499
|
+
"the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
|
|
500
|
+
"characters."
|
|
501
|
+
)
|
|
502
|
+
),
|
|
503
|
+
session_ids: list[str] = Query(
|
|
504
|
+
..., min_length=1, description="One or more session id to fetch annotations for"
|
|
505
|
+
),
|
|
506
|
+
include_annotation_names: Optional[list[str]] = Query(
|
|
507
|
+
default=None,
|
|
508
|
+
description=(
|
|
509
|
+
"Optional list of annotation names to include. If provided, only annotations with "
|
|
510
|
+
"these names will be returned. 'note' annotations are excluded by default unless "
|
|
511
|
+
"explicitly included in this list."
|
|
512
|
+
),
|
|
513
|
+
),
|
|
514
|
+
exclude_annotation_names: Optional[list[str]] = Query(
|
|
515
|
+
default=None, description="Optional list of annotation names to exclude from results."
|
|
516
|
+
),
|
|
517
|
+
cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
|
|
518
|
+
limit: int = Query(
|
|
519
|
+
default=10,
|
|
520
|
+
gt=0,
|
|
521
|
+
le=10000,
|
|
522
|
+
description="The maximum number of annotations to return in a single request",
|
|
523
|
+
),
|
|
524
|
+
) -> SessionAnnotationsResponseBody:
|
|
525
|
+
session_ids = list({*session_ids})
|
|
526
|
+
if len(session_ids) > MAX_SESSION_IDS:
|
|
527
|
+
raise HTTPException(
|
|
528
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
529
|
+
detail=f"Too many session_ids supplied: {len(session_ids)} (max {MAX_SESSION_IDS})",
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
async with request.app.state.db() as session:
|
|
533
|
+
project = await _get_project_by_identifier(session, project_identifier)
|
|
534
|
+
if not project:
|
|
535
|
+
raise HTTPException(
|
|
536
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
537
|
+
detail=f"Project with identifier {project_identifier} not found",
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
# Build the base query
|
|
541
|
+
where_conditions = [
|
|
542
|
+
models.Project.id == project.id,
|
|
543
|
+
models.ProjectSession.session_id.in_(session_ids),
|
|
544
|
+
]
|
|
545
|
+
|
|
546
|
+
# Add annotation name filtering
|
|
547
|
+
if include_annotation_names:
|
|
548
|
+
where_conditions.append(
|
|
549
|
+
models.ProjectSessionAnnotation.name.in_(include_annotation_names)
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
if exclude_annotation_names:
|
|
553
|
+
where_conditions.append(
|
|
554
|
+
models.ProjectSessionAnnotation.name.not_in(exclude_annotation_names)
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
stmt = (
|
|
558
|
+
select(models.ProjectSession.session_id, models.ProjectSessionAnnotation)
|
|
559
|
+
.join(models.Project, models.ProjectSession.project_id == models.Project.id)
|
|
560
|
+
.join(
|
|
561
|
+
models.ProjectSessionAnnotation,
|
|
562
|
+
models.ProjectSessionAnnotation.project_session_id == models.ProjectSession.id,
|
|
563
|
+
)
|
|
564
|
+
.where(*where_conditions)
|
|
565
|
+
.order_by(models.ProjectSessionAnnotation.id.desc())
|
|
566
|
+
.limit(limit + 1)
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
if cursor:
|
|
570
|
+
try:
|
|
571
|
+
cursor_id = int(GlobalID.from_id(cursor).node_id)
|
|
572
|
+
except ValueError:
|
|
573
|
+
raise HTTPException(
|
|
574
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
575
|
+
detail="Invalid cursor value",
|
|
576
|
+
)
|
|
577
|
+
stmt = stmt.where(models.ProjectSessionAnnotation.id <= cursor_id)
|
|
578
|
+
|
|
579
|
+
rows: list[tuple[str, models.ProjectSessionAnnotation]] = [
|
|
580
|
+
r async for r in (await session.stream(stmt))
|
|
581
|
+
]
|
|
582
|
+
|
|
583
|
+
next_cursor: Optional[str] = None
|
|
584
|
+
if len(rows) == limit + 1:
|
|
585
|
+
*rows, extra = rows
|
|
586
|
+
next_cursor = str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(extra[1].id)))
|
|
587
|
+
|
|
588
|
+
if not rows:
|
|
589
|
+
sessions_exist = await session.scalar(
|
|
590
|
+
select(
|
|
591
|
+
exists().where(
|
|
592
|
+
models.ProjectSession.session_id.in_(session_ids),
|
|
593
|
+
models.ProjectSession.project_id == project.id,
|
|
594
|
+
)
|
|
595
|
+
)
|
|
596
|
+
)
|
|
597
|
+
if not sessions_exist:
|
|
598
|
+
raise HTTPException(
|
|
599
|
+
detail="None of the supplied session_ids exist in this project",
|
|
600
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
return SessionAnnotationsResponseBody(data=[], next_cursor=None)
|
|
604
|
+
|
|
605
|
+
data = [
|
|
606
|
+
SessionAnnotation(
|
|
607
|
+
id=str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(anno.id))),
|
|
608
|
+
session_id=session_id,
|
|
609
|
+
name=anno.name,
|
|
610
|
+
result=AnnotationResult(
|
|
611
|
+
label=anno.label,
|
|
612
|
+
score=anno.score,
|
|
613
|
+
explanation=anno.explanation,
|
|
614
|
+
),
|
|
615
|
+
metadata=anno.metadata_,
|
|
616
|
+
annotator_kind=anno.annotator_kind,
|
|
617
|
+
created_at=anno.created_at,
|
|
618
|
+
updated_at=anno.updated_at,
|
|
619
|
+
identifier=anno.identifier,
|
|
620
|
+
source=anno.source,
|
|
621
|
+
user_id=str(GlobalID(USER_NODE_NAME, str(anno.user_id))) if anno.user_id else None,
|
|
622
|
+
)
|
|
623
|
+
for session_id, anno in rows
|
|
624
|
+
]
|
|
625
|
+
|
|
626
|
+
return SessionAnnotationsResponseBody(data=data, next_cursor=next_cursor)
|
|
@@ -48,6 +48,7 @@ from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVer
|
|
|
48
48
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
49
49
|
from phoenix.server.api.utils import delete_projects, delete_traces
|
|
50
50
|
from phoenix.server.authorization import is_not_locked
|
|
51
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
51
52
|
from phoenix.server.dml_event import DatasetInsertEvent
|
|
52
53
|
|
|
53
54
|
from .models import V1RoutesBaseModel
|
|
@@ -478,6 +479,9 @@ async def upload_dataset(
|
|
|
478
479
|
detail="Invalid request Content-Type",
|
|
479
480
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
480
481
|
)
|
|
482
|
+
user_id: Optional[int] = None
|
|
483
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
484
|
+
user_id = int(request.user.identity)
|
|
481
485
|
operation = cast(
|
|
482
486
|
Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
|
|
483
487
|
partial(
|
|
@@ -486,6 +490,7 @@ async def upload_dataset(
|
|
|
486
490
|
action=action,
|
|
487
491
|
name=name,
|
|
488
492
|
description=description,
|
|
493
|
+
user_id=user_id,
|
|
489
494
|
),
|
|
490
495
|
)
|
|
491
496
|
if sync:
|
|
@@ -15,10 +15,14 @@ from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESS
|
|
|
15
15
|
from strawberry.relay import GlobalID
|
|
16
16
|
|
|
17
17
|
from phoenix.db import models
|
|
18
|
-
from phoenix.db.helpers import
|
|
18
|
+
from phoenix.db.helpers import (
|
|
19
|
+
SupportedSQLDialect,
|
|
20
|
+
insert_experiment_with_examples_snapshot,
|
|
21
|
+
)
|
|
19
22
|
from phoenix.db.insertion.helpers import insert_on_conflict
|
|
20
23
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
21
24
|
from phoenix.server.authorization import is_not_locked
|
|
25
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
22
26
|
from phoenix.server.dml_event import ExperimentInsertEvent
|
|
23
27
|
from phoenix.server.experiments.utils import generate_experiment_project_name
|
|
24
28
|
|
|
@@ -157,6 +161,9 @@ async def create_experiment(
|
|
|
157
161
|
detail=f"DatasetVersion with ID {dataset_version_globalid} does not exist",
|
|
158
162
|
status_code=HTTP_404_NOT_FOUND,
|
|
159
163
|
)
|
|
164
|
+
user_id: Optional[int] = None
|
|
165
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
166
|
+
user_id = int(request.user.identity)
|
|
160
167
|
|
|
161
168
|
# generate a semi-unique name for the experiment
|
|
162
169
|
experiment_name = request_body.name or _generate_experiment_name(dataset_name)
|
|
@@ -172,9 +179,9 @@ async def create_experiment(
|
|
|
172
179
|
repetitions=request_body.repetitions,
|
|
173
180
|
metadata_=request_body.metadata or {},
|
|
174
181
|
project_name=project_name,
|
|
182
|
+
user_id=user_id,
|
|
175
183
|
)
|
|
176
|
-
session
|
|
177
|
-
await session.flush()
|
|
184
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
178
185
|
|
|
179
186
|
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
180
187
|
project_rowid = await session.scalar(
|