arize-phoenix 7.12.3__py3-none-any.whl → 8.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (80) hide show
  1. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/METADATA +31 -28
  2. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/RECORD +70 -47
  3. phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py +197 -0
  4. phoenix/db/models.py +307 -0
  5. phoenix/db/types/__init__.py +0 -0
  6. phoenix/db/types/identifier.py +7 -0
  7. phoenix/db/types/model_provider.py +8 -0
  8. phoenix/server/api/context.py +2 -0
  9. phoenix/server/api/dataloaders/__init__.py +2 -0
  10. phoenix/server/api/dataloaders/prompt_version_sequence_number.py +35 -0
  11. phoenix/server/api/helpers/jsonschema.py +135 -0
  12. phoenix/server/api/helpers/playground_clients.py +15 -15
  13. phoenix/server/api/helpers/playground_spans.py +9 -0
  14. phoenix/server/api/helpers/prompts/__init__.py +0 -0
  15. phoenix/server/api/helpers/prompts/conversions/__init__.py +0 -0
  16. phoenix/server/api/helpers/prompts/conversions/anthropic.py +87 -0
  17. phoenix/server/api/helpers/prompts/conversions/openai.py +78 -0
  18. phoenix/server/api/helpers/prompts/models.py +575 -0
  19. phoenix/server/api/input_types/ChatCompletionInput.py +9 -4
  20. phoenix/server/api/input_types/PromptTemplateOptions.py +10 -0
  21. phoenix/server/api/input_types/PromptVersionInput.py +133 -0
  22. phoenix/server/api/mutations/__init__.py +6 -0
  23. phoenix/server/api/mutations/chat_mutations.py +18 -16
  24. phoenix/server/api/mutations/prompt_label_mutations.py +191 -0
  25. phoenix/server/api/mutations/prompt_mutations.py +312 -0
  26. phoenix/server/api/mutations/prompt_version_tag_mutations.py +148 -0
  27. phoenix/server/api/mutations/user_mutations.py +7 -6
  28. phoenix/server/api/openapi/schema.py +1 -0
  29. phoenix/server/api/queries.py +84 -31
  30. phoenix/server/api/routers/oauth2.py +3 -2
  31. phoenix/server/api/routers/v1/__init__.py +2 -0
  32. phoenix/server/api/routers/v1/datasets.py +1 -1
  33. phoenix/server/api/routers/v1/experiment_evaluations.py +1 -1
  34. phoenix/server/api/routers/v1/experiment_runs.py +1 -1
  35. phoenix/server/api/routers/v1/experiments.py +1 -1
  36. phoenix/server/api/routers/v1/models.py +45 -0
  37. phoenix/server/api/routers/v1/prompts.py +415 -0
  38. phoenix/server/api/routers/v1/spans.py +1 -1
  39. phoenix/server/api/routers/v1/traces.py +1 -1
  40. phoenix/server/api/routers/v1/utils.py +1 -1
  41. phoenix/server/api/subscriptions.py +21 -24
  42. phoenix/server/api/types/GenerativeProvider.py +4 -4
  43. phoenix/server/api/types/Identifier.py +15 -0
  44. phoenix/server/api/types/Project.py +5 -7
  45. phoenix/server/api/types/Prompt.py +134 -0
  46. phoenix/server/api/types/PromptLabel.py +41 -0
  47. phoenix/server/api/types/PromptVersion.py +148 -0
  48. phoenix/server/api/types/PromptVersionTag.py +27 -0
  49. phoenix/server/api/types/PromptVersionTemplate.py +148 -0
  50. phoenix/server/api/types/ResponseFormat.py +9 -0
  51. phoenix/server/api/types/ToolDefinition.py +9 -0
  52. phoenix/server/app.py +3 -0
  53. phoenix/server/static/.vite/manifest.json +45 -45
  54. phoenix/server/static/assets/components-B-qgPyHv.js +2699 -0
  55. phoenix/server/static/assets/index-D4KO1IcF.js +1125 -0
  56. phoenix/server/static/assets/pages-DdcuL3Rh.js +5634 -0
  57. phoenix/server/static/assets/vendor-DQp7CrDA.js +894 -0
  58. phoenix/server/static/assets/vendor-arizeai-C1nEIEQq.js +657 -0
  59. phoenix/server/static/assets/vendor-codemirror-BZXYUIkP.js +24 -0
  60. phoenix/server/static/assets/vendor-recharts-BUFpwCVD.js +59 -0
  61. phoenix/server/static/assets/{vendor-shiki-Cl9QBraO.js → vendor-shiki-C8L-c9jT.js} +2 -2
  62. phoenix/server/static/assets/{vendor-three-DwGkEfCM.js → vendor-three-C-AGeJYv.js} +1 -1
  63. phoenix/session/client.py +25 -21
  64. phoenix/utilities/client.py +6 -0
  65. phoenix/version.py +1 -1
  66. phoenix/server/api/input_types/TemplateOptions.py +0 -10
  67. phoenix/server/api/routers/v1/pydantic_compat.py +0 -78
  68. phoenix/server/api/types/TemplateLanguage.py +0 -10
  69. phoenix/server/static/assets/components-DckIzNmE.js +0 -2125
  70. phoenix/server/static/assets/index-Bf25Ogon.js +0 -113
  71. phoenix/server/static/assets/pages-DL7J9q9w.js +0 -4463
  72. phoenix/server/static/assets/vendor-DvC8cT4X.js +0 -894
  73. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +0 -662
  74. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +0 -24
  75. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +0 -59
  76. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/WHEEL +0 -0
  77. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/entry_points.txt +0 -0
  78. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/licenses/IP_NOTICE +0 -0
  79. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/licenses/LICENSE +0 -0
  80. /phoenix/server/static/assets/{vendor-DxkFTwjz.css → vendor-Cg6lcjUC.css} +0 -0
