arize-phoenix 8.21.0__py3-none-any.whl → 8.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arize-phoenix
3
- Version: 8.21.0
3
+ Version: 8.22.0
4
4
  Summary: AI Observability and Evaluation
5
5
  Project-URL: Documentation, https://docs.arize.com/phoenix/
6
6
  Project-URL: Issues, https://github.com/Arize-ai/phoenix/issues
@@ -6,7 +6,7 @@ phoenix/exceptions.py,sha256=n2L2KKuecrdflB9MsCdAYCiSEvGJptIsfRkXMoJle7A,169
6
6
  phoenix/py.typed,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
7
7
  phoenix/services.py,sha256=kpW1WL0kiB8XJsO6XycvZVJ-lBkNoenhQ7atCvBoSe8,5365
8
8
  phoenix/settings.py,sha256=x87BX7hWGQQZbrW_vrYqFR_izCGfO9gFc--JXUG4Tdk,754
9
- phoenix/version.py,sha256=IcFew5OfUbzsX2QRcX00Ngf0cInzeuF_IhyKYRXmTmI,23
9
+ phoenix/version.py,sha256=87-3KQINFNZAUcsrNLppOae3DX76fIASs9wL99nLugI,23
10
10
  phoenix/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  phoenix/core/embedding_dimension.py,sha256=zKGbcvwOXgLf-yrJBpQyKtd-LEOPRKHnUToyAU8Owis,87
12
12
  phoenix/core/model.py,sha256=qBFraOtmwCCnWJltKNP18DDG0mULXigytlFsa6YOz6k,4837
@@ -27,7 +27,7 @@ phoenix/db/insertion/constants.py,sha256=8wifm7X-1XvroZ__R2Gc96NsgLhTDn0zXl4lehl
27
27
  phoenix/db/insertion/dataset.py,sha256=I9OC1ouVx7m6BH_c8hvcxW1dWGRAtpvXee29yBTuFkg,7136
28
28
  phoenix/db/insertion/document_annotation.py,sha256=vnszF9L0qHjUY-eWgBVRG5OTpz-ECs1BoCfy66ynzjo,5997
29
29
  phoenix/db/insertion/evaluation.py,sha256=SoI85N3MYUSeNgjKa5WzFw14OfNjNTjExv-2m3sxaR8,6371
30
- phoenix/db/insertion/helpers.py,sha256=-DyRcxzJnjSJFhscPoqiNiQn8fBvGqI8IcNJEu-79Vw,3455
30
+ phoenix/db/insertion/helpers.py,sha256=PlYtLI6SfmQwoLZ8UYrkAffygjMwiuZz1FPs6g7Vchs,3437
31
31
  phoenix/db/insertion/span.py,sha256=02hpGo5ZY6N-n6Z-far-AC_yVAloXpJyt-CIHOeor0k,8126
32
32
  phoenix/db/insertion/span_annotation.py,sha256=EtzcjS8GR2rDM_cxIQloGc2SZqfdrACyAPwvyFBa2Ac,5273
33
33
  phoenix/db/insertion/trace_annotation.py,sha256=6yzWuU0Fh-mC-rJJ96rH0IYTJg1EQFnj-GZhUwolUKI,5338
@@ -44,7 +44,7 @@ phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py,sha256=rq-bwg0g
44
44
  phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py,sha256=fkpmh5PgMZJiZpvLbZIaqlI2cucVpVbbNYpQ-Tznil8,5180
45
45
  phoenix/db/migrations/versions/cf03bd6bae1d_init.py,sha256=ZNQzTUyb3p9Bkq7rjd5MRxGQV8W08zpY8E3GlzSJ2cM,8630
46
46
  phoenix/db/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
- phoenix/db/types/identifier.py,sha256=TSm3o4CDnd7vNyQj3c8CKtrx04ZNaDFDbYLFHTi0_7I,181
47
+ phoenix/db/types/identifier.py,sha256=Opr3_1di6e5ncrBDn30WfBSr-jN_VGBnkkA4BMuSoyc,244
48
48
  phoenix/db/types/model_provider.py,sha256=96UMeqiy5X9PmYMOWA6dZAmI_BSV3yVxt9HEVYGe5Ns,157
49
49
  phoenix/experiments/__init__.py,sha256=6JGwgUd7xCbGpuHqYZlsmErmYvVgv7N_j43bn3dUqsk,123
