arize-phoenix 7.12.3__py3-none-any.whl → 8.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (80) hide show
  1. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/METADATA +31 -28
  2. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/RECORD +70 -47
  3. phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py +197 -0
  4. phoenix/db/models.py +307 -0
  5. phoenix/db/types/__init__.py +0 -0
  6. phoenix/db/types/identifier.py +7 -0
  7. phoenix/db/types/model_provider.py +8 -0
  8. phoenix/server/api/context.py +2 -0
  9. phoenix/server/api/dataloaders/__init__.py +2 -0
  10. phoenix/server/api/dataloaders/prompt_version_sequence_number.py +35 -0
  11. phoenix/server/api/helpers/jsonschema.py +135 -0
  12. phoenix/server/api/helpers/playground_clients.py +15 -15
  13. phoenix/server/api/helpers/playground_spans.py +9 -0
  14. phoenix/server/api/helpers/prompts/__init__.py +0 -0
  15. phoenix/server/api/helpers/prompts/conversions/__init__.py +0 -0
  16. phoenix/server/api/helpers/prompts/conversions/anthropic.py +87 -0
  17. phoenix/server/api/helpers/prompts/conversions/openai.py +78 -0
  18. phoenix/server/api/helpers/prompts/models.py +575 -0
  19. phoenix/server/api/input_types/ChatCompletionInput.py +9 -4
  20. phoenix/server/api/input_types/PromptTemplateOptions.py +10 -0
  21. phoenix/server/api/input_types/PromptVersionInput.py +133 -0
  22. phoenix/server/api/mutations/__init__.py +6 -0
  23. phoenix/server/api/mutations/chat_mutations.py +18 -16
  24. phoenix/server/api/mutations/prompt_label_mutations.py +191 -0
  25. phoenix/server/api/mutations/prompt_mutations.py +312 -0
  26. phoenix/server/api/mutations/prompt_version_tag_mutations.py +148 -0
  27. phoenix/server/api/mutations/user_mutations.py +7 -6
  28. phoenix/server/api/openapi/schema.py +1 -0
  29. phoenix/server/api/queries.py +84 -31
  30. phoenix/server/api/routers/oauth2.py +3 -2
  31. phoenix/server/api/routers/v1/__init__.py +2 -0
  32. phoenix/server/api/routers/v1/datasets.py +1 -1
  33. phoenix/server/api/routers/v1/experiment_evaluations.py +1 -1
  34. phoenix/server/api/routers/v1/experiment_runs.py +1 -1
  35. phoenix/server/api/routers/v1/experiments.py +1 -1
  36. phoenix/server/api/routers/v1/models.py +45 -0
  37. phoenix/server/api/routers/v1/prompts.py +415 -0
  38. phoenix/server/api/routers/v1/spans.py +1 -1
  39. phoenix/server/api/routers/v1/traces.py +1 -1
  40. phoenix/server/api/routers/v1/utils.py +1 -1
  41. phoenix/server/api/subscriptions.py +21 -24
  42. phoenix/server/api/types/GenerativeProvider.py +4 -4
  43. phoenix/server/api/types/Identifier.py +15 -0
  44. phoenix/server/api/types/Project.py +5 -7
  45. phoenix/server/api/types/Prompt.py +134 -0
  46. phoenix/server/api/types/PromptLabel.py +41 -0
  47. phoenix/server/api/types/PromptVersion.py +148 -0
  48. phoenix/server/api/types/PromptVersionTag.py +27 -0
  49. phoenix/server/api/types/PromptVersionTemplate.py +148 -0
  50. phoenix/server/api/types/ResponseFormat.py +9 -0
  51. phoenix/server/api/types/ToolDefinition.py +9 -0
  52. phoenix/server/app.py +3 -0
  53. phoenix/server/static/.vite/manifest.json +45 -45
  54. phoenix/server/static/assets/components-B-qgPyHv.js +2699 -0
  55. phoenix/server/static/assets/index-D4KO1IcF.js +1125 -0
  56. phoenix/server/static/assets/pages-DdcuL3Rh.js +5634 -0
  57. phoenix/server/static/assets/vendor-DQp7CrDA.js +894 -0
  58. phoenix/server/static/assets/vendor-arizeai-C1nEIEQq.js +657 -0
  59. phoenix/server/static/assets/vendor-codemirror-BZXYUIkP.js +24 -0
  60. phoenix/server/static/assets/vendor-recharts-BUFpwCVD.js +59 -0
  61. phoenix/server/static/assets/{vendor-shiki-Cl9QBraO.js → vendor-shiki-C8L-c9jT.js} +2 -2
  62. phoenix/server/static/assets/{vendor-three-DwGkEfCM.js → vendor-three-C-AGeJYv.js} +1 -1
  63. phoenix/session/client.py +25 -21
  64. phoenix/utilities/client.py +6 -0
  65. phoenix/version.py +1 -1
  66. phoenix/server/api/input_types/TemplateOptions.py +0 -10
  67. phoenix/server/api/routers/v1/pydantic_compat.py +0 -78
  68. phoenix/server/api/types/TemplateLanguage.py +0 -10
  69. phoenix/server/static/assets/components-DckIzNmE.js +0 -2125
  70. phoenix/server/static/assets/index-Bf25Ogon.js +0 -113
  71. phoenix/server/static/assets/pages-DL7J9q9w.js +0 -4463
  72. phoenix/server/static/assets/vendor-DvC8cT4X.js +0 -894
  73. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +0 -662
  74. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +0 -24
  75. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +0 -59
  76. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/WHEEL +0 -0
  77. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/entry_points.txt +0 -0
  78. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/licenses/IP_NOTICE +0 -0
  79. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/licenses/LICENSE +0 -0
  80. /phoenix/server/static/assets/{vendor-DxkFTwjz.css → vendor-Cg6lcjUC.css} +0 -0
