arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,18 +4,11 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
|
|
4
4
|
from pydantic import Field
|
|
5
5
|
from sqlalchemy import select
|
|
6
6
|
from starlette.requests import Request
|
|
7
|
-
from starlette.status import (
|
|
8
|
-
HTTP_204_NO_CONTENT,
|
|
9
|
-
HTTP_403_FORBIDDEN,
|
|
10
|
-
HTTP_404_NOT_FOUND,
|
|
11
|
-
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
12
|
-
)
|
|
13
7
|
from strawberry.relay import GlobalID
|
|
14
8
|
|
|
15
9
|
from phoenix.config import DEFAULT_PROJECT_NAME
|
|
16
10
|
from phoenix.db import models
|
|
17
11
|
from phoenix.db.helpers import exclude_experiment_projects
|
|
18
|
-
from phoenix.db.models import UserRoleName
|
|
19
12
|
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
20
13
|
from phoenix.server.api.routers.v1.utils import (
|
|
21
14
|
PaginatedResponseBody,
|
|
@@ -24,7 +17,7 @@ from phoenix.server.api.routers.v1.utils import (
|
|
|
24
17
|
add_errors_to_responses,
|
|
25
18
|
)
|
|
26
19
|
from phoenix.server.api.types.Project import Project as ProjectNodeType
|
|
27
|
-
from phoenix.server.authorization import is_not_locked
|
|
20
|
+
from phoenix.server.authorization import is_not_locked, require_admin
|
|
28
21
|
|
|
29
22
|
router = APIRouter(tags=["projects"])
|
|
30
23
|
|
|
@@ -70,7 +63,7 @@ class UpdateProjectResponseBody(ResponseBody[Project]):
|
|
|
70
63
|
response_description="A list of projects with pagination information", # noqa: E501
|
|
71
64
|
responses=add_errors_to_responses(
|
|
72
65
|
[
|
|
73
|
-
|
|
66
|
+
422,
|
|
74
67
|
]
|
|
75
68
|
),
|
|
76
69
|
)
|
|
@@ -115,7 +108,7 @@ async def get_projects(
|
|
|
115
108
|
except ValueError:
|
|
116
109
|
raise HTTPException(
|
|
117
110
|
detail=f"Invalid cursor format: {cursor}",
|
|
118
|
-
status_code=
|
|
111
|
+
status_code=422,
|
|
119
112
|
)
|
|
120
113
|
|
|
121
114
|
stmt = stmt.limit(limit + 1)
|
|
@@ -142,8 +135,8 @@ async def get_projects(
|
|
|
142
135
|
response_description="The requested project", # noqa: E501
|
|
143
136
|
responses=add_errors_to_responses(
|
|
144
137
|
[
|
|
145
|
-
|
|
146
|
-
|
|
138
|
+
404,
|
|
139
|
+
422,
|
|
147
140
|
]
|
|
148
141
|
),
|
|
149
142
|
)
|
|
@@ -182,7 +175,7 @@ async def get_project(
|
|
|
182
175
|
response_description="The newly created project", # noqa: E501
|
|
183
176
|
responses=add_errors_to_responses(
|
|
184
177
|
[
|
|
185
|
-
|
|
178
|
+
422,
|
|
186
179
|
]
|
|
187
180
|
),
|
|
188
181
|
)
|
|
@@ -216,16 +209,16 @@ async def create_project(
|
|
|
216
209
|
|
|
217
210
|
@router.put(
|
|
218
211
|
"/projects/{project_identifier}",
|
|
219
|
-
dependencies=[Depends(is_not_locked)],
|
|
212
|
+
dependencies=[Depends(require_admin), Depends(is_not_locked)],
|
|
220
213
|
operation_id="updateProject",
|
|
221
214
|
summary="Update a project by ID or name", # noqa: E501
|
|
222
215
|
description="Update an existing project with new configuration. Project names cannot be changed. The project identifier is either project ID or project name. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
|
|
223
216
|
response_description="The updated project", # noqa: E501
|
|
224
217
|
responses=add_errors_to_responses(
|
|
225
218
|
[
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
219
|
+
403,
|
|
220
|
+
404,
|
|
221
|
+
422,
|
|
229
222
|
]
|
|
230
223
|
),
|
|
231
224
|
)
|
|
@@ -251,20 +244,6 @@ async def update_project(
|
|
|
251
244
|
Raises:
|
|
252
245
|
HTTPException: If the project identifier format is invalid or the project is not found.
|
|
253
246
|
""" # noqa: E501
|
|
254
|
-
if request.app.state.authentication_enabled:
|
|
255
|
-
async with request.app.state.db() as session:
|
|
256
|
-
# Check if the user is an admin
|
|
257
|
-
stmt = (
|
|
258
|
-
select(models.UserRole.name)
|
|
259
|
-
.join(models.User)
|
|
260
|
-
.where(models.User.id == int(request.user.identity))
|
|
261
|
-
)
|
|
262
|
-
role_name: UserRoleName = await session.scalar(stmt)
|
|
263
|
-
if role_name != "ADMIN" and role_name != "SYSTEM":
|
|
264
|
-
raise HTTPException(
|
|
265
|
-
status_code=HTTP_403_FORBIDDEN,
|
|
266
|
-
detail="Only admins can update projects",
|
|
267
|
-
)
|
|
268
247
|
async with request.app.state.db() as session:
|
|
269
248
|
project = await _get_project_by_identifier(session, project_identifier)
|
|
270
249
|
|
|
@@ -278,16 +257,17 @@ async def update_project(
|
|
|
278
257
|
|
|
279
258
|
@router.delete(
|
|
280
259
|
"/projects/{project_identifier}",
|
|
260
|
+
dependencies=[Depends(require_admin)],
|
|
281
261
|
operation_id="deleteProject",
|
|
282
262
|
summary="Delete a project by ID or name", # noqa: E501
|
|
283
263
|
description="Delete an existing project and all its associated data. The project identifier is either project ID or project name. The default project cannot be deleted. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
|
|
284
264
|
response_description="No content returned on successful deletion", # noqa: E501
|
|
285
|
-
status_code=
|
|
265
|
+
status_code=204,
|
|
286
266
|
responses=add_errors_to_responses(
|
|
287
267
|
[
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
268
|
+
403,
|
|
269
|
+
404,
|
|
270
|
+
422,
|
|
291
271
|
]
|
|
292
272
|
),
|
|
293
273
|
)
|
|
@@ -311,27 +291,13 @@ async def delete_project(
|
|
|
311
291
|
Raises:
|
|
312
292
|
HTTPException: If the project identifier format is invalid, the project is not found, or it's the default project.
|
|
313
293
|
""" # noqa: E501
|
|
314
|
-
if request.app.state.authentication_enabled:
|
|
315
|
-
async with request.app.state.db() as session:
|
|
316
|
-
# Check if the user is an admin
|
|
317
|
-
stmt = (
|
|
318
|
-
select(models.UserRole.name)
|
|
319
|
-
.join(models.User)
|
|
320
|
-
.where(models.User.id == int(request.user.identity))
|
|
321
|
-
)
|
|
322
|
-
role_name: UserRoleName = await session.scalar(stmt)
|
|
323
|
-
if role_name != "ADMIN" and role_name != "SYSTEM":
|
|
324
|
-
raise HTTPException(
|
|
325
|
-
status_code=HTTP_403_FORBIDDEN,
|
|
326
|
-
detail="Only admins can delete projects",
|
|
327
|
-
)
|
|
328
294
|
async with request.app.state.db() as session:
|
|
329
295
|
project = await _get_project_by_identifier(session, project_identifier)
|
|
330
296
|
|
|
331
297
|
# The default project must not be deleted - it's forbidden
|
|
332
298
|
if project.name == DEFAULT_PROJECT_NAME:
|
|
333
299
|
raise HTTPException(
|
|
334
|
-
status_code=
|
|
300
|
+
status_code=403,
|
|
335
301
|
detail="The default project cannot be deleted",
|
|
336
302
|
)
|
|
337
303
|
|
|
@@ -4,9 +4,9 @@ from typing import Any, Optional, Union
|
|
|
4
4
|
from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
|
5
5
|
from pydantic import ValidationError, model_validator
|
|
6
6
|
from sqlalchemy import select
|
|
7
|
+
from sqlalchemy.orm import joinedload
|
|
7
8
|
from sqlalchemy.sql import Select
|
|
8
9
|
from starlette.requests import Request
|
|
9
|
-
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY
|
|
10
10
|
from strawberry.relay import GlobalID
|
|
11
11
|
from typing_extensions import Self, TypeAlias, assert_never
|
|
12
12
|
|
|
@@ -43,6 +43,7 @@ class PromptData(V1RoutesBaseModel):
|
|
|
43
43
|
name: Identifier
|
|
44
44
|
description: Optional[str] = None
|
|
45
45
|
source_prompt_id: Optional[str] = None
|
|
46
|
+
metadata: Optional[dict[str, Any]] = None
|
|
46
47
|
|
|
47
48
|
|
|
48
49
|
class Prompt(PromptData):
|
|
@@ -110,7 +111,7 @@ router = APIRouter(tags=["prompts"])
|
|
|
110
111
|
response_description="A list of prompts with pagination information",
|
|
111
112
|
responses=add_errors_to_responses(
|
|
112
113
|
[
|
|
113
|
-
|
|
114
|
+
422,
|
|
114
115
|
]
|
|
115
116
|
),
|
|
116
117
|
)
|
|
@@ -154,7 +155,7 @@ async def get_prompts(
|
|
|
154
155
|
except ValueError:
|
|
155
156
|
raise HTTPException(
|
|
156
157
|
detail=f"Invalid cursor format: {cursor}",
|
|
157
|
-
status_code=
|
|
158
|
+
status_code=422,
|
|
158
159
|
)
|
|
159
160
|
|
|
160
161
|
query = query.limit(limit + 1)
|
|
@@ -181,7 +182,7 @@ async def get_prompts(
|
|
|
181
182
|
description="Retrieve all versions of a specific prompt with pagination support. Each prompt "
|
|
182
183
|
"can have multiple versions with different configurations.",
|
|
183
184
|
response_description="A list of prompt versions with pagination information",
|
|
184
|
-
responses=add_errors_to_responses([
|
|
185
|
+
responses=add_errors_to_responses([422, 404]),
|
|
185
186
|
response_model_by_alias=True,
|
|
186
187
|
response_model_exclude_defaults=True,
|
|
187
188
|
response_model_exclude_unset=True,
|
|
@@ -214,7 +215,7 @@ async def list_prompt_versions(
|
|
|
214
215
|
HTTPException: If the cursor format is invalid, the prompt identifier is invalid,
|
|
215
216
|
or the prompt is not found.
|
|
216
217
|
"""
|
|
217
|
-
query = select(models.PromptVersion)
|
|
218
|
+
query = select(models.PromptVersion).options(joinedload(models.PromptVersion.prompt))
|
|
218
219
|
query = _filter_by_prompt_identifier(query.join(models.Prompt), prompt_identifier)
|
|
219
220
|
query = query.order_by(models.PromptVersion.id.desc())
|
|
220
221
|
|
|
@@ -226,7 +227,7 @@ async def list_prompt_versions(
|
|
|
226
227
|
except ValueError:
|
|
227
228
|
raise HTTPException(
|
|
228
229
|
detail=f"Invalid cursor format: {cursor}",
|
|
229
|
-
status_code=
|
|
230
|
+
status_code=422,
|
|
230
231
|
)
|
|
231
232
|
|
|
232
233
|
query = query.limit(limit + 1)
|
|
@@ -255,8 +256,8 @@ async def list_prompt_versions(
|
|
|
255
256
|
response_description="The requested prompt version",
|
|
256
257
|
responses=add_errors_to_responses(
|
|
257
258
|
[
|
|
258
|
-
|
|
259
|
-
|
|
259
|
+
404,
|
|
260
|
+
422,
|
|
260
261
|
]
|
|
261
262
|
),
|
|
262
263
|
response_model_by_alias=True,
|
|
@@ -286,11 +287,16 @@ async def get_prompt_version_by_prompt_version_id(
|
|
|
286
287
|
PromptVersionNodeType.__name__,
|
|
287
288
|
)
|
|
288
289
|
except ValueError:
|
|
289
|
-
raise HTTPException(
|
|
290
|
+
raise HTTPException(422, "Invalid prompt version ID")
|
|
290
291
|
async with request.app.state.db() as session:
|
|
291
|
-
|
|
292
|
+
stmt = (
|
|
293
|
+
select(models.PromptVersion)
|
|
294
|
+
.options(joinedload(models.PromptVersion.prompt))
|
|
295
|
+
.where(models.PromptVersion.id == id_)
|
|
296
|
+
)
|
|
297
|
+
prompt_version = await session.scalar(stmt)
|
|
292
298
|
if prompt_version is None:
|
|
293
|
-
raise HTTPException(
|
|
299
|
+
raise HTTPException(404)
|
|
294
300
|
data = _prompt_version_from_orm_version(prompt_version)
|
|
295
301
|
return GetPromptResponseBody(data=data)
|
|
296
302
|
|
|
@@ -304,8 +310,8 @@ async def get_prompt_version_by_prompt_version_id(
|
|
|
304
310
|
response_description="The prompt version with the specified tag",
|
|
305
311
|
responses=add_errors_to_responses(
|
|
306
312
|
[
|
|
307
|
-
|
|
308
|
-
|
|
313
|
+
404,
|
|
314
|
+
422,
|
|
309
315
|
]
|
|
310
316
|
),
|
|
311
317
|
response_model_by_alias=True,
|
|
@@ -334,9 +340,10 @@ async def get_prompt_version_by_tag_name(
|
|
|
334
340
|
try:
|
|
335
341
|
name = Identifier.model_validate(tag_name)
|
|
336
342
|
except ValidationError:
|
|
337
|
-
raise HTTPException(
|
|
343
|
+
raise HTTPException(422, "Invalid tag name")
|
|
338
344
|
stmt = (
|
|
339
345
|
select(models.PromptVersion)
|
|
346
|
+
.options(joinedload(models.PromptVersion.prompt))
|
|
340
347
|
.join_from(models.PromptVersion, models.PromptVersionTag)
|
|
341
348
|
.where(models.PromptVersionTag.name == name)
|
|
342
349
|
)
|
|
@@ -344,7 +351,7 @@ async def get_prompt_version_by_tag_name(
|
|
|
344
351
|
async with request.app.state.db() as session:
|
|
345
352
|
prompt_version: models.PromptVersion = await session.scalar(stmt)
|
|
346
353
|
if prompt_version is None:
|
|
347
|
-
raise HTTPException(
|
|
354
|
+
raise HTTPException(404)
|
|
348
355
|
data = _prompt_version_from_orm_version(prompt_version)
|
|
349
356
|
return GetPromptResponseBody(data=data)
|
|
350
357
|
|
|
@@ -357,8 +364,8 @@ async def get_prompt_version_by_tag_name(
|
|
|
357
364
|
response_description="The latest version of the specified prompt",
|
|
358
365
|
responses=add_errors_to_responses(
|
|
359
366
|
[
|
|
360
|
-
|
|
361
|
-
|
|
367
|
+
404,
|
|
368
|
+
422,
|
|
362
369
|
]
|
|
363
370
|
),
|
|
364
371
|
response_model_by_alias=True,
|
|
@@ -382,12 +389,17 @@ async def get_prompt_version_by_latest(
|
|
|
382
389
|
Raises:
|
|
383
390
|
HTTPException: If the prompt identifier is invalid or no prompt version is found.
|
|
384
391
|
"""
|
|
385
|
-
stmt =
|
|
392
|
+
stmt = (
|
|
393
|
+
select(models.PromptVersion)
|
|
394
|
+
.options(joinedload(models.PromptVersion.prompt))
|
|
395
|
+
.order_by(models.PromptVersion.id.desc())
|
|
396
|
+
.limit(1)
|
|
397
|
+
)
|
|
386
398
|
stmt = _filter_by_prompt_identifier(stmt.join(models.Prompt), prompt_identifier)
|
|
387
399
|
async with request.app.state.db() as session:
|
|
388
400
|
prompt_version: models.PromptVersion = await session.scalar(stmt)
|
|
389
401
|
if prompt_version is None:
|
|
390
|
-
raise HTTPException(
|
|
402
|
+
raise HTTPException(404)
|
|
391
403
|
data = _prompt_version_from_orm_version(prompt_version)
|
|
392
404
|
return GetPromptResponseBody(data=data)
|
|
393
405
|
|
|
@@ -401,7 +413,7 @@ async def get_prompt_version_by_latest(
|
|
|
401
413
|
response_description="The newly created prompt version",
|
|
402
414
|
responses=add_errors_to_responses(
|
|
403
415
|
[
|
|
404
|
-
|
|
416
|
+
422,
|
|
405
417
|
]
|
|
406
418
|
),
|
|
407
419
|
response_model_by_alias=True,
|
|
@@ -431,7 +443,7 @@ async def create_prompt(
|
|
|
431
443
|
or request_body.version.template_type != PromptTemplateType.CHAT
|
|
432
444
|
):
|
|
433
445
|
raise HTTPException(
|
|
434
|
-
|
|
446
|
+
422,
|
|
435
447
|
"Only CHAT template type is supported for prompts",
|
|
436
448
|
)
|
|
437
449
|
prompt = request_body.prompt
|
|
@@ -439,7 +451,7 @@ async def create_prompt(
|
|
|
439
451
|
name = Identifier.model_validate(prompt.name)
|
|
440
452
|
except ValidationError as e:
|
|
441
453
|
raise HTTPException(
|
|
442
|
-
|
|
454
|
+
422,
|
|
443
455
|
"Invalid name identifier for prompt: " + e.errors()[0]["msg"],
|
|
444
456
|
)
|
|
445
457
|
version = request_body.version
|
|
@@ -448,17 +460,15 @@ async def create_prompt(
|
|
|
448
460
|
assert isinstance(user := request.user, PhoenixUser)
|
|
449
461
|
user_id = int(user.identity)
|
|
450
462
|
async with request.app.state.db() as session:
|
|
451
|
-
if not (
|
|
463
|
+
if not (prompt_orm := await session.scalar(select(models.Prompt).filter_by(name=name))):
|
|
452
464
|
prompt_orm = models.Prompt(
|
|
453
465
|
name=name,
|
|
454
466
|
description=prompt.description,
|
|
467
|
+
metadata_=prompt.metadata or {},
|
|
455
468
|
)
|
|
456
|
-
session.add(prompt_orm)
|
|
457
|
-
await session.flush()
|
|
458
|
-
prompt_id = prompt_orm.id
|
|
459
469
|
version_orm = models.PromptVersion(
|
|
460
470
|
user_id=user_id,
|
|
461
|
-
|
|
471
|
+
prompt=prompt_orm,
|
|
462
472
|
description=version.description,
|
|
463
473
|
model_provider=version.model_provider,
|
|
464
474
|
model_name=version.model_name,
|
|
@@ -496,8 +506,8 @@ class GetPromptVersionTagsResponseBody(PaginatedResponseBody[PromptVersionTag]):
|
|
|
496
506
|
response_description="A list of tags associated with the prompt version",
|
|
497
507
|
responses=add_errors_to_responses(
|
|
498
508
|
[
|
|
499
|
-
|
|
500
|
-
|
|
509
|
+
404,
|
|
510
|
+
422,
|
|
501
511
|
]
|
|
502
512
|
),
|
|
503
513
|
response_model_by_alias=True,
|
|
@@ -537,7 +547,7 @@ async def list_prompt_version_tags(
|
|
|
537
547
|
PromptVersionNodeType.__name__,
|
|
538
548
|
)
|
|
539
549
|
except ValueError:
|
|
540
|
-
raise HTTPException(
|
|
550
|
+
raise HTTPException(422, "Invalid prompt version ID")
|
|
541
551
|
|
|
542
552
|
# Build the query for tags
|
|
543
553
|
stmt = (
|
|
@@ -560,7 +570,7 @@ async def list_prompt_version_tags(
|
|
|
560
570
|
except ValueError:
|
|
561
571
|
raise HTTPException(
|
|
562
572
|
detail=f"Invalid cursor format: {cursor}",
|
|
563
|
-
status_code=
|
|
573
|
+
status_code=422,
|
|
564
574
|
)
|
|
565
575
|
|
|
566
576
|
# Apply limit
|
|
@@ -571,7 +581,7 @@ async def list_prompt_version_tags(
|
|
|
571
581
|
|
|
572
582
|
# Check if prompt version exists
|
|
573
583
|
if not result:
|
|
574
|
-
raise HTTPException(
|
|
584
|
+
raise HTTPException(404, "Prompt version not found")
|
|
575
585
|
|
|
576
586
|
# Check if there are any tags
|
|
577
587
|
has_tags = any(id_ is not None for _, id_, _, _ in result)
|
|
@@ -610,11 +620,11 @@ async def list_prompt_version_tags(
|
|
|
610
620
|
description="Add a new tag to a specific prompt version. Tags help identify and categorize "
|
|
611
621
|
"different versions of a prompt.",
|
|
612
622
|
response_description="No content returned on successful tag creation",
|
|
613
|
-
status_code=
|
|
623
|
+
status_code=204,
|
|
614
624
|
responses=add_errors_to_responses(
|
|
615
625
|
[
|
|
616
|
-
|
|
617
|
-
|
|
626
|
+
404,
|
|
627
|
+
422,
|
|
618
628
|
]
|
|
619
629
|
),
|
|
620
630
|
response_model_by_alias=True,
|
|
@@ -647,7 +657,7 @@ async def create_prompt_version_tag(
|
|
|
647
657
|
PromptVersionNodeType.__name__,
|
|
648
658
|
)
|
|
649
659
|
except ValueError:
|
|
650
|
-
raise HTTPException(
|
|
660
|
+
raise HTTPException(422, "Invalid prompt version ID")
|
|
651
661
|
user_id: Optional[int] = None
|
|
652
662
|
if request.app.state.authentication_enabled:
|
|
653
663
|
assert isinstance(user := request.user, PhoenixUser)
|
|
@@ -655,7 +665,7 @@ async def create_prompt_version_tag(
|
|
|
655
665
|
async with request.app.state.db() as session:
|
|
656
666
|
prompt_id = await session.scalar(select(models.PromptVersion.prompt_id).filter_by(id=id_))
|
|
657
667
|
if prompt_id is None:
|
|
658
|
-
raise HTTPException(
|
|
668
|
+
raise HTTPException(404)
|
|
659
669
|
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
660
670
|
values = dict(
|
|
661
671
|
name=request_body.name,
|
|
@@ -686,7 +696,7 @@ def _parse_prompt_identifier(
|
|
|
686
696
|
prompt_identifier: str,
|
|
687
697
|
) -> _PromptIdentifier:
|
|
688
698
|
if not prompt_identifier:
|
|
689
|
-
raise HTTPException(
|
|
699
|
+
raise HTTPException(422, "Invalid prompt identifier")
|
|
690
700
|
try:
|
|
691
701
|
prompt_id = from_global_id_with_expected_type(
|
|
692
702
|
GlobalID.from_id(prompt_identifier),
|
|
@@ -696,7 +706,7 @@ def _parse_prompt_identifier(
|
|
|
696
706
|
try:
|
|
697
707
|
return Identifier.model_validate(prompt_identifier)
|
|
698
708
|
except ValidationError:
|
|
699
|
-
raise HTTPException(
|
|
709
|
+
raise HTTPException(422, "Invalid prompt name")
|
|
700
710
|
return _PromptId(prompt_id)
|
|
701
711
|
|
|
702
712
|
|
|
@@ -742,4 +752,5 @@ def _prompt_from_orm_prompt(orm_prompt: models.Prompt) -> Prompt:
|
|
|
742
752
|
source_prompt_id=source_prompt_id,
|
|
743
753
|
name=orm_prompt.name,
|
|
744
754
|
description=orm_prompt.description,
|
|
755
|
+
metadata=orm_prompt.metadata_,
|
|
745
756
|
)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
7
|
+
from pydantic import Field
|
|
8
|
+
from sqlalchemy import select
|
|
9
|
+
from starlette.requests import Request
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.db.helpers import SupportedSQLDialect
|
|
13
|
+
from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
|
|
14
|
+
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
15
|
+
from phoenix.server.authorization import is_not_locked
|
|
16
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
17
|
+
|
|
18
|
+
from .annotations import SessionAnnotationData
|
|
19
|
+
from .utils import RequestBody, ResponseBody, add_errors_to_responses
|
|
20
|
+
|
|
21
|
+
router = APIRouter(tags=["sessions"])
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class InsertedSessionAnnotation(V1RoutesBaseModel):
|
|
25
|
+
id: str = Field(description="The ID of the inserted session annotation")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AnnotateSessionsRequestBody(RequestBody[list[SessionAnnotationData]]):
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AnnotateSessionsResponseBody(ResponseBody[list[InsertedSessionAnnotation]]):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@router.post(
|
|
37
|
+
"/session_annotations",
|
|
38
|
+
dependencies=[Depends(is_not_locked)],
|
|
39
|
+
operation_id="annotateSessions",
|
|
40
|
+
summary="Create session annotations",
|
|
41
|
+
responses=add_errors_to_responses([{"status_code": 404, "description": "Session not found"}]),
|
|
42
|
+
response_description="Session annotations inserted successfully",
|
|
43
|
+
include_in_schema=True,
|
|
44
|
+
)
|
|
45
|
+
async def annotate_sessions(
|
|
46
|
+
request: Request,
|
|
47
|
+
request_body: AnnotateSessionsRequestBody,
|
|
48
|
+
sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
|
|
49
|
+
) -> AnnotateSessionsResponseBody:
|
|
50
|
+
if not request_body.data:
|
|
51
|
+
return AnnotateSessionsResponseBody(data=[])
|
|
52
|
+
|
|
53
|
+
user_id: Optional[int] = None
|
|
54
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
55
|
+
user_id = int(request.user.identity)
|
|
56
|
+
|
|
57
|
+
session_annotations = request_body.data
|
|
58
|
+
filtered_session_annotations = list(filter(lambda d: d.name != "note", session_annotations))
|
|
59
|
+
if len(filtered_session_annotations) != len(session_annotations):
|
|
60
|
+
warnings.warn(
|
|
61
|
+
(
|
|
62
|
+
"Session annotations with the name 'note' are not supported in this endpoint. "
|
|
63
|
+
"They will be ignored."
|
|
64
|
+
),
|
|
65
|
+
UserWarning,
|
|
66
|
+
)
|
|
67
|
+
precursors = [d.as_precursor(user_id=user_id) for d in filtered_session_annotations]
|
|
68
|
+
if not sync:
|
|
69
|
+
await request.state.enqueue_annotations(*precursors)
|
|
70
|
+
return AnnotateSessionsResponseBody(data=[])
|
|
71
|
+
|
|
72
|
+
session_ids = {p.session_id for p in precursors}
|
|
73
|
+
async with request.app.state.db() as session:
|
|
74
|
+
existing_sessions = {
|
|
75
|
+
session_id: rowid
|
|
76
|
+
async for session_id, rowid in await session.stream(
|
|
77
|
+
select(models.ProjectSession.session_id, models.ProjectSession.id).filter(
|
|
78
|
+
models.ProjectSession.session_id.in_(session_ids)
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
missing_session_ids = session_ids - set(existing_sessions.keys())
|
|
84
|
+
# We prefer to fail the entire operation if there are missing sessions in sync mode
|
|
85
|
+
if missing_session_ids:
|
|
86
|
+
raise HTTPException(
|
|
87
|
+
detail=f"Sessions with IDs {', '.join(missing_session_ids)} do not exist.",
|
|
88
|
+
status_code=404,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
async with request.app.state.db() as session:
|
|
92
|
+
inserted_ids = []
|
|
93
|
+
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
94
|
+
for p in precursors:
|
|
95
|
+
values = dict(as_kv(p.as_insertable(existing_sessions[p.session_id]).row))
|
|
96
|
+
session_annotation_id = await session.scalar(
|
|
97
|
+
insert_on_conflict(
|
|
98
|
+
values,
|
|
99
|
+
dialect=dialect,
|
|
100
|
+
table=models.ProjectSessionAnnotation,
|
|
101
|
+
unique_by=("name", "project_session_id", "identifier"),
|
|
102
|
+
).returning(models.ProjectSessionAnnotation.id)
|
|
103
|
+
)
|
|
104
|
+
inserted_ids.append(session_annotation_id)
|
|
105
|
+
|
|
106
|
+
return AnnotateSessionsResponseBody(
|
|
107
|
+
data=[InsertedSessionAnnotation(id=str(inserted_id)) for inserted_id in inserted_ids]
|
|
108
|
+
)
|