@@ -0,0 +1,312 @@
1
+ from typing import Any, Optional, Union, cast
2
+
3
+ import strawberry
4
+ from fastapi import Request
5
+ from pydantic import ValidationError
6
+ from sqlalchemy import delete, select, update
7
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
8
+ from sqlalchemy.orm import joinedload
9
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
10
+ from strawberry.relay.types import GlobalID
11
+ from strawberry.types import Info
12
+
13
+ from phoenix.db import models
14
+ from phoenix.db.types.identifier import Identifier as IdentifierModel
15
+ from phoenix.db.types.model_provider import ModelProvider
16
+ from phoenix.server.api.context import Context
17
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
18
+ from phoenix.server.api.helpers.prompts.models import (
19
+ normalize_response_format,
20
+ normalize_tools,
21
+ validate_invocation_parameters,
22
+ )
23
+ from phoenix.server.api.input_types.PromptVersionInput import (
24
+ ChatPromptVersionInput,
25
+ to_pydantic_prompt_chat_template_v1,
26
+ )
27
+ from phoenix.server.api.mutations.prompt_version_tag_mutations import (
28
+ SetPromptVersionTagInput,
29
+ upsert_prompt_version_tag,
30
+ )
31
+ from phoenix.server.api.queries import Query
32
+ from phoenix.server.api.types.Identifier import Identifier
33
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
34
+ from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
35
+ from phoenix.server.bearer_auth import PhoenixUser
36
+
37
+
38
+ @strawberry.input
39
+ class CreateChatPromptInput:
40
+ name: Identifier
41
+ description: Optional[str] = None
42
+ prompt_version: ChatPromptVersionInput
43
+
44
+
45
+ @strawberry.input
46
+ class CreateChatPromptVersionInput:
47
+ prompt_id: GlobalID
48
+ prompt_version: ChatPromptVersionInput
49
+ tags: Optional[list[SetPromptVersionTagInput]] = None
50
+
51
+
52
+ @strawberry.input
53
+ class DeletePromptInput:
54
+ prompt_id: GlobalID
55
+
56
+
57
+ @strawberry.input
58
+ class ClonePromptInput:
59
+ name: Identifier
60
+ description: Optional[str] = None
61
+ prompt_id: GlobalID
62
+
63
+
64
+ @strawberry.input
65
+ class PatchPromptInput:
66
+ prompt_id: GlobalID
67
+ description: str
68
+
69
+
70
+ @strawberry.type
71
+ class DeletePromptMutationPayload:
72
+ query: Query
73
+
74
+
75
+ @strawberry.type
76
+ class PromptMutationMixin:
77
+ @strawberry.mutation
78
+ async def create_chat_prompt(
79
+ self, info: Info[Context, None], input: CreateChatPromptInput
80
+ ) -> Prompt:
81
+ user_id: Optional[int] = None
82
+ assert isinstance(request := info.context.request, Request)
83
+ if "user" in request.scope:
84
+ assert isinstance(user := request.user, PhoenixUser)
85
+ user_id = int(user.identity)
86
+
87
+ input_prompt_version = input.prompt_version
88
+ tool_definitions = [tool.definition for tool in input_prompt_version.tools]
89
+ tool_choice = cast(
90
+ Optional[Union[str, dict[str, Any]]],
91
+ cast(dict[str, Any], input.prompt_version.invocation_parameters).pop(
92
+ "tool_choice", None
93
+ ),
94
+ )
95
+ model_provider = ModelProvider(input_prompt_version.model_provider)
96
+ try:
97
+ tools = (
98
+ normalize_tools(tool_definitions, model_provider, tool_choice)
99
+ if tool_definitions
100
+ else None
101
+ )
102
+ template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
103
+ response_format = (
104
+ normalize_response_format(
105
+ input_prompt_version.response_format.definition,
106
+ model_provider,
107
+ )
108
+ if input_prompt_version.response_format
109
+ else None
110
+ )
111
+ invocation_parameters = validate_invocation_parameters(
112
+ input_prompt_version.invocation_parameters,
113
+ model_provider,
114
+ )
115
+ except ValidationError as error:
116
+ raise BadRequest(str(error))
117
+
118
+ async with info.context.db() as session:
119
+ prompt_version = models.PromptVersion(
120
+ description=input_prompt_version.description,
121
+ user_id=user_id,
122
+ template_type="CHAT",
123
+ template_format=input_prompt_version.template_format,
124
+ template=template,
125
+ invocation_parameters=invocation_parameters,
126
+ tools=tools,
127
+ response_format=response_format,
128
+ model_provider=input_prompt_version.model_provider,
129
+ model_name=input_prompt_version.model_name,
130
+ )
131
+ name = IdentifierModel.model_validate(str(input.name))
132
+ prompt = models.Prompt(
133
+ name=name,
134
+ description=input.description,
135
+ prompt_versions=[prompt_version],
136
+ )
137
+ session.add(prompt)
138
+ try:
139
+ await session.commit()
140
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
141
+ raise Conflict(f"A prompt named '{input.name}' already exists")
142
+ return to_gql_prompt_from_orm(prompt)
143
+
144
+ @strawberry.mutation
145
+ async def create_chat_prompt_version(
146
+ self,
147
+ info: Info[Context, None],
148
+ input: CreateChatPromptVersionInput,
149
+ ) -> Prompt:
150
+ user_id: Optional[int] = None
151
+ assert isinstance(request := info.context.request, Request)
152
+ if "user" in request.scope:
153
+ assert isinstance(user := request.user, PhoenixUser)
154
+ user_id = int(user.identity)
155
+
156
+ input_prompt_version = input.prompt_version
157
+ tool_definitions = [tool.definition for tool in input.prompt_version.tools]
158
+ tool_choice = cast(
159
+ Optional[Union[str, dict[str, Any]]],
160
+ cast(dict[str, Any], input.prompt_version.invocation_parameters).pop(
161
+ "tool_choice", None
162
+ ),
163
+ )
164
+ model_provider = ModelProvider(input_prompt_version.model_provider)
165
+ try:
166
+ tools = (
167
+ normalize_tools(tool_definitions, model_provider, tool_choice)
168
+ if tool_definitions
169
+ else None
170
+ )
171
+ template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
172
+ response_format = (
173
+ normalize_response_format(
174
+ input_prompt_version.response_format.definition,
175
+ model_provider,
176
+ )
177
+ if input_prompt_version.response_format
178
+ else None
179
+ )
180
+ invocation_parameters = validate_invocation_parameters(
181
+ input_prompt_version.invocation_parameters,
182
+ model_provider,
183
+ )
184
+ except ValidationError as error:
185
+ raise BadRequest(str(error))
186
+
187
+ prompt_id = from_global_id_with_expected_type(
188
+ global_id=input.prompt_id, expected_type_name=Prompt.__name__
189
+ )
190
+ async with info.context.db() as session:
191
+ prompt = await session.get(models.Prompt, prompt_id)
192
+ if not prompt:
193
+ raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
194
+
195
+ prompt_version = models.PromptVersion(
196
+ prompt_id=prompt_id,
197
+ description=input.prompt_version.description,
198
+ user_id=user_id,
199
+ template_type="CHAT",
200
+ template_format=input.prompt_version.template_format,
201
+ template=template,
202
+ invocation_parameters=invocation_parameters,
203
+ tools=tools,
204
+ response_format=response_format,
205
+ model_provider=input.prompt_version.model_provider,
206
+ model_name=input.prompt_version.model_name,
207
+ )
208
+ session.add(prompt_version)
209
+
210
+ # ensure prompt_version is flushed to the database before creating tags against the
211
+ # prompt_version id
212
+ await session.flush()
213
+
214
+ if input.tags:
215
+ for tag in input.tags:
216
+ await upsert_prompt_version_tag(
217
+ session, prompt_id, prompt_version.id, tag.name, tag.description
218
+ )
219
+
220
+ return to_gql_prompt_from_orm(prompt)
221
+
222
+ @strawberry.mutation
223
+ async def delete_prompt(
224
+ self, info: Info[Context, None], input: DeletePromptInput
225
+ ) -> DeletePromptMutationPayload:
226
+ prompt_id = from_global_id_with_expected_type(
227
+ global_id=input.prompt_id, expected_type_name=Prompt.__name__
228
+ )
229
+ async with info.context.db() as session:
230
+ stmt = delete(models.Prompt).where(models.Prompt.id == prompt_id)
231
+ result = await session.execute(stmt)
232
+
233
+ if result.rowcount == 0:
234
+ raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
235
+
236
+ await session.commit()
237
+ return DeletePromptMutationPayload(query=Query())
238
+
239
+ @strawberry.mutation
240
+ async def clone_prompt(self, info: Info[Context, None], input: ClonePromptInput) -> Prompt:
241
+ prompt_id = from_global_id_with_expected_type(
242
+ global_id=input.prompt_id, expected_type_name=Prompt.__name__
243
+ )
244
+ async with info.context.db() as session:
245
+ # Load prompt with all versions
246
+ stmt = (
247
+ select(models.Prompt)
248
+ .options(joinedload(models.Prompt.prompt_versions))
249
+ .where(models.Prompt.id == prompt_id)
250
+ )
251
+ prompt = await session.scalar(stmt)
252
+
253
+ if not prompt:
254
+ raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
255
+
256
+ # Create new prompt
257
+ name = IdentifierModel.model_validate(str(input.name))
258
+ new_prompt = models.Prompt(
259
+ name=name,
260
+ description=input.description,
261
+ source_prompt_id=prompt_id,
262
+ )
263
+
264
+ # Create copies of all versions
265
+ new_versions = [
266
+ models.PromptVersion(
267
+ prompt_id=new_prompt.id,
268
+ user_id=version.user_id,
269
+ description=version.description,
270
+ template_type=version.template_type,
271
+ template_format=version.template_format,
272
+ template=version.template,
273
+ invocation_parameters=version.invocation_parameters,
274
+ tools=version.tools,
275
+ response_format=version.response_format,
276
+ model_provider=version.model_provider,
277
+ model_name=version.model_name,
278
+ )
279
+ for version in prompt.prompt_versions
280
+ ]
281
+ # Add all version copies to the new prompt
282
+ new_prompt.prompt_versions = new_versions
283
+
284
+ session.add(new_prompt)
285
+
286
+ try:
287
+ await session.commit()
288
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
289
+ raise Conflict(f"A prompt named '{input.name}' already exists")
290
+ return to_gql_prompt_from_orm(new_prompt)
291
+
292
+ @strawberry.mutation
293
+ async def patch_prompt(self, info: Info[Context, None], input: PatchPromptInput) -> Prompt:
294
+ prompt_id = from_global_id_with_expected_type(
295
+ global_id=input.prompt_id, expected_type_name=Prompt.__name__
296
+ )
297
+
298
+ async with info.context.db() as session:
299
+ stmt = (
300
+ update(models.Prompt)
301
+ .where(models.Prompt.id == prompt_id)
302
+ .values(description=input.description)
303
+ .returning(models.Prompt)
304
+ )
305
+
306
+ result = await session.execute(stmt)
307
+ prompt = result.scalar_one_or_none()
308
+
309
+ if prompt is None:
310
+ raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
311
+
312
+ return to_gql_prompt_from_orm(prompt)
@@ -0,0 +1,148 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from sqlalchemy import select
5
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
8
+ from strawberry.relay import GlobalID
9
+ from strawberry.types import Info
10
+
11
+ from phoenix.db import models
12
+ from phoenix.db.types.identifier import Identifier as IdentifierModel
13
+ from phoenix.server.api.context import Context
14
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
15
+ from phoenix.server.api.queries import Query
16
+ from phoenix.server.api.types.Identifier import Identifier
17
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
18
+ from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
19
+ from phoenix.server.api.types.PromptVersion import PromptVersion
20
+ from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
21
+
22
+
23
+ @strawberry.input
24
+ class DeletePromptVersionTagInput:
25
+ prompt_version_tag_id: GlobalID
26
+
27
+
28
+ @strawberry.input
29
+ class SetPromptVersionTagInput:
30
+ prompt_version_id: GlobalID
31
+ name: Identifier
32
+ description: Optional[str] = None
33
+
34
+
35
+ @strawberry.type
36
+ class PromptVersionTagMutationPayload:
37
+ prompt_version_tag: Optional[PromptVersionTag]
38
+ prompt: Prompt
39
+ query: Query
40
+
41
+
42
+ @strawberry.type
43
+ class PromptVersionTagMutationMixin:
44
+ @strawberry.mutation
45
+ async def delete_prompt_version_tag(
46
+ self, info: Info[Context, None], input: DeletePromptVersionTagInput
47
+ ) -> PromptVersionTagMutationPayload:
48
+ async with info.context.db() as session:
49
+ prompt_version_tag_id = from_global_id_with_expected_type(
50
+ input.prompt_version_tag_id, PromptVersionTag.__name__
51
+ )
52
+ stmt = (
53
+ select(models.PromptVersionTag, models.Prompt)
54
+ .join(
55
+ models.PromptVersion,
56
+ models.PromptVersion.id == models.PromptVersionTag.prompt_version_id,
57
+ )
58
+ .join(models.Prompt, models.Prompt.id == models.PromptVersion.prompt_id)
59
+ .where(models.PromptVersionTag.id == prompt_version_tag_id)
60
+ )
61
+ result = await session.execute(stmt)
62
+ if results := result.one_or_none():
63
+ prompt_version_tag, prompt = results
64
+
65
+ if not prompt_version_tag:
66
+ raise NotFound(f"PromptVersionTag with ID {input.prompt_version_tag_id} not found")
67
+
68
+ if not prompt:
69
+ raise BadRequest(
70
+ f"PromptVersionTag with ID {input.prompt_version_tag_id} "
71
+ "does not belong to a prompt"
72
+ )
73
+
74
+ await session.delete(prompt_version_tag)
75
+ await session.commit()
76
+ return PromptVersionTagMutationPayload(
77
+ prompt_version_tag=None, query=Query(), prompt=to_gql_prompt_from_orm(prompt)
78
+ )
79
+
80
+ @strawberry.mutation
81
+ async def set_prompt_version_tag(
82
+ self, info: Info[Context, None], input: SetPromptVersionTagInput
83
+ ) -> PromptVersionTagMutationPayload:
84
+ async with info.context.db() as session:
85
+ prompt_version_id = from_global_id_with_expected_type(
86
+ input.prompt_version_id, PromptVersion.__name__
87
+ )
88
+ prompt_version = await session.scalar(
89
+ select(models.PromptVersion).where(models.PromptVersion.id == prompt_version_id)
90
+ )
91
+ if not prompt_version:
92
+ raise BadRequest(f"PromptVersion with ID {input.prompt_version_id} not found.")
93
+
94
+ prompt_id = prompt_version.prompt_id
95
+ prompt = await session.scalar(
96
+ select(models.Prompt).where(models.Prompt.id == prompt_id)
97
+ )
98
+ if not prompt:
99
+ raise BadRequest("All prompt version tags must belong to a prompt")
100
+
101
+ updated_tag = await upsert_prompt_version_tag(
102
+ session, prompt_id, prompt_version_id, input.name, input.description
103
+ )
104
+
105
+ if not updated_tag:
106
+ raise BadRequest("Failed to create or update PromptVersionTag.")
107
+
108
+ try:
109
+ await session.commit()
110
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
111
+ raise Conflict("Failed to update PromptVersionTag.")
112
+
113
+ version_tag = to_gql_prompt_version_tag(updated_tag)
114
+ return PromptVersionTagMutationPayload(
115
+ prompt_version_tag=version_tag, prompt=to_gql_prompt_from_orm(prompt), query=Query()
116
+ )
117
+
118
+
119
+ async def upsert_prompt_version_tag(
120
+ session: AsyncSession,
121
+ prompt_id: int,
122
+ prompt_version_id: int,
123
+ name_str: str,
124
+ description: Optional[str] = None,
125
+ ) -> models.PromptVersionTag:
126
+ name = IdentifierModel.model_validate(name_str)
127
+
128
+ existing_tag = await session.scalar(
129
+ select(models.PromptVersionTag).where(
130
+ models.PromptVersionTag.prompt_id == prompt_id,
131
+ models.PromptVersionTag.name == name,
132
+ )
133
+ )
134
+
135
+ if existing_tag:
136
+ existing_tag.prompt_version_id = prompt_version_id
137
+ if description is not None:
138
+ existing_tag.description = description
139
+ return existing_tag
140
+ else:
141
+ new_tag = models.PromptVersionTag(
142
+ name=name,
143
+ description=description,
144
+ prompt_id=prompt_id,
145
+ prompt_version_id=prompt_version_id,
146
+ )
147
+ session.add(new_tag)
148
+ return new_tag
@@ -1,12 +1,13 @@
1
1
  import secrets
2
2
  from contextlib import AsyncExitStack
3
3
  from datetime import datetime, timezone
4
- from typing import Literal, Optional
4
+ from typing import Literal, Optional, Union
5
5
 
6
6
  import strawberry
7
7
  from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select
8
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
8
9
  from sqlalchemy.orm import joinedload
9
- from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
10
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
10
11
  from strawberry import UNSET
11
12
  from strawberry.relay import GlobalID
12
13
  from strawberry.types import Info
@@ -108,7 +109,7 @@ class UserMutationMixin:
108
109
  session.add(user)
109
110
  try:
110
111
  await session.flush()
111
- except IntegrityError as error:
112
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as error:
112
113
  raise Conflict(_user_operation_error_message(error))
113
114
  return UserMutationPayload(user=to_gql_user(user))
114
115
 
@@ -148,7 +149,7 @@ class UserMutationMixin:
148
149
  assert user in session.dirty
149
150
  try:
150
151
  await session.flush()
151
- except IntegrityError as error:
152
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as error:
152
153
  raise Conflict(_user_operation_error_message(error, "modify"))
153
154
  assert user
154
155
  if input.new_password:
@@ -186,7 +187,7 @@ class UserMutationMixin:
186
187
  user.updated_at = datetime.now(timezone.utc)