@@ -9,6 +9,7 @@ from .evaluations import router as evaluations_router
9
9
  from .experiment_evaluations import router as experiment_evaluations_router
10
10
  from .experiment_runs import router as experiment_runs_router
11
11
  from .experiments import router as experiments_router
12
+ from .prompts import router as prompts_router
12
13
  from .spans import router as spans_router
13
14
  from .traces import router as traces_router
14
15
  from .utils import add_errors_to_responses
@@ -61,4 +62,5 @@ def create_v1_router(authentication_enabled: bool) -> APIRouter:
61
62
  router.include_router(traces_router)
62
63
  router.include_router(spans_router)
63
64
  router.include_router(evaluations_router)
65
+ router.include_router(prompts_router)
64
66
  return router
@@ -48,7 +48,7 @@ from phoenix.server.api.types.node import from_global_id_with_expected_type
48
48
  from phoenix.server.api.utils import delete_projects, delete_traces
49
49
  from phoenix.server.dml_event import DatasetInsertEvent
50
50
 
51
- from .pydantic_compat import V1RoutesBaseModel
51
+ from .models import V1RoutesBaseModel
52
52
  from .utils import (
53
53
  PaginatedResponseBody,
54
54
  ResponseBody,
@@ -13,7 +13,7 @@ from phoenix.db.insertion.helpers import insert_on_conflict
13
13
  from phoenix.server.api.types.node import from_global_id_with_expected_type
14
14
  from phoenix.server.dml_event import ExperimentRunAnnotationInsertEvent
15
15
 
16
- from .pydantic_compat import V1RoutesBaseModel
16
+ from .models import V1RoutesBaseModel
17
17
  from .utils import ResponseBody, add_errors_to_responses
18
18
 
19
19
  router = APIRouter(tags=["experiments"], include_in_schema=False)
@@ -13,7 +13,7 @@ from phoenix.db.models import ExperimentRunOutput
13
13
  from phoenix.server.api.types.node import from_global_id_with_expected_type
14
14
  from phoenix.server.dml_event import ExperimentRunInsertEvent
15
15
 
16
- from .pydantic_compat import V1RoutesBaseModel
16
+ from .models import V1RoutesBaseModel
17
17
  from .utils import ResponseBody, add_errors_to_responses
18
18
 
19
19
  router = APIRouter(tags=["experiments"], include_in_schema=False)
@@ -15,7 +15,7 @@ from phoenix.db.insertion.helpers import insert_on_conflict
15
15
  from phoenix.server.api.types.node import from_global_id_with_expected_type
16
16
  from phoenix.server.dml_event import ExperimentInsertEvent
17
17
 
18
- from .pydantic_compat import V1RoutesBaseModel
18
+ from .models import V1RoutesBaseModel
19
19
  from .utils import ResponseBody, add_errors_to_responses
20
20
 
21
21
  router = APIRouter(tags=["experiments"], include_in_schema=True)
@@ -0,0 +1,45 @@
1
+ from datetime import datetime
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+
6
+ def datetime_encoder(dt: datetime) -> str:
7
+ """
8
+ Encodes a `datetime` object to an ISO-formatted timestamp string.
9
+
10
+ By default, Pydantic v2 serializes `datetime` objects in a format that
11
+ cannot be parsed by `datetime.fromisoformat`. Adding this encoder to the
12
+ `json_encoders` config for a Pydantic model ensures that the serialized
13
+ `datetime` objects are parseable.
14
+ """
15
+ return dt.isoformat()
16
+
17
+
18
+ # `json_encoders` is a configuration setting from Pydantic v1 that was
19
+ # removed in Pydantic v2.0.* but restored in Pydantic v2.1.0 with a
20
+ # deprecation warning. At this time, it remains the simplest way to
21
+ # configure custom JSON serialization for specific data types.
22
+ #
23
+ # For details, see:
24
+ # - https://github.com/pydantic/pydantic/pull/6811
25
+ # - https://github.com/pydantic/pydantic/releases/tag/v2.1.0
26
+ #
27
+ # The assertion below is added in case a future release of Pydantic v2 fully
28
+ # removes the `json_encoders` parameter.
29
+ assert "json_encoders" in ConfigDict.__annotations__, (
30
+ "If you encounter this error with `pydantic<2.1.0`, "
31
+ "please upgrade `pydantic` with `pip install -U pydantic>=2.1.0`. "
32
+ "If you encounter this error with `pydantic>=2.1.0`, "
33
+ "please upgrade `arize-phoenix` with `pip install -U arize-phoenix`, "
34
+ "or downgrade `pydantic` to a version that supports the `json_encoders` config setting."
35
+ )
36
+
37
+
38
+ class V1RoutesBaseModel(BaseModel):
39
+ model_config = ConfigDict(
40
+ json_encoders={datetime: datetime_encoder},
41
+ validate_assignment=True,
42
+ protected_namespaces=tuple(
43
+ []
44
+ ), # suppress warnings about protected namespaces starting with `model_` on pydantic 2.9
45
+ )
@@ -0,0 +1,415 @@
1
+ import logging
2
+ from typing import Any, Optional, Union
3
+
4
+ from fastapi import APIRouter, HTTPException, Path, Query
5
+ from pydantic import ValidationError
6
+ from sqlalchemy import select
7
+ from sqlalchemy.sql import Select
8
+ from starlette.requests import Request
9
+ from starlette.status import HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY
10
+ from strawberry.relay import GlobalID
11
+ from typing_extensions import TypeAlias, assert_never
12
+
13
+ from phoenix.db import models
14
+ from phoenix.db.types.identifier import Identifier
15
+ from phoenix.db.types.model_provider import ModelProvider
16
+ from phoenix.server.api.helpers.prompts.models import (
17
+ PromptInvocationParameters,
18
+ PromptResponseFormat,
19
+ PromptTemplate,
20
+ PromptTemplateFormat,
21
+ PromptTemplateType,
22
+ PromptTools,
23
+ )
24
+ from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
25
+ from phoenix.server.api.routers.v1.utils import ResponseBody, add_errors_to_responses
26
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
27
+ from phoenix.server.api.types.Prompt import Prompt as PromptNodeType
28
+ from phoenix.server.api.types.PromptVersion import PromptVersion as PromptVersionNodeType
29
+ from phoenix.server.bearer_auth import PhoenixUser
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class PromptData(V1RoutesBaseModel):
35
+ name: Identifier
36
+ description: Optional[str] = None
37
+ source_prompt_id: Optional[str] = None
38
+
39
+
40
+ class Prompt(PromptData):
41
+ id: str
42
+
43
+
44
+ class PromptVersionData(V1RoutesBaseModel):
45
+ description: Optional[str] = None
46
+ model_provider: ModelProvider
47
+ model_name: str
48
+ template: PromptTemplate
49
+ template_type: PromptTemplateType
50
+ template_format: PromptTemplateFormat
51
+ invocation_parameters: PromptInvocationParameters
52
+ tools: Optional[PromptTools] = None
53
+ response_format: Optional[PromptResponseFormat] = None
54
+
55
+
56
+ class PromptVersion(PromptVersionData):
57
+ id: str
58
+
59
+
60
+ class GetPromptResponseBody(ResponseBody[PromptVersion]):
61
+ pass
62
+
63
+
64
+ class GetPromptsResponseBody(ResponseBody[list[Prompt]]):
65
+ pass
66
+
67
+
68
+ class GetPromptVersionsResponseBody(ResponseBody[list[PromptVersion]]):
69
+ pass
70
+
71
+
72
+ class CreatePromptRequestBody(V1RoutesBaseModel):
73
+ prompt: PromptData
74
+ version: PromptVersionData
75
+
76
+
77
+ class CreatePromptResponseBody(ResponseBody[PromptVersion]):
78
+ pass
79
+
80
+
81
+ router = APIRouter(tags=["prompts"])
82
+
83
+
84
+ @router.get(
85
+ "/prompts",
86
+ operation_id="getPrompts",
87
+ summary="Get all prompts",
88
+ responses=add_errors_to_responses(
89
+ [
90
+ HTTP_422_UNPROCESSABLE_ENTITY,
91
+ ]
92
+ ),
93
+ )
94
+ async def get_prompts(
95
+ request: Request,
96
+ cursor: Optional[str] = Query(
97
+ default=None,
98
+ description="Cursor for pagination (base64-encoded prompt ID)",
99
+ ),
100
+ limit: int = Query(
101
+ default=100, description="The max number of prompts to return at a time.", gt=0
102
+ ),
103
+ ) -> GetPromptsResponseBody:
104
+ async with request.app.state.db() as session:
105
+ query = select(models.Prompt).order_by(models.Prompt.id.desc())
106
+
107
+ if cursor:
108
+ try:
109
+ cursor_id = GlobalID.from_id(cursor).node_id
110
+ query = query.filter(models.Prompt.id <= int(cursor_id))
111
+ except ValueError:
112
+ raise HTTPException(
113
+ detail=f"Invalid cursor format: {cursor}",
114
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
115
+ )
116
+
117
+ query = query.limit(limit + 1)
118
+ result = await session.execute(query)
119
+ orm_prompts = result.scalars().all()
120
+
121
+ if not orm_prompts:
122
+ return GetPromptsResponseBody(next_cursor=None, data=[])
123
+
124
+ next_cursor = None
125
+ if len(orm_prompts) == limit + 1:
126
+ last_prompt = orm_prompts[-1]
127
+ next_cursor = str(GlobalID(PromptNodeType.__name__, str(last_prompt.id)))
128
+ orm_prompts = orm_prompts[:-1]
129
+
130
+ prompts = [_prompt_from_orm_prompt(orm_prompt) for orm_prompt in orm_prompts]
131
+ return GetPromptsResponseBody(next_cursor=next_cursor, data=prompts)
132
+
133
+
134
+ @router.get(
135
+ "/prompts/{prompt_identifier}/versions",
136
+ operation_id="listPromptVersions",
137
+ summary="List all prompt versions for a given prompt",
138
+ responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
139
+ response_model_by_alias=True,
140
+ response_model_exclude_defaults=True,
141
+ response_model_exclude_unset=True,
142
+ )
143
+ async def list_prompt_versions(
144
+ request: Request,
145
+ prompt_identifier: str = Path(description="The identifier of the prompt, i.e. name or ID."),
146
+ cursor: Optional[str] = Query(
147
+ default=None,
148
+ description="Cursor for pagination (base64-encoded promptVersion ID)",
149
+ ),
150
+ limit: int = Query(
151
+ default=100, description="The max number of prompt versions to return at a time.", gt=0
152
+ ),
153
+ ) -> GetPromptVersionsResponseBody:
154
+ query = select(models.PromptVersion)
155
+ query = _filter_by_prompt_identifier(query.join(models.Prompt), prompt_identifier)
156
+ query = query.order_by(models.PromptVersion.id.desc())
157
+
158
+ async with request.app.state.db() as session:
159
+ if cursor:
160
+ try:
161
+ cursor_id = GlobalID.from_id(cursor).node_id
162
+ query = query.filter(models.PromptVersion.id <= int(cursor_id))
163
+ except ValueError:
164
+ raise HTTPException(
165
+ detail=f"Invalid cursor format: {cursor}",
166
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
167
+ )
168
+
169
+ query = query.limit(limit + 1)
170
+ result = await session.execute(query)
171
+ orm_versions = result.scalars().all()
172
+
173
+ if not orm_versions:
174
+ return GetPromptVersionsResponseBody(next_cursor=None, data=[])
175
+
176
+ next_cursor = None
177
+ if len(orm_versions) == limit + 1:
178
+ last_version = orm_versions[-1]
179
+ next_cursor = str(GlobalID(PromptVersionNodeType.__name__, str(last_version.id)))
180
+ orm_versions = orm_versions[:-1]
181
+
182
+ versions = [_prompt_version_from_orm_version(orm_version) for orm_version in orm_versions]
183
+ return GetPromptVersionsResponseBody(next_cursor=next_cursor, data=versions)
184
+
185
+
186
+ @router.get(
187
+ "/prompt_versions/{prompt_version_id}",
188
+ operation_id="getPromptVersionByPromptVersionId",
189
+ summary="Get prompt by prompt version ID",
190
+ responses=add_errors_to_responses(
191
+ [
192
+ HTTP_404_NOT_FOUND,
193
+ HTTP_422_UNPROCESSABLE_ENTITY,
194
+ ]
195
+ ),
196
+ response_model_by_alias=True,
197
+ response_model_exclude_defaults=True,
198
+ response_model_exclude_unset=True,
199
+ )
200
+ async def get_prompt_version_by_prompt_version_id(
201
+ request: Request,
202
+ prompt_version_id: str = Path(description="The ID of the prompt version."),
203
+ ) -> GetPromptResponseBody:
204
+ try:
205
+ id_ = from_global_id_with_expected_type(
206
+ GlobalID.from_id(prompt_version_id),
207
+ PromptVersionNodeType.__name__,
208
+ )
209
+ except ValueError:
210
+ raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt version ID")
211
+ async with request.app.state.db() as session:
212
+ prompt_version = await session.get(models.PromptVersion, id_)
213
+ if prompt_version is None:
214
+ raise HTTPException(HTTP_404_NOT_FOUND)
215
+ data = _prompt_version_from_orm_version(prompt_version)
216
+ return GetPromptResponseBody(data=data)
217
+
218
+
219
+ @router.get(
220
+ "/prompts/{prompt_identifier}/tags/{tag_name}",
221
+ operation_id="getPromptVersionByTagName",
222
+ summary="Get prompt by tag name",
223
+ responses=add_errors_to_responses(
224
+ [
225
+ HTTP_404_NOT_FOUND,
226
+ HTTP_422_UNPROCESSABLE_ENTITY,
227
+ ]
228
+ ),
229
+ response_model_by_alias=True,
230
+ response_model_exclude_unset=True,
231
+ response_model_exclude_defaults=True,
232
+ )
233
+ async def get_prompt_version_by_tag_name(
234
+ request: Request,
235
+ prompt_identifier: str = Path(description="The identifier of the prompt, i.e. name or ID."),
236
+ tag_name: str = Path(description="The tag of the prompt version"),
237
+ ) -> GetPromptResponseBody:
238
+ try:
239
+ name = Identifier.model_validate(tag_name)
240
+ except ValidationError:
241
+ raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid tag name")
242
+ stmt = (
243
+ select(models.PromptVersion)
244
+ .join_from(models.PromptVersion, models.PromptVersionTag)
245
+ .where(models.PromptVersionTag.name == name)
246
+ )
247
+ stmt = _filter_by_prompt_identifier(stmt.join(models.Prompt), prompt_identifier)
248
+ async with request.app.state.db() as session:
249
+ prompt_version: models.PromptVersion = await session.scalar(stmt)
250
+ if prompt_version is None:
251
+ raise HTTPException(HTTP_404_NOT_FOUND)
252
+ data = _prompt_version_from_orm_version(prompt_version)
253
+ return GetPromptResponseBody(data=data)
254
+
255
+
256
+ @router.get(
257
+ "/prompts/{prompt_identifier}/latest",
258
+ operation_id="getPromptVersionLatest",
259
+ summary="Get the latest prompt version",
260
+ responses=add_errors_to_responses(
261
+ [
262
+ HTTP_404_NOT_FOUND,
263
+ HTTP_422_UNPROCESSABLE_ENTITY,
264
+ ]
265
+ ),
266
+ response_model_by_alias=True,
267
+ response_model_exclude_defaults=True,
268
+ response_model_exclude_unset=True,
269
+ )
270
+ async def get_prompt_version_by_latest(
271
+ request: Request,
272
+ prompt_identifier: str = Path(description="The identifier of the prompt, i.e. name or ID."),
273
+ ) -> GetPromptResponseBody:
274
+ stmt = select(models.PromptVersion).order_by(models.PromptVersion.id.desc()).limit(1)
275
+ stmt = _filter_by_prompt_identifier(stmt.join(models.Prompt), prompt_identifier)
276
+ async with request.app.state.db() as session:
277
+ prompt_version: models.PromptVersion = await session.scalar(stmt)
278
+ if prompt_version is None:
279
+ raise HTTPException(HTTP_404_NOT_FOUND)
280
+ data = _prompt_version_from_orm_version(prompt_version)
281
+ return GetPromptResponseBody(data=data)
282
+
283
+
284
+ @router.post(
285
+ "/prompts",
286
+ operation_id="postPromptVersion",
287
+ summary="Create a prompt version",
288
+ responses=add_errors_to_responses(
289
+ [
290
+ HTTP_422_UNPROCESSABLE_ENTITY,
291
+ ]
292
+ ),
293
+ response_model_by_alias=True,
294
+ response_model_exclude_defaults=True,
295
+ response_model_exclude_unset=True,
296
+ )
297
+ async def create_prompt(
298
+ request: Request,
299
+ request_body: CreatePromptRequestBody,
300
+ ) -> CreatePromptResponseBody:
301
+ if (
302
+ request_body.version.template.type.lower() != "chat"
303
+ or request_body.version.template_type != PromptTemplateType.CHAT
304
+ ):
305
+ raise HTTPException(
306
+ HTTP_422_UNPROCESSABLE_ENTITY,
307
+ "Only CHAT template type is supported for prompts",
308
+ )
309
+ prompt = request_body.prompt
310
+ try:
311
+ name = Identifier.model_validate(prompt.name)
312
+ except ValidationError as e:
313
+ raise HTTPException(
314
+ HTTP_422_UNPROCESSABLE_ENTITY,
315
+ "Invalid name identifier for prompt: " + e.errors()[0]["msg"],
316
+ )
317
+ version = request_body.version
318
+ user_id: Optional[int] = None
319
+ if request.app.state.authentication_enabled:
320
+ assert isinstance(user := request.user, PhoenixUser)
321
+ user_id = int(user.identity)
322
+ async with request.app.state.db() as session:
323
+ if not (prompt_id := await session.scalar(select(models.Prompt.id).filter_by(name=name))):
324
+ prompt_orm = models.Prompt(
325
+ name=name,
326
+ description=prompt.description,
327
+ )
328
+ session.add(prompt_orm)
329
+ await session.flush()
330
+ prompt_id = prompt_orm.id
331
+ version_orm = models.PromptVersion(
332
+ user_id=user_id,
333
+ prompt_id=prompt_id,
334
+ description=version.description,
335
+ model_provider=version.model_provider,
336
+ model_name=version.model_name,
337
+ template_type=version.template_type,
338
+ template_format=version.template_format,
339
+ template=version.template,
340
+ invocation_parameters=version.invocation_parameters,
341
+ tools=version.tools,
342
+ response_format=version.response_format,
343
+ )
344
+ session.add(version_orm)
345
+ data = _prompt_version_from_orm_version(version_orm)
346
+ return CreatePromptResponseBody(data=data)
347
+
348
+
349
+ class _PromptId(int): ...
350
+
351
+
352
+ _PromptIdentifier: TypeAlias = Union[_PromptId, Identifier]
353
+
354
+
355
+ def _parse_prompt_identifier(
356
+ prompt_identifier: str,
357
+ ) -> _PromptIdentifier:
358
+ if not prompt_identifier:
359
+ raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt identifier")
360
+ try:
361
+ prompt_id = from_global_id_with_expected_type(
362
+ GlobalID.from_id(prompt_identifier),
363
+ PromptNodeType.__name__,
364
+ )
365
+ except ValueError:
366
+ try:
367
+ return Identifier.model_validate(prompt_identifier)
368
+ except ValidationError:
369
+ raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt name")
370
+ return _PromptId(prompt_id)
371
+
372
+
373
+ def _filter_by_prompt_identifier(
374
+ stmt: Select[tuple[models.PromptVersion]],
375
+ prompt_identifier: str,
376
+ ) -> Any:
377
+ identifier = _parse_prompt_identifier(prompt_identifier)
378
+ if isinstance(identifier, _PromptId):
379
+ return stmt.where(models.Prompt.id == int(identifier))
380
+ if isinstance(identifier, Identifier):
381
+ return stmt.where(models.Prompt.name == identifier)
382
+ assert_never(identifier)
383
+
384
+
385
+ def _prompt_version_from_orm_version(
386
+ prompt_version: models.PromptVersion,
387
+ ) -> PromptVersion:
388
+ prompt_template_type = PromptTemplateType(prompt_version.template_type)
389
+ prompt_template_format = PromptTemplateFormat(prompt_version.template_format)
390
+ return PromptVersion(
391
+ id=str(GlobalID(PromptVersionNodeType.__name__, str(prompt_version.id))),
392
+ description=prompt_version.description or "",
393
+ model_provider=prompt_version.model_provider,
394
+ model_name=prompt_version.model_name,
395
+ template=prompt_version.template,
396
+ template_type=prompt_template_type,
397
+ template_format=prompt_template_format,
398
+ invocation_parameters=prompt_version.invocation_parameters,
399
+ tools=prompt_version.tools,
400
+ response_format=prompt_version.response_format,
401
+ )
402
+
403
+
404
+ def _prompt_from_orm_prompt(orm_prompt: models.Prompt) -> Prompt:
405
+ source_prompt_id = (
406
+ str(GlobalID(PromptNodeType.__name__, str(orm_prompt.source_prompt_id)))
407
+ if orm_prompt.source_prompt_id
408
+ else None
409
+ )
410
+ return Prompt(
411
+ id=str(GlobalID(PromptNodeType.__name__, str(orm_prompt.id))),
412
+ source_prompt_id=source_prompt_id,
413
+ name=orm_prompt.name,
414
+ description=orm_prompt.description,
415
+ )
@@ -24,7 +24,7 @@ from phoenix.server.dml_event import SpanAnnotationInsertEvent
24
24
  from phoenix.trace.dsl import SpanQuery as SpanQuery_
25
25
  from phoenix.utilities.json import encode_df_as_json_string
26
26
 
27
- from .pydantic_compat import V1RoutesBaseModel
27
+ from .models import V1RoutesBaseModel
28
28
  from .utils import RequestBody, ResponseBody, add_errors_to_responses
29
29
 
30
30
  DEFAULT_SPAN_LIMIT = 1000
@@ -30,7 +30,7 @@ from phoenix.server.dml_event import TraceAnnotationInsertEvent
30
30
  from phoenix.trace.otel import decode_otlp_span
31
31
  from phoenix.utilities.project import get_project_name
32
32
 
33
- from .pydantic_compat import V1RoutesBaseModel
33
+ from .models import V1RoutesBaseModel
34
34
  from .utils import RequestBody, ResponseBody, add_errors_to_responses
35
35
 
36
36
  router = APIRouter(tags=["traces"])
@@ -2,7 +2,7 @@ from typing import Any, Generic, Optional, TypedDict, TypeVar, Union
2
2
 
3
3
  from typing_extensions import TypeAlias, assert_never
4
4
 
5
- from .pydantic_compat import V1RoutesBaseModel
5
+ from .models import V1RoutesBaseModel
6
6
 
7
7
  StatusCode: TypeAlias = int
8
8
  DataType = TypeVar("DataType")
@@ -41,6 +41,7 @@ from phoenix.server.api.helpers.playground_spans import (
41
41
  get_db_trace,
42
42
  streaming_llm_span,
43
43
  )
44
+ from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
44
45
  from phoenix.server.api.input_types.ChatCompletionInput import (
45
46
  ChatCompletionInput,
46
47
  ChatCompletionOverDatasetInput,
@@ -59,7 +60,6 @@ from phoenix.server.api.types.Experiment import to_gql_experiment
59
60
  from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
60
61
  from phoenix.server.api.types.node import from_global_id_with_expected_type
61
62
  from phoenix.server.api.types.Span import to_gql_span
62
- from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
63
63
  from phoenix.server.dml_event import SpanInsertEvent
64
64
  from phoenix.server.types import DbSessionFactory
65
65
  from phoenix.utilities.template_formatters import (
@@ -124,7 +124,7 @@ class Subscription:
124
124
  messages = list(
125
125
  _formatted_messages(
126
126
  messages=messages,
127
- template_language=template_options.language,
127
+ template_format=template_options.format,
128
128
  template_variables=template_options.variables,
129
129
  )
130
130
  )
@@ -198,9 +198,7 @@ class Subscription:
198
198
  )
199
199
  async with info.context.db() as session:
200
200
  if (
201
- dataset := await session.scalar(
202
- select(models.Dataset).where(models.Dataset.id == dataset_id)
203
- )
201
+ await session.scalar(select(models.Dataset).where(models.Dataset.id == dataset_id))
204
202
  ) is None:
205
203
  raise NotFound(f"Could not find dataset with ID {dataset_id}")
206
204
  if version_id is None:
@@ -274,9 +272,9 @@ class Subscription:
274
272
  experiment = models.Experiment(
275
273
  dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
276
274
  dataset_version_id=resolved_version_id,
277
- name=input.experiment_name or _default_playground_experiment_name(),
278
- description=input.experiment_description
279
- or _default_playground_experiment_description(dataset_name=dataset.name),
275
+ name=input.experiment_name
276
+ or _default_playground_experiment_name(input.prompt_name),
277
+ description=input.experiment_description,
280
278
  repetitions=1,
281
279
  metadata_=input.experiment_metadata or dict(),
282
280
  project_name=PLAYGROUND_PROJECT_NAME,
@@ -394,7 +392,7 @@ async def _stream_chat_completion_over_dataset_example(
394
392
  messages = list(
395
393
  _formatted_messages(
396
394
  messages=messages,
397
- template_language=input.template_language,
395
+ template_format=input.template_format,
398
396
  template_variables=revision.input,
399
397
  )
400
398
  )
@@ -472,7 +470,7 @@ def _is_result_payloads_stream(
472
470
  Checks if the given generator was instantiated from
473
471
  `_chat_completion_result_payloads`
474
472
  """
475
- return stream.ag_code == _chat_completion_result_payloads.__code__
473
+ return stream.ag_code == _chat_completion_result_payloads.__code__ # type: ignore
476
474
 
477
475
 
478
476
  def _create_task_with_timeout(
@@ -534,13 +532,13 @@ async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
534
532
  def _formatted_messages(
535
533
  *,
536
534
  messages: Iterable[ChatCompletionMessage],
537
- template_language: TemplateLanguage,
535
+ template_format: PromptTemplateFormat,
538
536
  template_variables: Mapping[str, Any],
539
537
  ) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
540
538
  """
541
539
  Formats the messages using the given template options.
542
540
  """
543
- template_formatter = _template_formatter(template_language=template_language)
541
+ template_formatter = _template_formatter(template_format=template_format)
544
542
  (
545
543
  roles,
546
544
  templates,
@@ -555,25 +553,24 @@ def _formatted_messages(
555
553
  return formatted_messages
556
554
 
557
555
 
558
- def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter:
556
+ def _template_formatter(template_format: PromptTemplateFormat) -> TemplateFormatter:
559
557
  """
560
- Instantiates the appropriate template formatter for the template language.
558
+ Instantiates the appropriate template formatter for the template format
561
559
  """
562
- if template_language is TemplateLanguage.MUSTACHE:
560
+ if template_format is PromptTemplateFormat.MUSTACHE:
563
561
  return MustacheTemplateFormatter()
564
- if template_language is TemplateLanguage.F_STRING:
562
+ if template_format is PromptTemplateFormat.F_STRING:
565
563
  return FStringTemplateFormatter()
566
- if template_language is TemplateLanguage.NONE:
564
+ if template_format is PromptTemplateFormat.NONE:
567
565
  return NoOpFormatter()
568
- assert_never(template_language)
569
-
570
-
571
- def _default_playground_experiment_name() -> str:
572
- return "playground-experiment"
566
+ assert_never(template_format)
573
567
 
574
568
 
575
- def _default_playground_experiment_description(dataset_name: str) -> str:
576
- return f'Playground experiment for dataset "{dataset_name}"'
569
+ def _default_playground_experiment_name(prompt_name: Optional[str] = None) -> str:
570
+ name = "playground-experiment"
571
+ if prompt_name:
572
+ name = f"{name} prompt:{prompt_name}"
573
+ return name
577
574
 
578
575
 
579
576
  LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES