arize-phoenix 7.12.3__py3-none-any.whl → 8.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/METADATA +31 -28
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/RECORD +70 -47
- phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py +197 -0
- phoenix/db/models.py +307 -0
- phoenix/db/types/__init__.py +0 -0
- phoenix/db/types/identifier.py +7 -0
- phoenix/db/types/model_provider.py +8 -0
- phoenix/server/api/context.py +2 -0
- phoenix/server/api/dataloaders/__init__.py +2 -0
- phoenix/server/api/dataloaders/prompt_version_sequence_number.py +35 -0
- phoenix/server/api/helpers/jsonschema.py +135 -0
- phoenix/server/api/helpers/playground_clients.py +15 -15
- phoenix/server/api/helpers/playground_spans.py +9 -0
- phoenix/server/api/helpers/prompts/__init__.py +0 -0
- phoenix/server/api/helpers/prompts/conversions/__init__.py +0 -0
- phoenix/server/api/helpers/prompts/conversions/anthropic.py +87 -0
- phoenix/server/api/helpers/prompts/conversions/openai.py +78 -0
- phoenix/server/api/helpers/prompts/models.py +575 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +9 -4
- phoenix/server/api/input_types/PromptTemplateOptions.py +10 -0
- phoenix/server/api/input_types/PromptVersionInput.py +133 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/chat_mutations.py +18 -16
- phoenix/server/api/mutations/prompt_label_mutations.py +191 -0
- phoenix/server/api/mutations/prompt_mutations.py +312 -0
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +148 -0
- phoenix/server/api/mutations/user_mutations.py +7 -6
- phoenix/server/api/openapi/schema.py +1 -0
- phoenix/server/api/queries.py +84 -31
- phoenix/server/api/routers/oauth2.py +3 -2
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/datasets.py +1 -1
- phoenix/server/api/routers/v1/experiment_evaluations.py +1 -1
- phoenix/server/api/routers/v1/experiment_runs.py +1 -1
- phoenix/server/api/routers/v1/experiments.py +1 -1
- phoenix/server/api/routers/v1/models.py +45 -0
- phoenix/server/api/routers/v1/prompts.py +412 -0
- phoenix/server/api/routers/v1/spans.py +1 -1
- phoenix/server/api/routers/v1/traces.py +1 -1
- phoenix/server/api/routers/v1/utils.py +1 -1
- phoenix/server/api/subscriptions.py +21 -24
- phoenix/server/api/types/GenerativeProvider.py +4 -4
- phoenix/server/api/types/Identifier.py +15 -0
- phoenix/server/api/types/Project.py +5 -7
- phoenix/server/api/types/Prompt.py +134 -0
- phoenix/server/api/types/PromptLabel.py +41 -0
- phoenix/server/api/types/PromptVersion.py +148 -0
- phoenix/server/api/types/PromptVersionTag.py +27 -0
- phoenix/server/api/types/PromptVersionTemplate.py +148 -0
- phoenix/server/api/types/ResponseFormat.py +9 -0
- phoenix/server/api/types/ToolDefinition.py +9 -0
- phoenix/server/app.py +3 -0
- phoenix/server/static/.vite/manifest.json +45 -45
- phoenix/server/static/assets/components-B-qgPyHv.js +2699 -0
- phoenix/server/static/assets/index-D4KO1IcF.js +1125 -0
- phoenix/server/static/assets/pages-DdcuL3Rh.js +5634 -0
- phoenix/server/static/assets/vendor-DQp7CrDA.js +894 -0
- phoenix/server/static/assets/vendor-arizeai-C1nEIEQq.js +657 -0
- phoenix/server/static/assets/vendor-codemirror-BZXYUIkP.js +24 -0
- phoenix/server/static/assets/vendor-recharts-BUFpwCVD.js +59 -0
- phoenix/server/static/assets/{vendor-shiki-Cl9QBraO.js → vendor-shiki-C8L-c9jT.js} +2 -2
- phoenix/server/static/assets/{vendor-three-DwGkEfCM.js → vendor-three-C-AGeJYv.js} +1 -1
- phoenix/session/client.py +25 -21
- phoenix/utilities/client.py +6 -0
- phoenix/version.py +1 -1
- phoenix/server/api/input_types/TemplateOptions.py +0 -10
- phoenix/server/api/routers/v1/pydantic_compat.py +0 -78
- phoenix/server/api/types/TemplateLanguage.py +0 -10
- phoenix/server/static/assets/components-DckIzNmE.js +0 -2125
- phoenix/server/static/assets/index-Bf25Ogon.js +0 -113
- phoenix/server/static/assets/pages-DL7J9q9w.js +0 -4463
- phoenix/server/static/assets/vendor-DvC8cT4X.js +0 -894
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +0 -662
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +0 -24
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +0 -59
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/static/assets/{vendor-DxkFTwjz.css → vendor-Cg6lcjUC.css} +0 -0
|
@@ -9,6 +9,7 @@ from .evaluations import router as evaluations_router
|
|
|
9
9
|
from .experiment_evaluations import router as experiment_evaluations_router
|
|
10
10
|
from .experiment_runs import router as experiment_runs_router
|
|
11
11
|
from .experiments import router as experiments_router
|
|
12
|
+
from .prompts import router as prompts_router
|
|
12
13
|
from .spans import router as spans_router
|
|
13
14
|
from .traces import router as traces_router
|
|
14
15
|
from .utils import add_errors_to_responses
|
|
@@ -61,4 +62,5 @@ def create_v1_router(authentication_enabled: bool) -> APIRouter:
|
|
|
61
62
|
router.include_router(traces_router)
|
|
62
63
|
router.include_router(spans_router)
|
|
63
64
|
router.include_router(evaluations_router)
|
|
65
|
+
router.include_router(prompts_router)
|
|
64
66
|
return router
|
|
@@ -48,7 +48,7 @@ from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
|
48
48
|
from phoenix.server.api.utils import delete_projects, delete_traces
|
|
49
49
|
from phoenix.server.dml_event import DatasetInsertEvent
|
|
50
50
|
|
|
51
|
-
from .
|
|
51
|
+
from .models import V1RoutesBaseModel
|
|
52
52
|
from .utils import (
|
|
53
53
|
PaginatedResponseBody,
|
|
54
54
|
ResponseBody,
|
|
@@ -13,7 +13,7 @@ from phoenix.db.insertion.helpers import insert_on_conflict
|
|
|
13
13
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
14
14
|
from phoenix.server.dml_event import ExperimentRunAnnotationInsertEvent
|
|
15
15
|
|
|
16
|
-
from .
|
|
16
|
+
from .models import V1RoutesBaseModel
|
|
17
17
|
from .utils import ResponseBody, add_errors_to_responses
|
|
18
18
|
|
|
19
19
|
router = APIRouter(tags=["experiments"], include_in_schema=False)
|
|
@@ -13,7 +13,7 @@ from phoenix.db.models import ExperimentRunOutput
|
|
|
13
13
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
14
14
|
from phoenix.server.dml_event import ExperimentRunInsertEvent
|
|
15
15
|
|
|
16
|
-
from .
|
|
16
|
+
from .models import V1RoutesBaseModel
|
|
17
17
|
from .utils import ResponseBody, add_errors_to_responses
|
|
18
18
|
|
|
19
19
|
router = APIRouter(tags=["experiments"], include_in_schema=False)
|
|
@@ -15,7 +15,7 @@ from phoenix.db.insertion.helpers import insert_on_conflict
|
|
|
15
15
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
16
16
|
from phoenix.server.dml_event import ExperimentInsertEvent
|
|
17
17
|
|
|
18
|
-
from .
|
|
18
|
+
from .models import V1RoutesBaseModel
|
|
19
19
|
from .utils import ResponseBody, add_errors_to_responses
|
|
20
20
|
|
|
21
21
|
router = APIRouter(tags=["experiments"], include_in_schema=True)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def datetime_encoder(dt: datetime) -> str:
|
|
7
|
+
"""
|
|
8
|
+
Encodes a `datetime` object to an ISO-formatted timestamp string.
|
|
9
|
+
|
|
10
|
+
By default, Pydantic v2 serializes `datetime` objects in a format that
|
|
11
|
+
cannot be parsed by `datetime.fromisoformat`. Adding this encoder to the
|
|
12
|
+
`json_encoders` config for a Pydantic model ensures that the serialized
|
|
13
|
+
`datetime` objects are parseable.
|
|
14
|
+
"""
|
|
15
|
+
return dt.isoformat()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# `json_encoders` is a configuration setting from Pydantic v1 that was
|
|
19
|
+
# removed in Pydantic v2.0.* but restored in Pydantic v2.1.0 with a
|
|
20
|
+
# deprecation warning. At this time, it remains the simplest way to
|
|
21
|
+
# configure custom JSON serialization for specific data types.
|
|
22
|
+
#
|
|
23
|
+
# For details, see:
|
|
24
|
+
# - https://github.com/pydantic/pydantic/pull/6811
|
|
25
|
+
# - https://github.com/pydantic/pydantic/releases/tag/v2.1.0
|
|
26
|
+
#
|
|
27
|
+
# The assertion below is added in case a future release of Pydantic v2 fully
|
|
28
|
+
# removes the `json_encoders` parameter.
|
|
29
|
+
assert "json_encoders" in ConfigDict.__annotations__, (
|
|
30
|
+
"If you encounter this error with `pydantic<2.1.0`, "
|
|
31
|
+
"please upgrade `pydantic` with `pip install -U pydantic>=2.1.0`. "
|
|
32
|
+
"If you encounter this error with `pydantic>=2.1.0`, "
|
|
33
|
+
"please upgrade `arize-phoenix` with `pip install -U arize-phoenix`, "
|
|
34
|
+
"or downgrade `pydantic` to a version that supports the `json_encoders` config setting."
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class V1RoutesBaseModel(BaseModel):
|
|
39
|
+
model_config = ConfigDict(
|
|
40
|
+
json_encoders={datetime: datetime_encoder},
|
|
41
|
+
validate_assignment=True,
|
|
42
|
+
protected_namespaces=tuple(
|
|
43
|
+
[]
|
|
44
|
+
), # suppress warnings about protected namespaces starting with `model_` on pydantic 2.9
|
|
45
|
+
)
|
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
|
+
|
|
4
|
+
from fastapi import APIRouter, HTTPException, Path, Query
|
|
5
|
+
from pydantic import ValidationError
|
|
6
|
+
from sqlalchemy import select
|
|
7
|
+
from sqlalchemy.sql import Select
|
|
8
|
+
from starlette.requests import Request
|
|
9
|
+
from starlette.status import HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY
|
|
10
|
+
from strawberry.relay import GlobalID
|
|
11
|
+
from typing_extensions import TypeAlias, assert_never
|
|
12
|
+
|
|
13
|
+
from phoenix.db import models
|
|
14
|
+
from phoenix.db.types.identifier import Identifier
|
|
15
|
+
from phoenix.db.types.model_provider import ModelProvider
|
|
16
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
17
|
+
PromptInvocationParameters,
|
|
18
|
+
PromptResponseFormat,
|
|
19
|
+
PromptTemplate,
|
|
20
|
+
PromptTemplateFormat,
|
|
21
|
+
PromptTemplateType,
|
|
22
|
+
PromptTools,
|
|
23
|
+
)
|
|
24
|
+
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
25
|
+
from phoenix.server.api.routers.v1.utils import ResponseBody, add_errors_to_responses
|
|
26
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
27
|
+
from phoenix.server.api.types.Prompt import Prompt as PromptNodeType
|
|
28
|
+
from phoenix.server.api.types.PromptVersion import PromptVersion as PromptVersionNodeType
|
|
29
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PromptData(V1RoutesBaseModel):
|
|
35
|
+
name: Identifier
|
|
36
|
+
description: Optional[str] = None
|
|
37
|
+
source_prompt_id: Optional[str] = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Prompt(PromptData):
|
|
41
|
+
id: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class PromptVersionData(V1RoutesBaseModel):
|
|
45
|
+
description: Optional[str] = None
|
|
46
|
+
model_provider: ModelProvider
|
|
47
|
+
model_name: str
|
|
48
|
+
template: PromptTemplate
|
|
49
|
+
template_type: PromptTemplateType
|
|
50
|
+
template_format: PromptTemplateFormat
|
|
51
|
+
invocation_parameters: PromptInvocationParameters
|
|
52
|
+
tools: Optional[PromptTools] = None
|
|
53
|
+
response_format: Optional[PromptResponseFormat] = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PromptVersion(PromptVersionData):
|
|
57
|
+
id: str
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class GetPromptResponseBody(ResponseBody[PromptVersion]):
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class GetPromptsResponseBody(ResponseBody[list[Prompt]]):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class GetPromptVersionsResponseBody(ResponseBody[list[PromptVersion]]):
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class CreatePromptRequestBody(V1RoutesBaseModel):
|
|
73
|
+
prompt: PromptData
|
|
74
|
+
version: PromptVersionData
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class CreatePromptResponseBody(ResponseBody[PromptVersion]):
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
router = APIRouter(tags=["prompts"])
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@router.get(
|
|
85
|
+
"/prompts",
|
|
86
|
+
operation_id="getPrompts",
|
|
87
|
+
summary="Get all prompts",
|
|
88
|
+
responses=add_errors_to_responses(
|
|
89
|
+
[
|
|
90
|
+
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
91
|
+
]
|
|
92
|
+
),
|
|
93
|
+
)
|
|
94
|
+
async def get_prompts(
|
|
95
|
+
request: Request,
|
|
96
|
+
cursor: Optional[str] = Query(
|
|
97
|
+
default=None,
|
|
98
|
+
description="Cursor for pagination (base64-encoded prompt ID)",
|
|
99
|
+
),
|
|
100
|
+
limit: int = Query(
|
|
101
|
+
default=100, description="The max number of prompts to return at a time.", gt=0
|
|
102
|
+
),
|
|
103
|
+
) -> GetPromptsResponseBody:
|
|
104
|
+
async with request.app.state.db() as session:
|
|
105
|
+
query = select(models.Prompt).order_by(models.Prompt.id.desc())
|
|
106
|
+
|
|
107
|
+
if cursor:
|
|
108
|
+
try:
|
|
109
|
+
cursor_id = GlobalID.from_id(cursor).node_id
|
|
110
|
+
query = query.filter(models.Prompt.id <= int(cursor_id))
|
|
111
|
+
except ValueError:
|
|
112
|
+
raise HTTPException(
|
|
113
|
+
detail=f"Invalid cursor format: {cursor}",
|
|
114
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
query = query.limit(limit + 1)
|
|
118
|
+
result = await session.execute(query)
|
|
119
|
+
orm_prompts = result.scalars().all()
|
|
120
|
+
|
|
121
|
+
if not orm_prompts:
|
|
122
|
+
return GetPromptsResponseBody(next_cursor=None, data=[])
|
|
123
|
+
|
|
124
|
+
next_cursor = None
|
|
125
|
+
if len(orm_prompts) == limit + 1:
|
|
126
|
+
last_prompt = orm_prompts[-1]
|
|
127
|
+
next_cursor = str(GlobalID(PromptNodeType.__name__, str(last_prompt.id)))
|
|
128
|
+
orm_prompts = orm_prompts[:-1]
|
|
129
|
+
|
|
130
|
+
prompts = [_prompt_from_orm_prompt(orm_prompt) for orm_prompt in orm_prompts]
|
|
131
|
+
return GetPromptsResponseBody(next_cursor=next_cursor, data=prompts)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@router.get(
|
|
135
|
+
"/prompts/{prompt_identifier}/versions",
|
|
136
|
+
operation_id="listPromptVersions",
|
|
137
|
+
summary="List all prompt versions for a given prompt",
|
|
138
|
+
responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
|
|
139
|
+
response_model_by_alias=True,
|
|
140
|
+
response_model_exclude_defaults=True,
|
|
141
|
+
response_model_exclude_unset=True,
|
|
142
|
+
)
|
|
143
|
+
async def list_prompt_versions(
|
|
144
|
+
request: Request,
|
|
145
|
+
prompt_identifier: str = Path(description="The identifier of the prompt, i.e. name or ID."),
|
|
146
|
+
cursor: Optional[str] = Query(
|
|
147
|
+
default=None,
|
|
148
|
+
description="Cursor for pagination (base64-encoded promptVersion ID)",
|
|
149
|
+
),
|
|
150
|
+
limit: int = Query(
|
|
151
|
+
default=100, description="The max number of prompt versions to return at a time.", gt=0
|
|
152
|
+
),
|
|
153
|
+
) -> GetPromptVersionsResponseBody:
|
|
154
|
+
query = select(models.PromptVersion)
|
|
155
|
+
query = _filter_by_prompt_identifier(query.join(models.Prompt), prompt_identifier)
|
|
156
|
+
query = query.order_by(models.PromptVersion.id.desc())
|
|
157
|
+
|
|
158
|
+
async with request.app.state.db() as session:
|
|
159
|
+
if cursor:
|
|
160
|
+
try:
|
|
161
|
+
cursor_id = GlobalID.from_id(cursor).node_id
|
|
162
|
+
query = query.filter(models.PromptVersion.id <= int(cursor_id))
|
|
163
|
+
except ValueError:
|
|
164
|
+
raise HTTPException(
|
|
165
|
+
detail=f"Invalid cursor format: {cursor}",
|
|
166
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
query = query.limit(limit + 1)
|
|
170
|
+
result = await session.execute(query)
|
|
171
|
+
orm_versions = result.scalars().all()
|
|
172
|
+
|
|
173
|
+
if not orm_versions:
|
|
174
|
+
return GetPromptVersionsResponseBody(next_cursor=None, data=[])
|
|
175
|
+
|
|
176
|
+
next_cursor = None
|
|
177
|
+
if len(orm_versions) == limit + 1:
|
|
178
|
+
last_version = orm_versions[-1]
|
|
179
|
+
next_cursor = str(GlobalID(PromptVersionNodeType.__name__, str(last_version.id)))
|
|
180
|
+
orm_versions = orm_versions[:-1]
|
|
181
|
+
|
|
182
|
+
versions = [_prompt_version_from_orm_version(orm_version) for orm_version in orm_versions]
|
|
183
|
+
return GetPromptVersionsResponseBody(next_cursor=next_cursor, data=versions)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@router.get(
|
|
187
|
+
"/prompt_versions/{prompt_version_id}",
|
|
188
|
+
operation_id="getPromptVersionByPromptVersionId",
|
|
189
|
+
summary="Get prompt by prompt version ID",
|
|
190
|
+
responses=add_errors_to_responses(
|
|
191
|
+
[
|
|
192
|
+
HTTP_404_NOT_FOUND,
|
|
193
|
+
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
194
|
+
]
|
|
195
|
+
),
|
|
196
|
+
response_model_by_alias=True,
|
|
197
|
+
response_model_exclude_defaults=True,
|
|
198
|
+
response_model_exclude_unset=True,
|
|
199
|
+
)
|
|
200
|
+
async def get_prompt_version_by_prompt_version_id(
|
|
201
|
+
request: Request,
|
|
202
|
+
prompt_version_id: str = Path(description="The ID of the prompt version."),
|
|
203
|
+
) -> GetPromptResponseBody:
|
|
204
|
+
try:
|
|
205
|
+
id_ = from_global_id_with_expected_type(
|
|
206
|
+
GlobalID.from_id(prompt_version_id),
|
|
207
|
+
PromptVersionNodeType.__name__,
|
|
208
|
+
)
|
|
209
|
+
except ValueError:
|
|
210
|
+
raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt version ID")
|
|
211
|
+
async with request.app.state.db() as session:
|
|
212
|
+
prompt_version = await session.get(models.PromptVersion, id_)
|
|
213
|
+
if prompt_version is None:
|
|
214
|
+
raise HTTPException(HTTP_404_NOT_FOUND)
|
|
215
|
+
data = _prompt_version_from_orm_version(prompt_version)
|
|
216
|
+
return GetPromptResponseBody(data=data)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@router.get(
|
|
220
|
+
"/prompts/{prompt_identifier}/tags/{tag_name}",
|
|
221
|
+
operation_id="getPromptVersionByTagName",
|
|
222
|
+
summary="Get prompt by tag name",
|
|
223
|
+
responses=add_errors_to_responses(
|
|
224
|
+
[
|
|
225
|
+
HTTP_404_NOT_FOUND,
|
|
226
|
+
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
227
|
+
]
|
|
228
|
+
),
|
|
229
|
+
response_model_by_alias=True,
|
|
230
|
+
response_model_exclude_unset=True,
|
|
231
|
+
response_model_exclude_defaults=True,
|
|
232
|
+
)
|
|
233
|
+
async def get_prompt_version_by_tag_name(
|
|
234
|
+
request: Request,
|
|
235
|
+
prompt_identifier: str = Path(description="The identifier of the prompt, i.e. name or ID."),
|
|
236
|
+
tag_name: str = Path(description="The tag of the prompt version"),
|
|
237
|
+
) -> GetPromptResponseBody:
|
|
238
|
+
try:
|
|
239
|
+
name = Identifier.model_validate(tag_name)
|
|
240
|
+
except ValidationError:
|
|
241
|
+
raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid tag name")
|
|
242
|
+
stmt = (
|
|
243
|
+
select(models.PromptVersion)
|
|
244
|
+
.join_from(models.PromptVersion, models.PromptVersionTag)
|
|
245
|
+
.where(models.PromptVersionTag.name == name)
|
|
246
|
+
)
|
|
247
|
+
stmt = _filter_by_prompt_identifier(stmt.join(models.Prompt), prompt_identifier)
|
|
248
|
+
async with request.app.state.db() as session:
|
|
249
|
+
prompt_version: models.PromptVersion = await session.scalar(stmt)
|
|
250
|
+
if prompt_version is None:
|
|
251
|
+
raise HTTPException(HTTP_404_NOT_FOUND)
|
|
252
|
+
data = _prompt_version_from_orm_version(prompt_version)
|
|
253
|
+
return GetPromptResponseBody(data=data)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@router.get(
|
|
257
|
+
"/prompts/{prompt_identifier}/latest",
|
|
258
|
+
operation_id="getPromptVersionLatest",
|
|
259
|
+
summary="Get the latest prompt version",
|
|
260
|
+
responses=add_errors_to_responses(
|
|
261
|
+
[
|
|
262
|
+
HTTP_404_NOT_FOUND,
|
|
263
|
+
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
264
|
+
]
|
|
265
|
+
),
|
|
266
|
+
response_model_by_alias=True,
|
|
267
|
+
response_model_exclude_defaults=True,
|
|
268
|
+
response_model_exclude_unset=True,
|
|
269
|
+
)
|
|
270
|
+
async def get_prompt_version_by_latest(
|
|
271
|
+
request: Request,
|
|
272
|
+
prompt_identifier: str = Path(description="The identifier of the prompt, i.e. name or ID."),
|
|
273
|
+
) -> GetPromptResponseBody:
|
|
274
|
+
stmt = select(models.PromptVersion).order_by(models.PromptVersion.id.desc()).limit(1)
|
|
275
|
+
stmt = _filter_by_prompt_identifier(stmt.join(models.Prompt), prompt_identifier)
|
|
276
|
+
async with request.app.state.db() as session:
|
|
277
|
+
prompt_version: models.PromptVersion = await session.scalar(stmt)
|
|
278
|
+
if prompt_version is None:
|
|
279
|
+
raise HTTPException(HTTP_404_NOT_FOUND)
|
|
280
|
+
data = _prompt_version_from_orm_version(prompt_version)
|
|
281
|
+
return GetPromptResponseBody(data=data)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@router.post(
|
|
285
|
+
"/prompts",
|
|
286
|
+
operation_id="postPromptVersion",
|
|
287
|
+
summary="Create a prompt version",
|
|
288
|
+
responses=add_errors_to_responses(
|
|
289
|
+
[
|
|
290
|
+
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
291
|
+
]
|
|
292
|
+
),
|
|
293
|
+
response_model_by_alias=True,
|
|
294
|
+
response_model_exclude_defaults=True,
|
|
295
|
+
response_model_exclude_unset=True,
|
|
296
|
+
)
|
|
297
|
+
async def create_prompt(
|
|
298
|
+
request: Request,
|
|
299
|
+
request_body: CreatePromptRequestBody,
|
|
300
|
+
) -> CreatePromptResponseBody:
|
|
301
|
+
if request_body.version.template.type.lower() != "chat":
|
|
302
|
+
raise HTTPException(
|
|
303
|
+
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
304
|
+
"Only CHAT template type is supported for prompts",
|
|
305
|
+
)
|
|
306
|
+
prompt = request_body.prompt
|
|
307
|
+
try:
|
|
308
|
+
name = Identifier.model_validate(prompt.name)
|
|
309
|
+
except ValidationError as e:
|
|
310
|
+
raise HTTPException(
|
|
311
|
+
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
312
|
+
"Invalid name identifier for prompt: " + e.errors()[0]["msg"],
|
|
313
|
+
)
|
|
314
|
+
version = request_body.version
|
|
315
|
+
user_id: Optional[int] = None
|
|
316
|
+
if request.app.state.authentication_enabled:
|
|
317
|
+
assert isinstance(user := request.user, PhoenixUser)
|
|
318
|
+
user_id = int(user.identity)
|
|
319
|
+
async with request.app.state.db() as session:
|
|
320
|
+
if not (prompt_id := await session.scalar(select(models.Prompt.id).filter_by(name=name))):
|
|
321
|
+
prompt_orm = models.Prompt(
|
|
322
|
+
name=name,
|
|
323
|
+
description=prompt.description,
|
|
324
|
+
)
|
|
325
|
+
session.add(prompt_orm)
|
|
326
|
+
await session.flush()
|
|
327
|
+
prompt_id = prompt_orm.id
|
|
328
|
+
version_orm = models.PromptVersion(
|
|
329
|
+
user_id=user_id,
|
|
330
|
+
prompt_id=prompt_id,
|
|
331
|
+
description=version.description,
|
|
332
|
+
model_provider=version.model_provider,
|
|
333
|
+
model_name=version.model_name,
|
|
334
|
+
template_type=version.template_type,
|
|
335
|
+
template_format=version.template_format,
|
|
336
|
+
template=version.template,
|
|
337
|
+
invocation_parameters=version.invocation_parameters,
|
|
338
|
+
tools=version.tools,
|
|
339
|
+
response_format=version.response_format,
|
|
340
|
+
)
|
|
341
|
+
session.add(version_orm)
|
|
342
|
+
data = _prompt_version_from_orm_version(version_orm)
|
|
343
|
+
return CreatePromptResponseBody(data=data)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
class _PromptId(int): ...
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
_PromptIdentifier: TypeAlias = Union[_PromptId, Identifier]
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _parse_prompt_identifier(
|
|
353
|
+
prompt_identifier: str,
|
|
354
|
+
) -> _PromptIdentifier:
|
|
355
|
+
if not prompt_identifier:
|
|
356
|
+
raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt identifier")
|
|
357
|
+
try:
|
|
358
|
+
prompt_id = from_global_id_with_expected_type(
|
|
359
|
+
GlobalID.from_id(prompt_identifier),
|
|
360
|
+
PromptNodeType.__name__,
|
|
361
|
+
)
|
|
362
|
+
except ValueError:
|
|
363
|
+
try:
|
|
364
|
+
return Identifier.model_validate(prompt_identifier)
|
|
365
|
+
except ValidationError:
|
|
366
|
+
raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt name")
|
|
367
|
+
return _PromptId(prompt_id)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _filter_by_prompt_identifier(
|
|
371
|
+
stmt: Select[tuple[models.PromptVersion]],
|
|
372
|
+
prompt_identifier: str,
|
|
373
|
+
) -> Any:
|
|
374
|
+
identifier = _parse_prompt_identifier(prompt_identifier)
|
|
375
|
+
if isinstance(identifier, _PromptId):
|
|
376
|
+
return stmt.where(models.Prompt.id == int(identifier))
|
|
377
|
+
if isinstance(identifier, Identifier):
|
|
378
|
+
return stmt.where(models.Prompt.name == identifier)
|
|
379
|
+
assert_never(identifier)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def _prompt_version_from_orm_version(
|
|
383
|
+
prompt_version: models.PromptVersion,
|
|
384
|
+
) -> PromptVersion:
|
|
385
|
+
prompt_template_type = PromptTemplateType(prompt_version.template_type)
|
|
386
|
+
prompt_template_format = PromptTemplateFormat(prompt_version.template_format)
|
|
387
|
+
return PromptVersion(
|
|
388
|
+
id=str(GlobalID(PromptVersionNodeType.__name__, str(prompt_version.id))),
|
|
389
|
+
description=prompt_version.description or "",
|
|
390
|
+
model_provider=prompt_version.model_provider,
|
|
391
|
+
model_name=prompt_version.model_name,
|
|
392
|
+
template=prompt_version.template,
|
|
393
|
+
template_type=prompt_template_type,
|
|
394
|
+
template_format=prompt_template_format,
|
|
395
|
+
invocation_parameters=prompt_version.invocation_parameters,
|
|
396
|
+
tools=prompt_version.tools,
|
|
397
|
+
response_format=prompt_version.response_format,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _prompt_from_orm_prompt(orm_prompt: models.Prompt) -> Prompt:
|
|
402
|
+
source_prompt_id = (
|
|
403
|
+
str(GlobalID(PromptNodeType.__name__, str(orm_prompt.source_prompt_id)))
|
|
404
|
+
if orm_prompt.source_prompt_id
|
|
405
|
+
else None
|
|
406
|
+
)
|
|
407
|
+
return Prompt(
|
|
408
|
+
id=str(GlobalID(PromptNodeType.__name__, str(orm_prompt.id))),
|
|
409
|
+
source_prompt_id=source_prompt_id,
|
|
410
|
+
name=orm_prompt.name,
|
|
411
|
+
description=orm_prompt.description,
|
|
412
|
+
)
|
|
@@ -24,7 +24,7 @@ from phoenix.server.dml_event import SpanAnnotationInsertEvent
|
|
|
24
24
|
from phoenix.trace.dsl import SpanQuery as SpanQuery_
|
|
25
25
|
from phoenix.utilities.json import encode_df_as_json_string
|
|
26
26
|
|
|
27
|
-
from .
|
|
27
|
+
from .models import V1RoutesBaseModel
|
|
28
28
|
from .utils import RequestBody, ResponseBody, add_errors_to_responses
|
|
29
29
|
|
|
30
30
|
DEFAULT_SPAN_LIMIT = 1000
|
|
@@ -30,7 +30,7 @@ from phoenix.server.dml_event import TraceAnnotationInsertEvent
|
|
|
30
30
|
from phoenix.trace.otel import decode_otlp_span
|
|
31
31
|
from phoenix.utilities.project import get_project_name
|
|
32
32
|
|
|
33
|
-
from .
|
|
33
|
+
from .models import V1RoutesBaseModel
|
|
34
34
|
from .utils import RequestBody, ResponseBody, add_errors_to_responses
|
|
35
35
|
|
|
36
36
|
router = APIRouter(tags=["traces"])
|
|
@@ -2,7 +2,7 @@ from typing import Any, Generic, Optional, TypedDict, TypeVar, Union
|
|
|
2
2
|
|
|
3
3
|
from typing_extensions import TypeAlias, assert_never
|
|
4
4
|
|
|
5
|
-
from .
|
|
5
|
+
from .models import V1RoutesBaseModel
|
|
6
6
|
|
|
7
7
|
StatusCode: TypeAlias = int
|
|
8
8
|
DataType = TypeVar("DataType")
|
|
@@ -41,6 +41,7 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
41
41
|
get_db_trace,
|
|
42
42
|
streaming_llm_span,
|
|
43
43
|
)
|
|
44
|
+
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
44
45
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
45
46
|
ChatCompletionInput,
|
|
46
47
|
ChatCompletionOverDatasetInput,
|
|
@@ -59,7 +60,6 @@ from phoenix.server.api.types.Experiment import to_gql_experiment
|
|
|
59
60
|
from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
|
|
60
61
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
61
62
|
from phoenix.server.api.types.Span import to_gql_span
|
|
62
|
-
from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
|
|
63
63
|
from phoenix.server.dml_event import SpanInsertEvent
|
|
64
64
|
from phoenix.server.types import DbSessionFactory
|
|
65
65
|
from phoenix.utilities.template_formatters import (
|
|
@@ -124,7 +124,7 @@ class Subscription:
|
|
|
124
124
|
messages = list(
|
|
125
125
|
_formatted_messages(
|
|
126
126
|
messages=messages,
|
|
127
|
-
|
|
127
|
+
template_format=template_options.format,
|
|
128
128
|
template_variables=template_options.variables,
|
|
129
129
|
)
|
|
130
130
|
)
|
|
@@ -198,9 +198,7 @@ class Subscription:
|
|
|
198
198
|
)
|
|
199
199
|
async with info.context.db() as session:
|
|
200
200
|
if (
|
|
201
|
-
|
|
202
|
-
select(models.Dataset).where(models.Dataset.id == dataset_id)
|
|
203
|
-
)
|
|
201
|
+
await session.scalar(select(models.Dataset).where(models.Dataset.id == dataset_id))
|
|
204
202
|
) is None:
|
|
205
203
|
raise NotFound(f"Could not find dataset with ID {dataset_id}")
|
|
206
204
|
if version_id is None:
|
|
@@ -274,9 +272,9 @@ class Subscription:
|
|
|
274
272
|
experiment = models.Experiment(
|
|
275
273
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
276
274
|
dataset_version_id=resolved_version_id,
|
|
277
|
-
name=input.experiment_name
|
|
278
|
-
|
|
279
|
-
|
|
275
|
+
name=input.experiment_name
|
|
276
|
+
or _default_playground_experiment_name(input.prompt_name),
|
|
277
|
+
description=input.experiment_description,
|
|
280
278
|
repetitions=1,
|
|
281
279
|
metadata_=input.experiment_metadata or dict(),
|
|
282
280
|
project_name=PLAYGROUND_PROJECT_NAME,
|
|
@@ -394,7 +392,7 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
394
392
|
messages = list(
|
|
395
393
|
_formatted_messages(
|
|
396
394
|
messages=messages,
|
|
397
|
-
|
|
395
|
+
template_format=input.template_format,
|
|
398
396
|
template_variables=revision.input,
|
|
399
397
|
)
|
|
400
398
|
)
|
|
@@ -472,7 +470,7 @@ def _is_result_payloads_stream(
|
|
|
472
470
|
Checks if the given generator was instantiated from
|
|
473
471
|
`_chat_completion_result_payloads`
|
|
474
472
|
"""
|
|
475
|
-
return stream.ag_code == _chat_completion_result_payloads.__code__
|
|
473
|
+
return stream.ag_code == _chat_completion_result_payloads.__code__ # type: ignore
|
|
476
474
|
|
|
477
475
|
|
|
478
476
|
def _create_task_with_timeout(
|
|
@@ -534,13 +532,13 @@ async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
|
|
|
534
532
|
def _formatted_messages(
|
|
535
533
|
*,
|
|
536
534
|
messages: Iterable[ChatCompletionMessage],
|
|
537
|
-
|
|
535
|
+
template_format: PromptTemplateFormat,
|
|
538
536
|
template_variables: Mapping[str, Any],
|
|
539
537
|
) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
|
|
540
538
|
"""
|
|
541
539
|
Formats the messages using the given template options.
|
|
542
540
|
"""
|
|
543
|
-
template_formatter = _template_formatter(
|
|
541
|
+
template_formatter = _template_formatter(template_format=template_format)
|
|
544
542
|
(
|
|
545
543
|
roles,
|
|
546
544
|
templates,
|
|
@@ -555,25 +553,24 @@ def _formatted_messages(
|
|
|
555
553
|
return formatted_messages
|
|
556
554
|
|
|
557
555
|
|
|
558
|
-
def _template_formatter(
|
|
556
|
+
def _template_formatter(template_format: PromptTemplateFormat) -> TemplateFormatter:
|
|
559
557
|
"""
|
|
560
|
-
Instantiates the appropriate template formatter for the template
|
|
558
|
+
Instantiates the appropriate template formatter for the template format
|
|
561
559
|
"""
|
|
562
|
-
if
|
|
560
|
+
if template_format is PromptTemplateFormat.MUSTACHE:
|
|
563
561
|
return MustacheTemplateFormatter()
|
|
564
|
-
if
|
|
562
|
+
if template_format is PromptTemplateFormat.F_STRING:
|
|
565
563
|
return FStringTemplateFormatter()
|
|
566
|
-
if
|
|
564
|
+
if template_format is PromptTemplateFormat.NONE:
|
|
567
565
|
return NoOpFormatter()
|
|
568
|
-
assert_never(
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
def _default_playground_experiment_name() -> str:
|
|
572
|
-
return "playground-experiment"
|
|
566
|
+
assert_never(template_format)
|
|
573
567
|
|
|
574
568
|
|
|
575
|
-
def
|
|
576
|
-
|
|
569
|
+
def _default_playground_experiment_name(prompt_name: Optional[str] = None) -> str:
|
|
570
|
+
name = "playground-experiment"
|
|
571
|
+
if prompt_name:
|
|
572
|
+
name = f"{name} prompt:{prompt_name}"
|
|
573
|
+
return name
|
|
577
574
|
|
|
578
575
|
|
|
579
576
|
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|