50
50
  phoenix/experiments/functions.py,sha256=QOwYNMO1qxG2bah4ZjJKMk7agFkZRWGvwyazLcwRCys,32573
@@ -99,7 +99,7 @@ phoenix/server/api/auth.py,sha256=nywpmfMI1trZTbZRD3oBj4kFjzg_vnxDljcM431T1eY,12
99
99
  phoenix/server/api/context.py,sha256=OopBkMnY48TzulTtfuay3MXmhbFIWPtPoAsbhCW-Pl0,6459
100
100
  phoenix/server/api/exceptions.py,sha256=TA0JuY2YRnj35qGuMSQ8d0ToHum9gWm9W--3fSKHrX0,1171
101
101
  phoenix/server/api/interceptor.py,sha256=ykDnoC_apUd-llVli3m1CW18kNSIgjz2qZ6m5JmPDu8,1294
102
- phoenix/server/api/queries.py,sha256=e_zt4VKQj3dnLIrVEFbxIOhFo8bup_vVN9xAqVW009w,35726
102
+ phoenix/server/api/queries.py,sha256=Xd1K-6bu_DuUxczyXzf0M6EiRL7Ahnx47QVqPmo2-58,36167
103
103
  phoenix/server/api/schema.py,sha256=fcs36xQwFF_Qe41_5cWR8wYpDvOrnbcyTeo5WNMbDsA,1702
104
104
  phoenix/server/api/subscriptions.py,sha256=DSIgQF6lQqkbc7D0AaI5R4g3hIHbU04H5Y2UIpwmpy0,22989
105
105
  phoenix/server/api/utils.py,sha256=quCBRcusc6PUq9tJq7M8PgwFZp7nXgVAxtbw8feribY,833
@@ -217,7 +217,7 @@ phoenix/server/api/routers/v1/experiment_evaluations.py,sha256=vx4CKlE84sAL1vtPi
217
217
  phoenix/server/api/routers/v1/experiment_runs.py,sha256=bInuasRv7ogiYf8fq-LwpJ5tptmMQsBNDlJAqwdymko,6378
218
218
  phoenix/server/api/routers/v1/experiments.py,sha256=V9_sxqLTE1MKGFu9H3FEdGKr70lYMbGZx813MGaavfQ,20430
219
219
  phoenix/server/api/routers/v1/models.py,sha256=r0nM2kFJ3mxDqgc5vFr1cjNuyOPs3RIKE_DS2VMdF48,1749
220
- phoenix/server/api/routers/v1/prompts.py,sha256=ytK8HnOZNxUMDtC7XAFxzaTSM9DMMua13vWsqqd4PAw,14986
220
+ phoenix/server/api/routers/v1/prompts.py,sha256=aBOUBwLDzZDIzJQkxJcR8ZKnakNJOLMwzsLKINSs1mA,26545
221
221
  phoenix/server/api/routers/v1/spans.py,sha256=uoU_bwIgz86fuvPjP5sX8goDyuCcnsTig-x3f17p60U,9625
222
222
  phoenix/server/api/routers/v1/traces.py,sha256=hSv35QIB4mwFgp53rOpz3zWIiSwbZzQnjafD790QuJU,7908
223
223
  phoenix/server/api/routers/v1/utils.py,sha256=SoRl0Dc8By15ZckhNcXg2QRrqYjMvgTjVcqrZ6MwVmo,3065
@@ -273,7 +273,7 @@ phoenix/server/api/types/Prompt.py,sha256=ccP4eq1e38xbF0afclGWLOuDpBVpNbJ3AOSRCl
273
273
  phoenix/server/api/types/PromptLabel.py,sha256=g3IDSPYRZwb0qpMAk93R6J96jgYULUYGOciTnpeh3sI,1321
274
274
  phoenix/server/api/types/PromptResponse.py,sha256=Q8HKtpp8GpUOcxPCzZpkkokidDd6u0aZOv_SuPZZd5Q,630
275
275
  phoenix/server/api/types/PromptVersion.py,sha256=hLPwOwj1h0gH4fqpLS-xwcC2TQYLCmvaCiEkyXvTS1g,5494
276
- phoenix/server/api/types/PromptVersionTag.py,sha256=K2pbK-pIjzCTfwcwPpyk91m52XMhMlH0dLcPnaxyb4M,836
276
+ phoenix/server/api/types/PromptVersionTag.py,sha256=dSumwkBv2P-yDLxE0tnu8O87M2ZcuhLrmnBb4CkuyC4,1389
277
277
  phoenix/server/api/types/PromptVersionTemplate.py,sha256=jlHJS079OhI8IHZoWQn0aIz5JnHOuYc8iD8UkytyzEQ,4627