187
188
  try:
188
189
  await session.flush()
189
- except IntegrityError as error:
190
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as error:
190
191
  raise Conflict(_user_operation_error_message(error, "modify"))
191
192
  assert user
192
193
  if input.new_password:
@@ -313,7 +314,7 @@ def _select_user_by_id(user_id: int) -> Select[tuple[models.User]]:
313
314
 
314
315
 
315
316
  def _user_operation_error_message(
316
- error: IntegrityError,
317
+ error: Union[PostgreSQLIntegrityError, SQLiteIntegrityError],
317
318
  operation: Literal["create", "modify"] = "create",
318
319
  ) -> str:
319
320
  """
@@ -13,4 +13,5 @@ def get_openapi_schema() -> dict[str, Any]:
13
13
  openapi_version="3.1.0",
14
14
  description="Schema for Arize-Phoenix REST API",
15
15
  routes=v1_router.routes,
16
+ separate_input_output_schemas=False,
16
17
  )
@@ -14,22 +14,12 @@ from strawberry.types import Info
14
14
  from typing_extensions import Annotated, TypeAlias
15
15
 
16
16
  from phoenix.db import enums, models
17
- from phoenix.db.models import (
18
- DatasetExample as OrmExample,
19
- )
20
- from phoenix.db.models import (
21
- DatasetExampleRevision as OrmRevision,
22
- )
23
- from phoenix.db.models import (
24
- DatasetVersion as OrmVersion,
25
- )
26
- from phoenix.db.models import (
27
- Experiment as OrmExperiment,
28
- )
17
+ from phoenix.db.models import DatasetExample as OrmExample
18
+ from phoenix.db.models import DatasetExampleRevision as OrmRevision
19
+ from phoenix.db.models import DatasetVersion as OrmVersion
20
+ from phoenix.db.models import Experiment as OrmExperiment
29
21
  from phoenix.db.models import ExperimentRun as OrmExperimentRun
30
- from phoenix.db.models import (
31
- Trace as OrmTrace,
32
- )
22
+ from phoenix.db.models import Trace as OrmTrace
33
23
  from phoenix.pointcloud.clustering import Hdbscan
34
24
  from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
35
25
  from phoenix.server.api.context import Context
@@ -43,14 +33,9 @@ from phoenix.server.api.helpers.experiment_run_filters import (
43
33
  from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
44
34
  from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
45
35
  from phoenix.server.api.input_types.ClusterInput import ClusterInput
46
- from phoenix.server.api.input_types.Coordinates import (
47
- InputCoordinate2D,
48
- InputCoordinate3D,
49
- )
36
+ from phoenix.server.api.input_types.Coordinates import InputCoordinate2D, InputCoordinate3D
50
37
  from phoenix.server.api.input_types.DatasetSort import DatasetSort
51
- from phoenix.server.api.input_types.InvocationParameters import (
52
- InvocationParameter,
53
- )
38
+ from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
54
39
  from phoenix.server.api.subscriptions import PLAYGROUND_PROJECT_NAME
55
40
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
56
41
  from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
@@ -68,20 +53,16 @@ from phoenix.server.api.types.ExperimentComparison import ExperimentComparison,
68
53
  from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
69
54
  from phoenix.server.api.types.Functionality import Functionality
70
55
  from phoenix.server.api.types.GenerativeModel import GenerativeModel
71
- from phoenix.server.api.types.GenerativeProvider import (
72
- GenerativeProvider,
73
- GenerativeProviderKey,
74
- )
56
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
75
57
  from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
76
58
  from phoenix.server.api.types.Model import Model
77
59
  from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
78
- from phoenix.server.api.types.pagination import (
79
- ConnectionArgs,
80
- CursorString,
81
- connection_from_list,
82
- )
60
+ from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
83
61
  from phoenix.server.api.types.Project import Project
84
62
  from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
63
+ from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
64
+ from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
65
+ from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
85
66
  from phoenix.server.api.types.SortDir import SortDir
86
67
  from phoenix.server.api.types.Span import Span, to_gql_span
87
68
  from phoenix.server.api.types.SystemApiKey import SystemApiKey
@@ -587,6 +568,31 @@ class Query:
587
568
  ):
588
569
  raise NotFound(f"Unknown user: {id}")
589
570
  return to_gql_project_session(project_session)
571
+ elif type_name == Prompt.__name__:
572
+ async with info.context.db() as session:
573
+ if orm_prompt := await session.scalar(
574
+ select(models.Prompt).where(models.Prompt.id == node_id)
575
+ ):
576
+ return to_gql_prompt_from_orm(orm_prompt)
577
+ else:
578
+ raise NotFound(f"Unknown prompt: {id}")
579
+ elif type_name == PromptVersion.__name__:
580
+ async with info.context.db() as session:
581
+ if orm_prompt_version := await session.scalar(
582
+ select(models.PromptVersion).where(models.PromptVersion.id == node_id)
583
+ ):
584
+ return to_gql_prompt_version(orm_prompt_version)
585
+ else:
586
+ raise NotFound(f"Unknown prompt version: {id}")
587
+ elif type_name == PromptLabel.__name__:
588
+ async with info.context.db() as session:
589
+ if not (
590
+ prompt_label := await session.scalar(
591
+ select(models.PromptLabel).where(models.PromptLabel.id == node_id)
592
+ )
593
+ ):
594
+ raise NotFound(f"Unknown prompt label: {id}")
595
+ return to_gql_prompt_label(prompt_label)
590
596
  raise NotFound(f"Unknown node type: {type_name}")
591
597
 
592
598
  @strawberry.field
@@ -609,6 +615,53 @@ class Query:
609
615
  return None
610
616
  return to_gql_user(user)
611
617
 
618
+ @strawberry.field
619
+ async def prompts(
620
+ self,
621
+ info: Info[Context, None],
622
+ first: Optional[int] = 50,
623
+ last: Optional[int] = UNSET,
624
+ after: Optional[CursorString] = UNSET,
625
+ before: Optional[CursorString] = UNSET,
626
+ ) -> Connection[Prompt]:
627
+ args = ConnectionArgs(
628
+ first=first,
629
+ after=after if isinstance(after, CursorString) else None,
630
+ last=last,
631
+ before=before if isinstance(before, CursorString) else None,
632
+ )
633
+ stmt = select(models.Prompt)
634
+ async with info.context.db() as session:
635
+ orm_prompts = await session.stream_scalars(stmt)
636
+ data = [to_gql_prompt_from_orm(orm_prompt) async for orm_prompt in orm_prompts]
637
+ return connection_from_list(
638
+ data=data,
639
+ args=args,
640
+ )
641
+
642
+ @strawberry.field
643
+ async def prompt_labels(
644
+ self,
645
+ info: Info[Context, None],
646
+ first: Optional[int] = 50,
647
+ last: Optional[int] = UNSET,
648
+ after: Optional[CursorString] = UNSET,
649
+ before: Optional[CursorString] = UNSET,
650
+ ) -> Connection[PromptLabel]:
651
+ args = ConnectionArgs(
652
+ first=first,
653
+ after=after if isinstance(after, CursorString) else None,
654
+ last=last,
655
+ before=before if isinstance(before, CursorString) else None,
656
+ )
657
+ async with info.context.db() as session:
658
+ prompt_labels = await session.stream_scalars(select(models.PromptLabel))
659
+ data = [to_gql_prompt_label(prompt_label) async for prompt_label in prompt_labels]
660
+ return connection_from_list(
661
+ data=data,
662
+ args=args,
663
+ )
664
+
612
665
  @strawberry.field
613
666
  def clusters(
614
667
  self,
@@ -11,9 +11,10 @@ from authlib.jose import jwt
11
11
  from authlib.jose.errors import JoseError
12
12
  from fastapi import APIRouter, Cookie, Depends, Path, Query, Request
13
13
  from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update
14
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
14
15
  from sqlalchemy.ext.asyncio import AsyncSession
15
16
  from sqlalchemy.orm import joinedload
16
- from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
17
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
17
18
  from starlette.datastructures import URL, URLPath
18
19
  from starlette.responses import RedirectResponse
19
20
  from starlette.routing import Router
@@ -323,7 +324,7 @@ async def _update_user_email(session: AsyncSession, /, *, user_id: int, email: s
323
324
  .values(email=email)
324
325
  .options(joinedload(models.User.role))
325
326
  )
326
- except IntegrityError:
327
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
327
328
  raise EmailAlreadyInUse(f"An account for {email} is already in use.")
328
329
  user = await session.scalar(
329
330
  select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))