arize-phoenix 7.12.3__py3-none-any.whl → 8.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/METADATA +31 -28
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/RECORD +70 -47
- phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py +197 -0
- phoenix/db/models.py +307 -0
- phoenix/db/types/__init__.py +0 -0
- phoenix/db/types/identifier.py +7 -0
- phoenix/db/types/model_provider.py +8 -0
- phoenix/server/api/context.py +2 -0
- phoenix/server/api/dataloaders/__init__.py +2 -0
- phoenix/server/api/dataloaders/prompt_version_sequence_number.py +35 -0
- phoenix/server/api/helpers/jsonschema.py +135 -0
- phoenix/server/api/helpers/playground_clients.py +15 -15
- phoenix/server/api/helpers/playground_spans.py +9 -0
- phoenix/server/api/helpers/prompts/__init__.py +0 -0
- phoenix/server/api/helpers/prompts/conversions/__init__.py +0 -0
- phoenix/server/api/helpers/prompts/conversions/anthropic.py +87 -0
- phoenix/server/api/helpers/prompts/conversions/openai.py +78 -0
- phoenix/server/api/helpers/prompts/models.py +575 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +9 -4
- phoenix/server/api/input_types/PromptTemplateOptions.py +10 -0
- phoenix/server/api/input_types/PromptVersionInput.py +133 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/chat_mutations.py +18 -16
- phoenix/server/api/mutations/prompt_label_mutations.py +191 -0
- phoenix/server/api/mutations/prompt_mutations.py +312 -0
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +148 -0
- phoenix/server/api/mutations/user_mutations.py +7 -6
- phoenix/server/api/openapi/schema.py +1 -0
- phoenix/server/api/queries.py +84 -31
- phoenix/server/api/routers/oauth2.py +3 -2
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/datasets.py +1 -1
- phoenix/server/api/routers/v1/experiment_evaluations.py +1 -1
- phoenix/server/api/routers/v1/experiment_runs.py +1 -1
- phoenix/server/api/routers/v1/experiments.py +1 -1
- phoenix/server/api/routers/v1/models.py +45 -0
- phoenix/server/api/routers/v1/prompts.py +412 -0
- phoenix/server/api/routers/v1/spans.py +1 -1
- phoenix/server/api/routers/v1/traces.py +1 -1
- phoenix/server/api/routers/v1/utils.py +1 -1
- phoenix/server/api/subscriptions.py +21 -24
- phoenix/server/api/types/GenerativeProvider.py +4 -4
- phoenix/server/api/types/Identifier.py +15 -0
- phoenix/server/api/types/Project.py +5 -7
- phoenix/server/api/types/Prompt.py +134 -0
- phoenix/server/api/types/PromptLabel.py +41 -0
- phoenix/server/api/types/PromptVersion.py +148 -0
- phoenix/server/api/types/PromptVersionTag.py +27 -0
- phoenix/server/api/types/PromptVersionTemplate.py +148 -0
- phoenix/server/api/types/ResponseFormat.py +9 -0
- phoenix/server/api/types/ToolDefinition.py +9 -0
- phoenix/server/app.py +3 -0
- phoenix/server/static/.vite/manifest.json +45 -45
- phoenix/server/static/assets/components-B-qgPyHv.js +2699 -0
- phoenix/server/static/assets/index-D4KO1IcF.js +1125 -0
- phoenix/server/static/assets/pages-DdcuL3Rh.js +5634 -0
- phoenix/server/static/assets/vendor-DQp7CrDA.js +894 -0
- phoenix/server/static/assets/vendor-arizeai-C1nEIEQq.js +657 -0
- phoenix/server/static/assets/vendor-codemirror-BZXYUIkP.js +24 -0
- phoenix/server/static/assets/vendor-recharts-BUFpwCVD.js +59 -0
- phoenix/server/static/assets/{vendor-shiki-Cl9QBraO.js → vendor-shiki-C8L-c9jT.js} +2 -2
- phoenix/server/static/assets/{vendor-three-DwGkEfCM.js → vendor-three-C-AGeJYv.js} +1 -1
- phoenix/session/client.py +25 -21
- phoenix/utilities/client.py +6 -0
- phoenix/version.py +1 -1
- phoenix/server/api/input_types/TemplateOptions.py +0 -10
- phoenix/server/api/routers/v1/pydantic_compat.py +0 -78
- phoenix/server/api/types/TemplateLanguage.py +0 -10
- phoenix/server/static/assets/components-DckIzNmE.js +0 -2125
- phoenix/server/static/assets/index-Bf25Ogon.js +0 -113
- phoenix/server/static/assets/pages-DL7J9q9w.js +0 -4463
- phoenix/server/static/assets/vendor-DvC8cT4X.js +0 -894
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +0 -662
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +0 -24
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +0 -59
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/static/assets/{vendor-DxkFTwjz.css → vendor-Cg6lcjUC.css} +0 -0
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
from typing import Any, Optional, Union, cast
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from fastapi import Request
|
|
5
|
+
from pydantic import ValidationError
|
|
6
|
+
from sqlalchemy import delete, select, update
|
|
7
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
8
|
+
from sqlalchemy.orm import joinedload
|
|
9
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
10
|
+
from strawberry.relay.types import GlobalID
|
|
11
|
+
from strawberry.types import Info
|
|
12
|
+
|
|
13
|
+
from phoenix.db import models
|
|
14
|
+
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
15
|
+
from phoenix.db.types.model_provider import ModelProvider
|
|
16
|
+
from phoenix.server.api.context import Context
|
|
17
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
18
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
19
|
+
normalize_response_format,
|
|
20
|
+
normalize_tools,
|
|
21
|
+
validate_invocation_parameters,
|
|
22
|
+
)
|
|
23
|
+
from phoenix.server.api.input_types.PromptVersionInput import (
|
|
24
|
+
ChatPromptVersionInput,
|
|
25
|
+
to_pydantic_prompt_chat_template_v1,
|
|
26
|
+
)
|
|
27
|
+
from phoenix.server.api.mutations.prompt_version_tag_mutations import (
|
|
28
|
+
SetPromptVersionTagInput,
|
|
29
|
+
upsert_prompt_version_tag,
|
|
30
|
+
)
|
|
31
|
+
from phoenix.server.api.queries import Query
|
|
32
|
+
from phoenix.server.api.types.Identifier import Identifier
|
|
33
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
34
|
+
from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
|
|
35
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@strawberry.input
|
|
39
|
+
class CreateChatPromptInput:
|
|
40
|
+
name: Identifier
|
|
41
|
+
description: Optional[str] = None
|
|
42
|
+
prompt_version: ChatPromptVersionInput
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@strawberry.input
|
|
46
|
+
class CreateChatPromptVersionInput:
|
|
47
|
+
prompt_id: GlobalID
|
|
48
|
+
prompt_version: ChatPromptVersionInput
|
|
49
|
+
tags: Optional[list[SetPromptVersionTagInput]] = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@strawberry.input
|
|
53
|
+
class DeletePromptInput:
|
|
54
|
+
prompt_id: GlobalID
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@strawberry.input
|
|
58
|
+
class ClonePromptInput:
|
|
59
|
+
name: Identifier
|
|
60
|
+
description: Optional[str] = None
|
|
61
|
+
prompt_id: GlobalID
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@strawberry.input
|
|
65
|
+
class PatchPromptInput:
|
|
66
|
+
prompt_id: GlobalID
|
|
67
|
+
description: str
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@strawberry.type
|
|
71
|
+
class DeletePromptMutationPayload:
|
|
72
|
+
query: Query
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@strawberry.type
|
|
76
|
+
class PromptMutationMixin:
|
|
77
|
+
@strawberry.mutation
|
|
78
|
+
async def create_chat_prompt(
|
|
79
|
+
self, info: Info[Context, None], input: CreateChatPromptInput
|
|
80
|
+
) -> Prompt:
|
|
81
|
+
user_id: Optional[int] = None
|
|
82
|
+
assert isinstance(request := info.context.request, Request)
|
|
83
|
+
if "user" in request.scope:
|
|
84
|
+
assert isinstance(user := request.user, PhoenixUser)
|
|
85
|
+
user_id = int(user.identity)
|
|
86
|
+
|
|
87
|
+
input_prompt_version = input.prompt_version
|
|
88
|
+
tool_definitions = [tool.definition for tool in input_prompt_version.tools]
|
|
89
|
+
tool_choice = cast(
|
|
90
|
+
Optional[Union[str, dict[str, Any]]],
|
|
91
|
+
cast(dict[str, Any], input.prompt_version.invocation_parameters).pop(
|
|
92
|
+
"tool_choice", None
|
|
93
|
+
),
|
|
94
|
+
)
|
|
95
|
+
model_provider = ModelProvider(input_prompt_version.model_provider)
|
|
96
|
+
try:
|
|
97
|
+
tools = (
|
|
98
|
+
normalize_tools(tool_definitions, model_provider, tool_choice)
|
|
99
|
+
if tool_definitions
|
|
100
|
+
else None
|
|
101
|
+
)
|
|
102
|
+
template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
|
|
103
|
+
response_format = (
|
|
104
|
+
normalize_response_format(
|
|
105
|
+
input_prompt_version.response_format.definition,
|
|
106
|
+
model_provider,
|
|
107
|
+
)
|
|
108
|
+
if input_prompt_version.response_format
|
|
109
|
+
else None
|
|
110
|
+
)
|
|
111
|
+
invocation_parameters = validate_invocation_parameters(
|
|
112
|
+
input_prompt_version.invocation_parameters,
|
|
113
|
+
model_provider,
|
|
114
|
+
)
|
|
115
|
+
except ValidationError as error:
|
|
116
|
+
raise BadRequest(str(error))
|
|
117
|
+
|
|
118
|
+
async with info.context.db() as session:
|
|
119
|
+
prompt_version = models.PromptVersion(
|
|
120
|
+
description=input_prompt_version.description,
|
|
121
|
+
user_id=user_id,
|
|
122
|
+
template_type="CHAT",
|
|
123
|
+
template_format=input_prompt_version.template_format,
|
|
124
|
+
template=template,
|
|
125
|
+
invocation_parameters=invocation_parameters,
|
|
126
|
+
tools=tools,
|
|
127
|
+
response_format=response_format,
|
|
128
|
+
model_provider=input_prompt_version.model_provider,
|
|
129
|
+
model_name=input_prompt_version.model_name,
|
|
130
|
+
)
|
|
131
|
+
name = IdentifierModel.model_validate(str(input.name))
|
|
132
|
+
prompt = models.Prompt(
|
|
133
|
+
name=name,
|
|
134
|
+
description=input.description,
|
|
135
|
+
prompt_versions=[prompt_version],
|
|
136
|
+
)
|
|
137
|
+
session.add(prompt)
|
|
138
|
+
try:
|
|
139
|
+
await session.commit()
|
|
140
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
141
|
+
raise Conflict(f"A prompt named '{input.name}' already exists")
|
|
142
|
+
return to_gql_prompt_from_orm(prompt)
|
|
143
|
+
|
|
144
|
+
@strawberry.mutation
|
|
145
|
+
async def create_chat_prompt_version(
|
|
146
|
+
self,
|
|
147
|
+
info: Info[Context, None],
|
|
148
|
+
input: CreateChatPromptVersionInput,
|
|
149
|
+
) -> Prompt:
|
|
150
|
+
user_id: Optional[int] = None
|
|
151
|
+
assert isinstance(request := info.context.request, Request)
|
|
152
|
+
if "user" in request.scope:
|
|
153
|
+
assert isinstance(user := request.user, PhoenixUser)
|
|
154
|
+
user_id = int(user.identity)
|
|
155
|
+
|
|
156
|
+
input_prompt_version = input.prompt_version
|
|
157
|
+
tool_definitions = [tool.definition for tool in input.prompt_version.tools]
|
|
158
|
+
tool_choice = cast(
|
|
159
|
+
Optional[Union[str, dict[str, Any]]],
|
|
160
|
+
cast(dict[str, Any], input.prompt_version.invocation_parameters).pop(
|
|
161
|
+
"tool_choice", None
|
|
162
|
+
),
|
|
163
|
+
)
|
|
164
|
+
model_provider = ModelProvider(input_prompt_version.model_provider)
|
|
165
|
+
try:
|
|
166
|
+
tools = (
|
|
167
|
+
normalize_tools(tool_definitions, model_provider, tool_choice)
|
|
168
|
+
if tool_definitions
|
|
169
|
+
else None
|
|
170
|
+
)
|
|
171
|
+
template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
|
|
172
|
+
response_format = (
|
|
173
|
+
normalize_response_format(
|
|
174
|
+
input_prompt_version.response_format.definition,
|
|
175
|
+
model_provider,
|
|
176
|
+
)
|
|
177
|
+
if input_prompt_version.response_format
|
|
178
|
+
else None
|
|
179
|
+
)
|
|
180
|
+
invocation_parameters = validate_invocation_parameters(
|
|
181
|
+
input_prompt_version.invocation_parameters,
|
|
182
|
+
model_provider,
|
|
183
|
+
)
|
|
184
|
+
except ValidationError as error:
|
|
185
|
+
raise BadRequest(str(error))
|
|
186
|
+
|
|
187
|
+
prompt_id = from_global_id_with_expected_type(
|
|
188
|
+
global_id=input.prompt_id, expected_type_name=Prompt.__name__
|
|
189
|
+
)
|
|
190
|
+
async with info.context.db() as session:
|
|
191
|
+
prompt = await session.get(models.Prompt, prompt_id)
|
|
192
|
+
if not prompt:
|
|
193
|
+
raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
|
|
194
|
+
|
|
195
|
+
prompt_version = models.PromptVersion(
|
|
196
|
+
prompt_id=prompt_id,
|
|
197
|
+
description=input.prompt_version.description,
|
|
198
|
+
user_id=user_id,
|
|
199
|
+
template_type="CHAT",
|
|
200
|
+
template_format=input.prompt_version.template_format,
|
|
201
|
+
template=template,
|
|
202
|
+
invocation_parameters=invocation_parameters,
|
|
203
|
+
tools=tools,
|
|
204
|
+
response_format=response_format,
|
|
205
|
+
model_provider=input.prompt_version.model_provider,
|
|
206
|
+
model_name=input.prompt_version.model_name,
|
|
207
|
+
)
|
|
208
|
+
session.add(prompt_version)
|
|
209
|
+
|
|
210
|
+
# ensure prompt_version is flushed to the database before creating tags against the
|
|
211
|
+
# prompt_version id
|
|
212
|
+
await session.flush()
|
|
213
|
+
|
|
214
|
+
if input.tags:
|
|
215
|
+
for tag in input.tags:
|
|
216
|
+
await upsert_prompt_version_tag(
|
|
217
|
+
session, prompt_id, prompt_version.id, tag.name, tag.description
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return to_gql_prompt_from_orm(prompt)
|
|
221
|
+
|
|
222
|
+
@strawberry.mutation
|
|
223
|
+
async def delete_prompt(
|
|
224
|
+
self, info: Info[Context, None], input: DeletePromptInput
|
|
225
|
+
) -> DeletePromptMutationPayload:
|
|
226
|
+
prompt_id = from_global_id_with_expected_type(
|
|
227
|
+
global_id=input.prompt_id, expected_type_name=Prompt.__name__
|
|
228
|
+
)
|
|
229
|
+
async with info.context.db() as session:
|
|
230
|
+
stmt = delete(models.Prompt).where(models.Prompt.id == prompt_id)
|
|
231
|
+
result = await session.execute(stmt)
|
|
232
|
+
|
|
233
|
+
if result.rowcount == 0:
|
|
234
|
+
raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
|
|
235
|
+
|
|
236
|
+
await session.commit()
|
|
237
|
+
return DeletePromptMutationPayload(query=Query())
|
|
238
|
+
|
|
239
|
+
@strawberry.mutation
|
|
240
|
+
async def clone_prompt(self, info: Info[Context, None], input: ClonePromptInput) -> Prompt:
|
|
241
|
+
prompt_id = from_global_id_with_expected_type(
|
|
242
|
+
global_id=input.prompt_id, expected_type_name=Prompt.__name__
|
|
243
|
+
)
|
|
244
|
+
async with info.context.db() as session:
|
|
245
|
+
# Load prompt with all versions
|
|
246
|
+
stmt = (
|
|
247
|
+
select(models.Prompt)
|
|
248
|
+
.options(joinedload(models.Prompt.prompt_versions))
|
|
249
|
+
.where(models.Prompt.id == prompt_id)
|
|
250
|
+
)
|
|
251
|
+
prompt = await session.scalar(stmt)
|
|
252
|
+
|
|
253
|
+
if not prompt:
|
|
254
|
+
raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
|
|
255
|
+
|
|
256
|
+
# Create new prompt
|
|
257
|
+
name = IdentifierModel.model_validate(str(input.name))
|
|
258
|
+
new_prompt = models.Prompt(
|
|
259
|
+
name=name,
|
|
260
|
+
description=input.description,
|
|
261
|
+
source_prompt_id=prompt_id,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Create copies of all versions
|
|
265
|
+
new_versions = [
|
|
266
|
+
models.PromptVersion(
|
|
267
|
+
prompt_id=new_prompt.id,
|
|
268
|
+
user_id=version.user_id,
|
|
269
|
+
description=version.description,
|
|
270
|
+
template_type=version.template_type,
|
|
271
|
+
template_format=version.template_format,
|
|
272
|
+
template=version.template,
|
|
273
|
+
invocation_parameters=version.invocation_parameters,
|
|
274
|
+
tools=version.tools,
|
|
275
|
+
response_format=version.response_format,
|
|
276
|
+
model_provider=version.model_provider,
|
|
277
|
+
model_name=version.model_name,
|
|
278
|
+
)
|
|
279
|
+
for version in prompt.prompt_versions
|
|
280
|
+
]
|
|
281
|
+
# Add all version copies to the new prompt
|
|
282
|
+
new_prompt.prompt_versions = new_versions
|
|
283
|
+
|
|
284
|
+
session.add(new_prompt)
|
|
285
|
+
|
|
286
|
+
try:
|
|
287
|
+
await session.commit()
|
|
288
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
289
|
+
raise Conflict(f"A prompt named '{input.name}' already exists")
|
|
290
|
+
return to_gql_prompt_from_orm(new_prompt)
|
|
291
|
+
|
|
292
|
+
@strawberry.mutation
|
|
293
|
+
async def patch_prompt(self, info: Info[Context, None], input: PatchPromptInput) -> Prompt:
|
|
294
|
+
prompt_id = from_global_id_with_expected_type(
|
|
295
|
+
global_id=input.prompt_id, expected_type_name=Prompt.__name__
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
async with info.context.db() as session:
|
|
299
|
+
stmt = (
|
|
300
|
+
update(models.Prompt)
|
|
301
|
+
.where(models.Prompt.id == prompt_id)
|
|
302
|
+
.values(description=input.description)
|
|
303
|
+
.returning(models.Prompt)
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
result = await session.execute(stmt)
|
|
307
|
+
prompt = result.scalar_one_or_none()
|
|
308
|
+
|
|
309
|
+
if prompt is None:
|
|
310
|
+
raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
|
|
311
|
+
|
|
312
|
+
return to_gql_prompt_from_orm(prompt)
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
8
|
+
from strawberry.relay import GlobalID
|
|
9
|
+
from strawberry.types import Info
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
13
|
+
from phoenix.server.api.context import Context
|
|
14
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
15
|
+
from phoenix.server.api.queries import Query
|
|
16
|
+
from phoenix.server.api.types.Identifier import Identifier
|
|
17
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
18
|
+
from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
|
|
19
|
+
from phoenix.server.api.types.PromptVersion import PromptVersion
|
|
20
|
+
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@strawberry.input
|
|
24
|
+
class DeletePromptVersionTagInput:
|
|
25
|
+
prompt_version_tag_id: GlobalID
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@strawberry.input
|
|
29
|
+
class SetPromptVersionTagInput:
|
|
30
|
+
prompt_version_id: GlobalID
|
|
31
|
+
name: Identifier
|
|
32
|
+
description: Optional[str] = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@strawberry.type
|
|
36
|
+
class PromptVersionTagMutationPayload:
|
|
37
|
+
prompt_version_tag: Optional[PromptVersionTag]
|
|
38
|
+
prompt: Prompt
|
|
39
|
+
query: Query
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@strawberry.type
|
|
43
|
+
class PromptVersionTagMutationMixin:
|
|
44
|
+
@strawberry.mutation
|
|
45
|
+
async def delete_prompt_version_tag(
|
|
46
|
+
self, info: Info[Context, None], input: DeletePromptVersionTagInput
|
|
47
|
+
) -> PromptVersionTagMutationPayload:
|
|
48
|
+
async with info.context.db() as session:
|
|
49
|
+
prompt_version_tag_id = from_global_id_with_expected_type(
|
|
50
|
+
input.prompt_version_tag_id, PromptVersionTag.__name__
|
|
51
|
+
)
|
|
52
|
+
stmt = (
|
|
53
|
+
select(models.PromptVersionTag, models.Prompt)
|
|
54
|
+
.join(
|
|
55
|
+
models.PromptVersion,
|
|
56
|
+
models.PromptVersion.id == models.PromptVersionTag.prompt_version_id,
|
|
57
|
+
)
|
|
58
|
+
.join(models.Prompt, models.Prompt.id == models.PromptVersion.prompt_id)
|
|
59
|
+
.where(models.PromptVersionTag.id == prompt_version_tag_id)
|
|
60
|
+
)
|
|
61
|
+
result = await session.execute(stmt)
|
|
62
|
+
if results := result.one_or_none():
|
|
63
|
+
prompt_version_tag, prompt = results
|
|
64
|
+
|
|
65
|
+
if not prompt_version_tag:
|
|
66
|
+
raise NotFound(f"PromptVersionTag with ID {input.prompt_version_tag_id} not found")
|
|
67
|
+
|
|
68
|
+
if not prompt:
|
|
69
|
+
raise BadRequest(
|
|
70
|
+
f"PromptVersionTag with ID {input.prompt_version_tag_id} "
|
|
71
|
+
"does not belong to a prompt"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
await session.delete(prompt_version_tag)
|
|
75
|
+
await session.commit()
|
|
76
|
+
return PromptVersionTagMutationPayload(
|
|
77
|
+
prompt_version_tag=None, query=Query(), prompt=to_gql_prompt_from_orm(prompt)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@strawberry.mutation
|
|
81
|
+
async def set_prompt_version_tag(
|
|
82
|
+
self, info: Info[Context, None], input: SetPromptVersionTagInput
|
|
83
|
+
) -> PromptVersionTagMutationPayload:
|
|
84
|
+
async with info.context.db() as session:
|
|
85
|
+
prompt_version_id = from_global_id_with_expected_type(
|
|
86
|
+
input.prompt_version_id, PromptVersion.__name__
|
|
87
|
+
)
|
|
88
|
+
prompt_version = await session.scalar(
|
|
89
|
+
select(models.PromptVersion).where(models.PromptVersion.id == prompt_version_id)
|
|
90
|
+
)
|
|
91
|
+
if not prompt_version:
|
|
92
|
+
raise BadRequest(f"PromptVersion with ID {input.prompt_version_id} not found.")
|
|
93
|
+
|
|
94
|
+
prompt_id = prompt_version.prompt_id
|
|
95
|
+
prompt = await session.scalar(
|
|
96
|
+
select(models.Prompt).where(models.Prompt.id == prompt_id)
|
|
97
|
+
)
|
|
98
|
+
if not prompt:
|
|
99
|
+
raise BadRequest("All prompt version tags must belong to a prompt")
|
|
100
|
+
|
|
101
|
+
updated_tag = await upsert_prompt_version_tag(
|
|
102
|
+
session, prompt_id, prompt_version_id, input.name, input.description
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if not updated_tag:
|
|
106
|
+
raise BadRequest("Failed to create or update PromptVersionTag.")
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
await session.commit()
|
|
110
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
111
|
+
raise Conflict("Failed to update PromptVersionTag.")
|
|
112
|
+
|
|
113
|
+
version_tag = to_gql_prompt_version_tag(updated_tag)
|
|
114
|
+
return PromptVersionTagMutationPayload(
|
|
115
|
+
prompt_version_tag=version_tag, prompt=to_gql_prompt_from_orm(prompt), query=Query()
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
async def upsert_prompt_version_tag(
|
|
120
|
+
session: AsyncSession,
|
|
121
|
+
prompt_id: int,
|
|
122
|
+
prompt_version_id: int,
|
|
123
|
+
name_str: str,
|
|
124
|
+
description: Optional[str] = None,
|
|
125
|
+
) -> models.PromptVersionTag:
|
|
126
|
+
name = IdentifierModel.model_validate(name_str)
|
|
127
|
+
|
|
128
|
+
existing_tag = await session.scalar(
|
|
129
|
+
select(models.PromptVersionTag).where(
|
|
130
|
+
models.PromptVersionTag.prompt_id == prompt_id,
|
|
131
|
+
models.PromptVersionTag.name == name,
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if existing_tag:
|
|
136
|
+
existing_tag.prompt_version_id = prompt_version_id
|
|
137
|
+
if description is not None:
|
|
138
|
+
existing_tag.description = description
|
|
139
|
+
return existing_tag
|
|
140
|
+
else:
|
|
141
|
+
new_tag = models.PromptVersionTag(
|
|
142
|
+
name=name,
|
|
143
|
+
description=description,
|
|
144
|
+
prompt_id=prompt_id,
|
|
145
|
+
prompt_version_id=prompt_version_id,
|
|
146
|
+
)
|
|
147
|
+
session.add(new_tag)
|
|
148
|
+
return new_tag
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import secrets
|
|
2
2
|
from contextlib import AsyncExitStack
|
|
3
3
|
from datetime import datetime, timezone
|
|
4
|
-
from typing import Literal, Optional
|
|
4
|
+
from typing import Literal, Optional, Union
|
|
5
5
|
|
|
6
6
|
import strawberry
|
|
7
7
|
from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select
|
|
8
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
8
9
|
from sqlalchemy.orm import joinedload
|
|
9
|
-
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
|
|
10
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
10
11
|
from strawberry import UNSET
|
|
11
12
|
from strawberry.relay import GlobalID
|
|
12
13
|
from strawberry.types import Info
|
|
@@ -108,7 +109,7 @@ class UserMutationMixin:
|
|
|
108
109
|
session.add(user)
|
|
109
110
|
try:
|
|
110
111
|
await session.flush()
|
|
111
|
-
except
|
|
112
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as error:
|
|
112
113
|
raise Conflict(_user_operation_error_message(error))
|
|
113
114
|
return UserMutationPayload(user=to_gql_user(user))
|
|
114
115
|
|
|
@@ -148,7 +149,7 @@ class UserMutationMixin:
|
|
|
148
149
|
assert user in session.dirty
|
|
149
150
|
try:
|
|
150
151
|
await session.flush()
|
|
151
|
-
except
|
|
152
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as error:
|
|
152
153
|
raise Conflict(_user_operation_error_message(error, "modify"))
|
|
153
154
|
assert user
|
|
154
155
|
if input.new_password:
|
|
@@ -186,7 +187,7 @@ class UserMutationMixin:
|
|
|
186
187
|
user.updated_at = datetime.now(timezone.utc)
|
|
187
188
|
try:
|
|
188
189
|
await session.flush()
|
|
189
|
-
except
|
|
190
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as error:
|
|
190
191
|
raise Conflict(_user_operation_error_message(error, "modify"))
|
|
191
192
|
assert user
|
|
192
193
|
if input.new_password:
|
|
@@ -313,7 +314,7 @@ def _select_user_by_id(user_id: int) -> Select[tuple[models.User]]:
|
|
|
313
314
|
|
|
314
315
|
|
|
315
316
|
def _user_operation_error_message(
|
|
316
|
-
error:
|
|
317
|
+
error: Union[PostgreSQLIntegrityError, SQLiteIntegrityError],
|
|
317
318
|
operation: Literal["create", "modify"] = "create",
|
|
318
319
|
) -> str:
|
|
319
320
|
"""
|
phoenix/server/api/queries.py
CHANGED
|
@@ -14,22 +14,12 @@ from strawberry.types import Info
|
|
|
14
14
|
from typing_extensions import Annotated, TypeAlias
|
|
15
15
|
|
|
16
16
|
from phoenix.db import enums, models
|
|
17
|
-
from phoenix.db.models import
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
from phoenix.db.models import
|
|
21
|
-
DatasetExampleRevision as OrmRevision,
|
|
22
|
-
)
|
|
23
|
-
from phoenix.db.models import (
|
|
24
|
-
DatasetVersion as OrmVersion,
|
|
25
|
-
)
|
|
26
|
-
from phoenix.db.models import (
|
|
27
|
-
Experiment as OrmExperiment,
|
|
28
|
-
)
|
|
17
|
+
from phoenix.db.models import DatasetExample as OrmExample
|
|
18
|
+
from phoenix.db.models import DatasetExampleRevision as OrmRevision
|
|
19
|
+
from phoenix.db.models import DatasetVersion as OrmVersion
|
|
20
|
+
from phoenix.db.models import Experiment as OrmExperiment
|
|
29
21
|
from phoenix.db.models import ExperimentRun as OrmExperimentRun
|
|
30
|
-
from phoenix.db.models import
|
|
31
|
-
Trace as OrmTrace,
|
|
32
|
-
)
|
|
22
|
+
from phoenix.db.models import Trace as OrmTrace
|
|
33
23
|
from phoenix.pointcloud.clustering import Hdbscan
|
|
34
24
|
from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
|
|
35
25
|
from phoenix.server.api.context import Context
|
|
@@ -43,14 +33,9 @@ from phoenix.server.api.helpers.experiment_run_filters import (
|
|
|
43
33
|
from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
|
|
44
34
|
from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
|
|
45
35
|
from phoenix.server.api.input_types.ClusterInput import ClusterInput
|
|
46
|
-
from phoenix.server.api.input_types.Coordinates import
|
|
47
|
-
InputCoordinate2D,
|
|
48
|
-
InputCoordinate3D,
|
|
49
|
-
)
|
|
36
|
+
from phoenix.server.api.input_types.Coordinates import InputCoordinate2D, InputCoordinate3D
|
|
50
37
|
from phoenix.server.api.input_types.DatasetSort import DatasetSort
|
|
51
|
-
from phoenix.server.api.input_types.InvocationParameters import
|
|
52
|
-
InvocationParameter,
|
|
53
|
-
)
|
|
38
|
+
from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
|
|
54
39
|
from phoenix.server.api.subscriptions import PLAYGROUND_PROJECT_NAME
|
|
55
40
|
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
56
41
|
from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
|
|
@@ -68,20 +53,16 @@ from phoenix.server.api.types.ExperimentComparison import ExperimentComparison,
|
|
|
68
53
|
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
69
54
|
from phoenix.server.api.types.Functionality import Functionality
|
|
70
55
|
from phoenix.server.api.types.GenerativeModel import GenerativeModel
|
|
71
|
-
from phoenix.server.api.types.GenerativeProvider import
|
|
72
|
-
GenerativeProvider,
|
|
73
|
-
GenerativeProviderKey,
|
|
74
|
-
)
|
|
56
|
+
from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
|
|
75
57
|
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
76
58
|
from phoenix.server.api.types.Model import Model
|
|
77
59
|
from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
|
|
78
|
-
from phoenix.server.api.types.pagination import
|
|
79
|
-
ConnectionArgs,
|
|
80
|
-
CursorString,
|
|
81
|
-
connection_from_list,
|
|
82
|
-
)
|
|
60
|
+
from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
|
|
83
61
|
from phoenix.server.api.types.Project import Project
|
|
84
62
|
from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
|
|
63
|
+
from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
|
|
64
|
+
from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
|
|
65
|
+
from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
|
|
85
66
|
from phoenix.server.api.types.SortDir import SortDir
|
|
86
67
|
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
87
68
|
from phoenix.server.api.types.SystemApiKey import SystemApiKey
|
|
@@ -587,6 +568,31 @@ class Query:
|
|
|
587
568
|
):
|
|
588
569
|
raise NotFound(f"Unknown user: {id}")
|
|
589
570
|
return to_gql_project_session(project_session)
|
|
571
|
+
elif type_name == Prompt.__name__:
|
|
572
|
+
async with info.context.db() as session:
|
|
573
|
+
if orm_prompt := await session.scalar(
|
|
574
|
+
select(models.Prompt).where(models.Prompt.id == node_id)
|
|
575
|
+
):
|
|
576
|
+
return to_gql_prompt_from_orm(orm_prompt)
|
|
577
|
+
else:
|
|
578
|
+
raise NotFound(f"Unknown prompt: {id}")
|
|
579
|
+
elif type_name == PromptVersion.__name__:
|
|
580
|
+
async with info.context.db() as session:
|
|
581
|
+
if orm_prompt_version := await session.scalar(
|
|
582
|
+
select(models.PromptVersion).where(models.PromptVersion.id == node_id)
|
|
583
|
+
):
|
|
584
|
+
return to_gql_prompt_version(orm_prompt_version)
|
|
585
|
+
else:
|
|
586
|
+
raise NotFound(f"Unknown prompt version: {id}")
|
|
587
|
+
elif type_name == PromptLabel.__name__:
|
|
588
|
+
async with info.context.db() as session:
|
|
589
|
+
if not (
|
|
590
|
+
prompt_label := await session.scalar(
|
|
591
|
+
select(models.PromptLabel).where(models.PromptLabel.id == node_id)
|
|
592
|
+
)
|
|
593
|
+
):
|
|
594
|
+
raise NotFound(f"Unknown prompt label: {id}")
|
|
595
|
+
return to_gql_prompt_label(prompt_label)
|
|
590
596
|
raise NotFound(f"Unknown node type: {type_name}")
|
|
591
597
|
|
|
592
598
|
@strawberry.field
|
|
@@ -609,6 +615,53 @@ class Query:
|
|
|
609
615
|
return None
|
|
610
616
|
return to_gql_user(user)
|
|
611
617
|
|
|
618
|
+
@strawberry.field
|
|
619
|
+
async def prompts(
|
|
620
|
+
self,
|
|
621
|
+
info: Info[Context, None],
|
|
622
|
+
first: Optional[int] = 50,
|
|
623
|
+
last: Optional[int] = UNSET,
|
|
624
|
+
after: Optional[CursorString] = UNSET,
|
|
625
|
+
before: Optional[CursorString] = UNSET,
|
|
626
|
+
) -> Connection[Prompt]:
|
|
627
|
+
args = ConnectionArgs(
|
|
628
|
+
first=first,
|
|
629
|
+
after=after if isinstance(after, CursorString) else None,
|
|
630
|
+
last=last,
|
|
631
|
+
before=before if isinstance(before, CursorString) else None,
|
|
632
|
+
)
|
|
633
|
+
stmt = select(models.Prompt)
|
|
634
|
+
async with info.context.db() as session:
|
|
635
|
+
orm_prompts = await session.stream_scalars(stmt)
|
|
636
|
+
data = [to_gql_prompt_from_orm(orm_prompt) async for orm_prompt in orm_prompts]
|
|
637
|
+
return connection_from_list(
|
|
638
|
+
data=data,
|
|
639
|
+
args=args,
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
@strawberry.field
|
|
643
|
+
async def prompt_labels(
|
|
644
|
+
self,
|
|
645
|
+
info: Info[Context, None],
|
|
646
|
+
first: Optional[int] = 50,
|
|
647
|
+
last: Optional[int] = UNSET,
|
|
648
|
+
after: Optional[CursorString] = UNSET,
|
|
649
|
+
before: Optional[CursorString] = UNSET,
|
|
650
|
+
) -> Connection[PromptLabel]:
|
|
651
|
+
args = ConnectionArgs(
|
|
652
|
+
first=first,
|
|
653
|
+
after=after if isinstance(after, CursorString) else None,
|
|
654
|
+
last=last,
|
|
655
|
+
before=before if isinstance(before, CursorString) else None,
|
|
656
|
+
)
|
|
657
|
+
async with info.context.db() as session:
|
|
658
|
+
prompt_labels = await session.stream_scalars(select(models.PromptLabel))
|
|
659
|
+
data = [to_gql_prompt_label(prompt_label) async for prompt_label in prompt_labels]
|
|
660
|
+
return connection_from_list(
|
|
661
|
+
data=data,
|
|
662
|
+
args=args,
|
|
663
|
+
)
|
|
664
|
+
|
|
612
665
|
@strawberry.field
|
|
613
666
|
def clusters(
|
|
614
667
|
self,
|
|
@@ -11,9 +11,10 @@ from authlib.jose import jwt
|
|
|
11
11
|
from authlib.jose.errors import JoseError
|
|
12
12
|
from fastapi import APIRouter, Cookie, Depends, Path, Query, Request
|
|
13
13
|
from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update
|
|
14
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
14
15
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
15
16
|
from sqlalchemy.orm import joinedload
|
|
16
|
-
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
|
|
17
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
17
18
|
from starlette.datastructures import URL, URLPath
|
|
18
19
|
from starlette.responses import RedirectResponse
|
|
19
20
|
from starlette.routing import Router
|
|
@@ -323,7 +324,7 @@ async def _update_user_email(session: AsyncSession, /, *, user_id: int, email: s
|
|
|
323
324
|
.values(email=email)
|
|
324
325
|
.options(joinedload(models.User.role))
|
|
325
326
|
)
|
|
326
|
-
except
|
|
327
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
327
328
|
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
|
|
328
329
|
user = await session.scalar(
|
|
329
330
|
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
|