278
278
  phoenix/server/api/types/ResponseFormat.py,sha256=ymBsPWGViiaEo4vyw50ht1K6BveKcOZ2YNce50PAoew,182
279
279
  phoenix/server/api/types/Retrieval.py,sha256=OhMK2ncjoyp5h1yjKhjlKpoTbQrMHuxmgSFw-AO1rWw,285
@@ -364,9 +364,9 @@ phoenix/utilities/project.py,sha256=auVpARXkDb-JgeX5f2aStyFIkeKvGwN9l7qrFeJMVxI,
364
364
  phoenix/utilities/re.py,sha256=6YyUWIkv0zc2SigsxfOWIHzdpjKA_TZo2iqKq7zJKvw,2081
365
365
  phoenix/utilities/span_store.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
366
366
  phoenix/utilities/template_formatters.py,sha256=gh9PJD6WEGw7TEYXfSst1UR4pWWwmjxMLrDVQ_CkpkQ,2779
367
- arize_phoenix-8.21.0.dist-info/METADATA,sha256=0DkKUK_8BHpdYpiMIfmCLhA8bRmpQeHmmh0PgqoMmUU,21378
368
- arize_phoenix-8.21.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
369
- arize_phoenix-8.21.0.dist-info/entry_points.txt,sha256=Pgpn8Upxx9P8z8joPXZWl2LlnAlGc3gcQoVchb06X1Q,94
370
- arize_phoenix-8.21.0.dist-info/licenses/IP_NOTICE,sha256=JBqyyCYYxGDfzQ0TtsQgjts41IJoa-hiwDrBjCb9gHM,469
371
- arize_phoenix-8.21.0.dist-info/licenses/LICENSE,sha256=HFkW9REuMOkvKRACuwLPT0hRydHb3zNg-fdFt94td18,3794
372
- arize_phoenix-8.21.0.dist-info/RECORD,,
367
+ arize_phoenix-8.22.0.dist-info/METADATA,sha256=ahp8w8u1QgKzZCz7LglKC0DXLYJS-yZjFH7igLe2yO0,21378
368
+ arize_phoenix-8.22.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
369
+ arize_phoenix-8.22.0.dist-info/entry_points.txt,sha256=Pgpn8Upxx9P8z8joPXZWl2LlnAlGc3gcQoVchb06X1Q,94
370
+ arize_phoenix-8.22.0.dist-info/licenses/IP_NOTICE,sha256=JBqyyCYYxGDfzQ0TtsQgjts41IJoa-hiwDrBjCb9gHM,469
371
+ arize_phoenix-8.22.0.dist-info/licenses/LICENSE,sha256=HFkW9REuMOkvKRACuwLPT0hRydHb3zNg-fdFt94td18,3794
372
+ arize_phoenix-8.22.0.dist-info/RECORD,,
@@ -78,7 +78,7 @@ def _clean(
78
78
  kv: Iterable[tuple[str, KeyedColumnElement[Any]]],
79
79
  ) -> Iterator[tuple[str, KeyedColumnElement[Any]]]:
80
80
  for k, v in kv:
81
- if v.primary_key or v.foreign_keys or k == "created_at":
81
+ if v.primary_key or k == "created_at":
82
82
  continue
83
83
  if k == "metadata_":
84
84
  yield "metadata", v
@@ -5,3 +5,6 @@ from pydantic import Field, RootModel
5
5
 
6
6
  class Identifier(RootModel[str]):
7
7
  root: Annotated[str, Field(pattern=r"^[a-z0-9]([_a-z0-9-]*[a-z0-9])?$")]
8
+
9
+ def __hash__(self) -> int:
10
+ return hash(self.root)
@@ -69,6 +69,7 @@ from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_proje
69
69
  from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
70
70
  from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
71
71
  from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
72
+ from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
72
73
  from phoenix.server.api.types.SortDir import SortDir
73
74
  from phoenix.server.api.types.Span import Span
74
75
  from phoenix.server.api.types.SystemApiKey import SystemApiKey
@@ -596,6 +597,11 @@ class Query:
596
597
  ):
597
598
  raise NotFound(f"Unknown prompt label: {id}")
598
599
  return to_gql_prompt_label(prompt_label)
600
+ elif type_name == PromptVersionTag.__name__:
601
+ async with info.context.db() as session:
602
+ if not (prompt_version_tag := await session.get(models.PromptVersionTag, node_id)):
603
+ raise NotFound(f"Unknown prompt version tag: {id}")
604
+ return to_gql_prompt_version_tag(prompt_version_tag)
599
605
  raise NotFound(f"Unknown node type: {type_name}")
600
606
 
601
607
  @strawberry.field
@@ -6,11 +6,13 @@ from pydantic import ValidationError, model_validator
6
6
  from sqlalchemy import select
7
7
  from sqlalchemy.sql import Select
8
8
  from starlette.requests import Request
9
- from starlette.status import HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY
9
+ from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY
10
10
  from strawberry.relay import GlobalID
11
11
  from typing_extensions import Self, TypeAlias, assert_never
12
12
 
13
13
  from phoenix.db import models
14
+ from phoenix.db.helpers import SupportedSQLDialect
15
+ from phoenix.db.insertion.helpers import OnConflict, insert_on_conflict
14
16
  from phoenix.db.types.identifier import Identifier
15
17
  from phoenix.db.types.model_provider import ModelProvider
16
18
  from phoenix.server.api.helpers.prompts.models import (
@@ -22,10 +24,15 @@ from phoenix.server.api.helpers.prompts.models import (
22
24
  PromptTools,
23
25
  )
24
26
  from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
25
- from phoenix.server.api.routers.v1.utils import ResponseBody, add_errors_to_responses
27
+ from phoenix.server.api.routers.v1.utils import (
28
+ PaginatedResponseBody,
29
+ ResponseBody,
30
+ add_errors_to_responses,
31
+ )
26
32
  from phoenix.server.api.types.node import from_global_id_with_expected_type
27
33
  from phoenix.server.api.types.Prompt import Prompt as PromptNodeType
28
34
  from phoenix.server.api.types.PromptVersion import PromptVersion as PromptVersionNodeType
35
+ from phoenix.server.api.types.PromptVersionTag import PromptVersionTag as PromptVersionTagNodeType
29
36
  from phoenix.server.bearer_auth import PhoenixUser
30
37
 
31
38
  logger = logging.getLogger(__name__)
@@ -73,11 +80,11 @@ class GetPromptResponseBody(ResponseBody[PromptVersion]):
73
80
  pass
74
81
 
75
82
 
76
- class GetPromptsResponseBody(ResponseBody[list[Prompt]]):
83
+ class GetPromptsResponseBody(PaginatedResponseBody[Prompt]):
77
84
  pass
78
85
 
79
86
 
80
- class GetPromptVersionsResponseBody(ResponseBody[list[PromptVersion]]):
87
+ class GetPromptVersionsResponseBody(PaginatedResponseBody[PromptVersion]):
81
88
  pass
82
89
 
83
90
 
@@ -96,7 +103,10 @@ router = APIRouter(tags=["prompts"])
96
103
  @router.get(
97
104
  "/prompts",
98
105
  operation_id="getPrompts",
99
- summary="Get all prompts",
106
+ summary="List all prompts",
107
+ description="Retrieve a paginated list of all prompts in the system. A prompt can have "
108
+ "multiple versions.",
109
+ response_description="A list of prompts with pagination information",
100
110
  responses=add_errors_to_responses(
101
111
  [
102
112
  HTTP_422_UNPROCESSABLE_ENTITY,
@@ -113,7 +123,27 @@ async def get_prompts(
113
123
  default=100, description="The max number of prompts to return at a time.", gt=0
114
124
  ),
115
125
  ) -> GetPromptsResponseBody:
126
+ """
127
+ Retrieve a paginated list of all prompts in the system.
128
+
129
+ Args:
130
+ request (Request): The FastAPI request object.
131
+ cursor (Optional[str]): Pagination cursor (base64-encoded prompt ID).
132
+ limit (int): Maximum number of prompts to return per request.
133
+
134
+ Returns:
135
+ GetPromptsResponseBody: Response containing a list of prompts and pagination information.
136
+
137
+ Raises:
138
+ HTTPException: If the cursor format is invalid.
139
+ """
116
140
  async with request.app.state.db() as session:
141
+ # First check if any prompts exist
142
+ if not cursor:
143
+ prompt_exists = await session.scalar(select(models.Prompt.id).limit(1))
144
+ if not prompt_exists:
145
+ return GetPromptsResponseBody(next_cursor=None, data=[])
146
+
117
147
  query = select(models.Prompt).order_by(models.Prompt.id.desc())
118
148
 
119
149
  if cursor:
@@ -146,8 +176,11 @@ async def get_prompts(
146
176
  @router.get(
147
177
  "/prompts/{prompt_identifier}/versions",
148
178
  operation_id="listPromptVersions",
149
- summary="List all prompt versions for a given prompt",
150
- responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
179
+ summary="List prompt versions",
180
+ description="Retrieve all versions of a specific prompt with pagination support. Each prompt "
181
+ "can have multiple versions with different configurations.",
182
+ response_description="A list of prompt versions with pagination information",
183
+ responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY, HTTP_404_NOT_FOUND]),
151
184
  response_model_by_alias=True,
152
185
  response_model_exclude_defaults=True,
153
186
  response_model_exclude_unset=True,
@@ -163,6 +196,23 @@ async def list_prompt_versions(
163
196
  default=100, description="The max number of prompt versions to return at a time.", gt=0
164
197
  ),
165
198
  ) -> GetPromptVersionsResponseBody:
199
+ """
200
+ List all versions of a specific prompt with pagination support.
201
+
202
+ Args:
203
+ request (Request): The FastAPI request object.
204
+ prompt_identifier (str): The identifier of the prompt (name or ID).
205
+ cursor (Optional[str]): Pagination cursor (base64-encoded promptVersion ID).
206
+ limit (int): Maximum number of prompt versions to return per request.
207
+
208
+ Returns:
209
+ GetPromptVersionsResponseBody: Response containing a list of prompt versions and pagination
210
+ information.
211
+
212
+ Raises:
213
+ HTTPException: If the cursor format is invalid, the prompt identifier is invalid,
214
+ or the prompt is not found.
215
+ """
166
216
  query = select(models.PromptVersion)
167
217
  query = _filter_by_prompt_identifier(query.join(models.Prompt), prompt_identifier)
168
218
  query = query.order_by(models.PromptVersion.id.desc())
@@ -198,7 +248,10 @@ async def list_prompt_versions(
198
248
  @router.get(
199
249
  "/prompt_versions/{prompt_version_id}",
200
250
  operation_id="getPromptVersionByPromptVersionId",
201
- summary="Get prompt by prompt version ID",
251
+ summary="Get prompt version by ID",
252
+ description="Retrieve a specific prompt version using its unique identifier. A prompt version "
253
+ "contains the actual template and configuration.",
254
+ response_description="The requested prompt version",
202
255
  responses=add_errors_to_responses(
203
256
  [
204
257
  HTTP_404_NOT_FOUND,
@@ -213,6 +266,19 @@ async def get_prompt_version_by_prompt_version_id(
213
266
  request: Request,
214
267
  prompt_version_id: str = Path(description="The ID of the prompt version."),
215
268
  ) -> GetPromptResponseBody:
269
+ """
270
+ Retrieve a specific prompt version by its ID.
271
+
272
+ Args:
273
+ request (Request): The FastAPI request object.
274
+ prompt_version_id (str): The ID of the prompt version to retrieve.
275
+
276
+ Returns:
277
+ GetPromptResponseBody: Response containing the requested prompt version.
278
+
279
+ Raises:
280
+ HTTPException: If the prompt version ID is invalid or the prompt version is not found.
281
+ """
216
282
  try:
217
283
  id_ = from_global_id_with_expected_type(
218
284
  GlobalID.from_id(prompt_version_id),
@@ -231,7 +297,10 @@ async def get_prompt_version_by_prompt_version_id(
231
297
  @router.get(
232
298
  "/prompts/{prompt_identifier}/tags/{tag_name}",
233
299
  operation_id="getPromptVersionByTagName",
234
- summary="Get prompt by tag name",
300
+ summary="Get prompt version by tag",
301
+ description="Retrieve a specific prompt version using its tag name. Tags are used to identify "
302
+ "specific versions of a prompt.",
303
+ response_description="The prompt version with the specified tag",
235
304
  responses=add_errors_to_responses(
236
305
  [
237
306
  HTTP_404_NOT_FOUND,
@@ -247,6 +316,20 @@ async def get_prompt_version_by_tag_name(
247
316
  prompt_identifier: str = Path(description="The identifier of the prompt, i.e. name or ID."),
248
317
  tag_name: str = Path(description="The tag of the prompt version"),
249
318
  ) -> GetPromptResponseBody:
319
+ """
320
+ Retrieve a specific prompt version by its tag name.
321
+
322
+ Args:
323
+ request (Request): The FastAPI request object.
324
+ prompt_identifier (str): The identifier of the prompt (name or ID).
325
+ tag_name (str): The tag name associated with the prompt version.
326
+
327
+ Returns:
328
+ GetPromptResponseBody: Response containing the prompt version with the specified tag.
329
+
330
+ Raises:
331
+ HTTPException: If the tag name is invalid or the prompt version is not found.
332
+ """
250
333
  try:
251
334
  name = Identifier.model_validate(tag_name)
252
335
  except ValidationError:
@@ -268,7 +351,9 @@ async def get_prompt_version_by_tag_name(
268
351
  @router.get(
269
352
  "/prompts/{prompt_identifier}/latest",
270
353
  operation_id="getPromptVersionLatest",
271
- summary="Get the latest prompt version",
354
+ summary="Get latest prompt version",
355
+ description="Retrieve the most recent version of a specific prompt.",
356
+ response_description="The latest version of the specified prompt",
272
357
  responses=add_errors_to_responses(
273
358
  [
274
359
  HTTP_404_NOT_FOUND,
@@ -283,6 +368,19 @@ async def get_prompt_version_by_latest(
283
368
  request: Request,
284
369
  prompt_identifier: str = Path(description="The identifier of the prompt, i.e. name or ID."),
285
370
  ) -> GetPromptResponseBody:
371
+ """
372
+ Retrieve the latest version of a specific prompt.
373
+
374
+ Args:
375
+ request (Request): The FastAPI request object.
376
+ prompt_identifier (str): The identifier of the prompt (name or ID).
377
+
378
+ Returns:
379
+ GetPromptResponseBody: Response containing the latest prompt version.
380
+
381
+ Raises:
382
+ HTTPException: If the prompt identifier is invalid or no prompt version is found.
383
+ """
286
384
  stmt = select(models.PromptVersion).order_by(models.PromptVersion.id.desc()).limit(1)
287
385
  stmt = _filter_by_prompt_identifier(stmt.join(models.Prompt), prompt_identifier)
288
386
  async with request.app.state.db() as session:
@@ -296,7 +394,9 @@ async def get_prompt_version_by_latest(
296
394
  @router.post(
297
395
  "/prompts",
298
396
  operation_id="postPromptVersion",
299
- summary="Create a prompt version",
397
+ summary="Create a new prompt",
398
+ description="Create a new prompt and its initial version. A prompt can have multiple versions.",
399
+ response_description="The newly created prompt version",
300
400
  responses=add_errors_to_responses(
301
401
  [
302
402
  HTTP_422_UNPROCESSABLE_ENTITY,
@@ -310,6 +410,20 @@ async def create_prompt(
310
410
  request: Request,
311
411
  request_body: CreatePromptRequestBody,
312
412
  ) -> CreatePromptResponseBody:
413
+ """
414
+ Create a new prompt and its initial version.
415
+
416
+ Args:
417
+ request (Request): The FastAPI request object.
418
+ request_body (CreatePromptRequestBody): The request body containing prompt and version data.
419
+
420
+ Returns:
421
+ CreatePromptResponseBody: Response containing the created prompt version.
422
+
423
+ Raises:
424
+ HTTPException: If the template type is not supported, the name identifier is invalid,
425
+ or any other validation error occurs.
426
+ """
313
427
  if (
314
428
  request_body.version.template.type.lower() != "chat"
315
429
  or request_body.version.template_type != PromptTemplateType.CHAT
@@ -358,6 +472,207 @@ async def create_prompt(
358
472
  return CreatePromptResponseBody(data=data)
359
473
 
360
474
 
475
+ class PromptVersionTagData(V1RoutesBaseModel):
476
+ name: Identifier
477
+ description: Optional[str] = None
478
+
479
+
480
+ class PromptVersionTag(PromptVersionTagData):
481
+ id: str
482
+
483
+
484
+ class GetPromptVersionTagsResponseBody(PaginatedResponseBody[PromptVersionTag]):
485
+ pass
486
+
487
+
488
+ @router.get(
489
+ "/prompt_versions/{prompt_version_id}/tags",
490
+ operation_id="getPromptVersionTags",
491
+ summary="List prompt version tags",
492
+ description="Retrieve all tags associated with a specific prompt version. Tags are used to "
493
+ "identify and categorize different versions of a prompt.",
494
+ response_description="A list of tags associated with the prompt version",
495
+ responses=add_errors_to_responses(
496
+ [
497
+ HTTP_404_NOT_FOUND,
498
+ HTTP_422_UNPROCESSABLE_ENTITY,
499
+ ]
500
+ ),
501
+ response_model_by_alias=True,
502
+ response_model_exclude_defaults=True,
503
+ response_model_exclude_unset=True,
504
+ )
505
+ async def list_prompt_version_tags(
506
+ request: Request,
507
+ prompt_version_id: str = Path(description="The ID of the prompt version."),
508
+ cursor: Optional[str] = Query(
509
+ default=None,
510
+ description="Cursor for pagination (base64-encoded promptVersionTag ID)",
511
+ ),
512
+ limit: int = Query(
513
+ default=100, description="The max number of tags to return at a time.", gt=0
514
+ ),
515
+ ) -> GetPromptVersionTagsResponseBody:
516
+ """
517
+ Get tags for a specific prompt version.
518
+
519
+ Args:
520
+ request (Request): The request object.
521
+ prompt_version_id (str): The ID of the prompt version.
522
+ cursor (Optional[str]): Pagination cursor (base64-encoded promptVersionTag ID).
523
+ limit (int): Maximum number of tags to return per request.
524
+
525
+ Returns:
526
+ GetPromptVersionTagsResponseBody: The response body containing the tags.
527
+
528
+ Raises:
529
+ HTTPException: If the prompt version ID is invalid, the prompt version is not found,
530
+ or the cursor format is invalid.
531
+ """
532
+ try:
533
+ id_ = from_global_id_with_expected_type(
534
+ GlobalID.from_id(prompt_version_id),
535
+ PromptVersionNodeType.__name__,
536
+ )
537
+ except ValueError:
538
+ raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt version ID")
539
+
540
+ # Build the query for tags
541
+ stmt = (
542
+ select(
543
+ models.PromptVersion.id,
544
+ models.PromptVersionTag.id,
545
+ models.PromptVersionTag.name,
546
+ models.PromptVersionTag.description,
547
+ )
548
+ .outerjoin_from(models.PromptVersion, models.PromptVersionTag)
549
+ .where(models.PromptVersion.id == id_)
550
+ .order_by(models.PromptVersionTag.id.desc())
551
+ )
552
+
553
+ # Apply cursor-based pagination
554
+ if cursor:
555
+ try:
556
+ cursor_id = GlobalID.from_id(cursor).node_id
557
+ stmt = stmt.filter(models.PromptVersionTag.id <= int(cursor_id))
558
+ except ValueError:
559
+ raise HTTPException(
560
+ detail=f"Invalid cursor format: {cursor}",
561
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
562
+ )
563
+
564
+ # Apply limit
565
+ stmt = stmt.limit(limit + 1)
566
+
567
+ async with request.app.state.db() as session:
568
+ result = (await session.execute(stmt)).all()
569
+
570
+ # Check if prompt version exists
571
+ if not result:
572
+ raise HTTPException(HTTP_404_NOT_FOUND, "Prompt version not found")
573
+
574
+ # Check if there are any tags
575
+ has_tags = any(id_ is not None for _, id_, _, _ in result)
576
+ if not has_tags:
577
+ return GetPromptVersionTagsResponseBody(next_cursor=None, data=[])
578
+
579
+ # Check if there are more results
580
+ next_cursor = None
581
+ if len(result) == limit + 1:
582
+ # Remove the extra item used for pagination
583
+ result = result[:-1]
584
+ # Get the ID of the last item for the next cursor
585
+ last_tag_id = result[-1][1] # The second element is the tag ID
586
+ if last_tag_id is not None:
587
+ next_cursor = str(GlobalID(PromptVersionTagNodeType.__name__, str(last_tag_id)))
588
+
589
+ # Convert to response format
590
+ data = [
591
+ PromptVersionTag(
592
+ id=str(GlobalID(PromptVersionTagNodeType.__name__, str(id_))),
593
+ name=name,
594
+ description=description,
595
+ )
596
+ for _, id_, name, description in result
597
+ if id_ is not None
598
+ ]
599
+
600
+ return GetPromptVersionTagsResponseBody(next_cursor=next_cursor, data=data)
601
+
602
+
603
+ @router.post(
604
+ "/prompt_versions/{prompt_version_id}/tags",
605
+ operation_id="createPromptVersionTag",
606
+ summary="Add tag to prompt version",
607
+ description="Add a new tag to a specific prompt version. Tags help identify and categorize "
608
+ "different versions of a prompt.",
609
+ response_description="No content returned on successful tag creation",
610
+ status_code=HTTP_204_NO_CONTENT,
611
+ responses=add_errors_to_responses(
612
+ [
613
+ HTTP_404_NOT_FOUND,
614
+ HTTP_422_UNPROCESSABLE_ENTITY,
615
+ ]
616
+ ),
617
+ response_model_by_alias=True,
618
+ response_model_exclude_defaults=True,
619
+ response_model_exclude_unset=True,
620
+ )
621
+ async def create_prompt_version_tag(
622
+ request: Request,
623
+ request_body: PromptVersionTagData,
624
+ prompt_version_id: str = Path(description="The ID of the prompt version."),
625
+ ) -> None:
626
+ """
627
+ Add a tag to a specific prompt version.
628
+
629
+ Args:
630
+ request (Request): The FastAPI request object.
631
+ request_body (PromptVersionTagData): The tag data to be added.
632
+ prompt_version_id (str): The ID of the prompt version to tag.
633
+
634
+ Returns:
635
+ None: Returns a 204 No Content response on success.
636
+
637
+ Raises:
638
+ HTTPException: If the prompt version ID is invalid, the prompt version is not found,
639
+ or any other validation error occurs.
640
+ """
641
+ try:
642
+ id_ = from_global_id_with_expected_type(
643
+ GlobalID.from_id(prompt_version_id),
644
+ PromptVersionNodeType.__name__,
645
+ )
646
+ except ValueError:
647
+ raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt version ID")
648
+ user_id: Optional[int] = None
649
+ if request.app.state.authentication_enabled:
650
+ assert isinstance(user := request.user, PhoenixUser)
651
+ user_id = int(user.identity)
652
+ async with request.app.state.db() as session:
653
+ prompt_id = await session.scalar(select(models.PromptVersion.prompt_id).filter_by(id=id_))
654
+ if prompt_id is None:
655
+ raise HTTPException(HTTP_404_NOT_FOUND)
656
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
657
+ values = dict(
658
+ name=request_body.name,
659
+ description=request_body.description,
660
+ prompt_id=prompt_id,
661
+ prompt_version_id=id_,
662
+ user_id=user_id,
663
+ )
664
+ await session.execute(
665
+ insert_on_conflict(
666
+ values,
667
+ dialect=dialect,
668
+ table=models.PromptVersionTag,
669
+ unique_by=("name", "prompt_id"),
670
+ on_conflict=OnConflict.DO_UPDATE,
671
+ )
672
+ )
673
+ return None
674
+
675
+
361
676
  class _PromptId(int): ...
362
677
 
363
678
 
@@ -1,19 +1,31 @@
1
1
  from typing import Optional
2
2
 
3
3
  import strawberry
4
+ from strawberry import Info
4
5
  from strawberry.relay import GlobalID, Node, NodeID
5
6
 
6
7
  from phoenix.db import models
8
+ from phoenix.server.api.context import Context
7
9
  from phoenix.server.api.types.Identifier import Identifier
10
+ from phoenix.server.api.types.User import User, to_gql_user
8
11
 
9
12
 
10
13
  @strawberry.type
11
14
  class PromptVersionTag(Node):
12
15
  id_attr: NodeID[int]
16
+ user_id: strawberry.Private[Optional[int]]
13
17
  prompt_version_id: GlobalID
14
18
  name: Identifier
15
19
  description: Optional[str] = None
16
20
 
21
+ @strawberry.field
22
+ async def user(self, info: Info[Context, None]) -> Optional[User]:
23
+ if self.user_id is None:
24
+ return None
25
+ async with info.context.db() as session:
26
+ user = await session.get(models.User, self.user_id)
27
+ return to_gql_user(user) if user is not None else None
28
+
17
29
 
18
30
  def to_gql_prompt_version_tag(prompt_version_tag: models.PromptVersionTag) -> PromptVersionTag:
19
31
  from phoenix.server.api.types.PromptVersion import PromptVersion
@@ -24,4 +36,5 @@ def to_gql_prompt_version_tag(prompt_version_tag: models.PromptVersionTag) -> Pr
24
36
  prompt_version_id=version_gid,
25
37
  name=Identifier(prompt_version_tag.name.root),
26
38
  description=prompt_version_tag.description,
39
+ user_id=prompt_version_tag.user_id,
27
40
  )
phoenix/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "8.21.0"
1
+ __version__ = "8.22.0"