arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -4,18 +4,11 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Query
4
4
  from pydantic import Field
5
5
  from sqlalchemy import select
6
6
  from starlette.requests import Request
7
- from starlette.status import (
8
- HTTP_204_NO_CONTENT,
9
- HTTP_403_FORBIDDEN,
10
- HTTP_404_NOT_FOUND,
11
- HTTP_422_UNPROCESSABLE_ENTITY,
12
- )
13
7
  from strawberry.relay import GlobalID
14
8
 
15
9
  from phoenix.config import DEFAULT_PROJECT_NAME
16
10
  from phoenix.db import models
17
11
  from phoenix.db.helpers import exclude_experiment_projects
18
- from phoenix.db.models import UserRoleName
19
12
  from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
20
13
  from phoenix.server.api.routers.v1.utils import (
21
14
  PaginatedResponseBody,
@@ -24,7 +17,7 @@ from phoenix.server.api.routers.v1.utils import (
24
17
  add_errors_to_responses,
25
18
  )
26
19
  from phoenix.server.api.types.Project import Project as ProjectNodeType
27
- from phoenix.server.authorization import is_not_locked
20
+ from phoenix.server.authorization import is_not_locked, require_admin
28
21
 
29
22
  router = APIRouter(tags=["projects"])
30
23
 
@@ -70,7 +63,7 @@ class UpdateProjectResponseBody(ResponseBody[Project]):
70
63
  response_description="A list of projects with pagination information", # noqa: E501
71
64
  responses=add_errors_to_responses(
72
65
  [
73
- HTTP_422_UNPROCESSABLE_ENTITY,
66
+ 422,
74
67
  ]
75
68
  ),
76
69
  )
@@ -115,7 +108,7 @@ async def get_projects(
115
108
  except ValueError:
116
109
  raise HTTPException(
117
110
  detail=f"Invalid cursor format: {cursor}",
118
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
111
+ status_code=422,
119
112
  )
120
113
 
121
114
  stmt = stmt.limit(limit + 1)
@@ -142,8 +135,8 @@ async def get_projects(
142
135
  response_description="The requested project", # noqa: E501
143
136
  responses=add_errors_to_responses(
144
137
  [
145
- HTTP_404_NOT_FOUND,
146
- HTTP_422_UNPROCESSABLE_ENTITY,
138
+ 404,
139
+ 422,
147
140
  ]
148
141
  ),
149
142
  )
@@ -182,7 +175,7 @@ async def get_project(
182
175
  response_description="The newly created project", # noqa: E501
183
176
  responses=add_errors_to_responses(
184
177
  [
185
- HTTP_422_UNPROCESSABLE_ENTITY,
178
+ 422,
186
179
  ]
187
180
  ),
188
181
  )
@@ -216,16 +209,16 @@ async def create_project(
216
209
 
217
210
  @router.put(
218
211
  "/projects/{project_identifier}",
219
- dependencies=[Depends(is_not_locked)],
212
+ dependencies=[Depends(require_admin), Depends(is_not_locked)],
220
213
  operation_id="updateProject",
221
214
  summary="Update a project by ID or name", # noqa: E501
222
215
  description="Update an existing project with new configuration. Project names cannot be changed. The project identifier is either project ID or project name. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
223
216
  response_description="The updated project", # noqa: E501
224
217
  responses=add_errors_to_responses(
225
218
  [
226
- HTTP_403_FORBIDDEN,
227
- HTTP_404_NOT_FOUND,
228
- HTTP_422_UNPROCESSABLE_ENTITY,
219
+ 403,
220
+ 404,
221
+ 422,
229
222
  ]
230
223
  ),
231
224
  )
@@ -251,20 +244,6 @@ async def update_project(
251
244
  Raises:
252
245
  HTTPException: If the project identifier format is invalid or the project is not found.
253
246
  """ # noqa: E501
254
- if request.app.state.authentication_enabled:
255
- async with request.app.state.db() as session:
256
- # Check if the user is an admin
257
- stmt = (
258
- select(models.UserRole.name)
259
- .join(models.User)
260
- .where(models.User.id == int(request.user.identity))
261
- )
262
- role_name: UserRoleName = await session.scalar(stmt)
263
- if role_name != "ADMIN" and role_name != "SYSTEM":
264
- raise HTTPException(
265
- status_code=HTTP_403_FORBIDDEN,
266
- detail="Only admins can update projects",
267
- )
268
247
  async with request.app.state.db() as session:
269
248
  project = await _get_project_by_identifier(session, project_identifier)
270
249
 
@@ -278,16 +257,17 @@ async def update_project(
278
257
 
279
258
  @router.delete(
280
259
  "/projects/{project_identifier}",
260
+ dependencies=[Depends(require_admin)],
281
261
  operation_id="deleteProject",
282
262
  summary="Delete a project by ID or name", # noqa: E501
283
263
  description="Delete an existing project and all its associated data. The project identifier is either project ID or project name. The default project cannot be deleted. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
284
264
  response_description="No content returned on successful deletion", # noqa: E501
285
- status_code=HTTP_204_NO_CONTENT,
265
+ status_code=204,
286
266
  responses=add_errors_to_responses(
287
267
  [
288
- HTTP_403_FORBIDDEN,
289
- HTTP_404_NOT_FOUND,
290
- HTTP_422_UNPROCESSABLE_ENTITY,
268
+ 403,
269
+ 404,
270
+ 422,
291
271
  ]
292
272
  ),
293
273
  )
@@ -311,27 +291,13 @@ async def delete_project(
311
291
  Raises:
312
292
  HTTPException: If the project identifier format is invalid, the project is not found, or it's the default project.
313
293
  """ # noqa: E501
314
- if request.app.state.authentication_enabled:
315
- async with request.app.state.db() as session:
316
- # Check if the user is an admin
317
- stmt = (
318
- select(models.UserRole.name)
319
- .join(models.User)
320
- .where(models.User.id == int(request.user.identity))
321
- )
322
- role_name: UserRoleName = await session.scalar(stmt)
323
- if role_name != "ADMIN" and role_name != "SYSTEM":
324
- raise HTTPException(
325
- status_code=HTTP_403_FORBIDDEN,
326
- detail="Only admins can delete projects",
327
- )
328
294
  async with request.app.state.db() as session:
329
295
  project = await _get_project_by_identifier(session, project_identifier)
330
296
 
331
297
  # The default project must not be deleted - it's forbidden
332
298
  if project.name == DEFAULT_PROJECT_NAME:
333
299
  raise HTTPException(
334
- status_code=HTTP_403_FORBIDDEN,
300
+ status_code=403,
335
301
  detail="The default project cannot be deleted",
336
302
  )
337
303
 
@@ -4,9 +4,9 @@ from typing import Any, Optional, Union
4
4
  from fastapi import APIRouter, Depends, HTTPException, Path, Query
5
5
  from pydantic import ValidationError, model_validator
6
6
  from sqlalchemy import select
7
+ from sqlalchemy.orm import joinedload
7
8
  from sqlalchemy.sql import Select
8
9
  from starlette.requests import Request
9
- from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY
10
10
  from strawberry.relay import GlobalID
11
11
  from typing_extensions import Self, TypeAlias, assert_never
12
12
 
@@ -43,6 +43,7 @@ class PromptData(V1RoutesBaseModel):
43
43
  name: Identifier
44
44
  description: Optional[str] = None
45
45
  source_prompt_id: Optional[str] = None
46
+ metadata: Optional[dict[str, Any]] = None
46
47
 
47
48
 
48
49
  class Prompt(PromptData):
@@ -110,7 +111,7 @@ router = APIRouter(tags=["prompts"])
110
111
  response_description="A list of prompts with pagination information",
111
112
  responses=add_errors_to_responses(
112
113
  [
113
- HTTP_422_UNPROCESSABLE_ENTITY,
114
+ 422,
114
115
  ]
115
116
  ),
116
117
  )
@@ -154,7 +155,7 @@ async def get_prompts(
154
155
  except ValueError:
155
156
  raise HTTPException(
156
157
  detail=f"Invalid cursor format: {cursor}",
157
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
158
+ status_code=422,
158
159
  )
159
160
 
160
161
  query = query.limit(limit + 1)
@@ -181,7 +182,7 @@ async def get_prompts(
181
182
  description="Retrieve all versions of a specific prompt with pagination support. Each prompt "
182
183
  "can have multiple versions with different configurations.",
183
184
  response_description="A list of prompt versions with pagination information",
184
- responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY, HTTP_404_NOT_FOUND]),
185
+ responses=add_errors_to_responses([422, 404]),
185
186
  response_model_by_alias=True,
186
187
  response_model_exclude_defaults=True,
187
188
  response_model_exclude_unset=True,
@@ -214,7 +215,7 @@ async def list_prompt_versions(
214
215
  HTTPException: If the cursor format is invalid, the prompt identifier is invalid,
215
216
  or the prompt is not found.
216
217
  """
217
- query = select(models.PromptVersion)
218
+ query = select(models.PromptVersion).options(joinedload(models.PromptVersion.prompt))
218
219
  query = _filter_by_prompt_identifier(query.join(models.Prompt), prompt_identifier)
219
220
  query = query.order_by(models.PromptVersion.id.desc())
220
221
 
@@ -226,7 +227,7 @@ async def list_prompt_versions(
226
227
  except ValueError:
227
228
  raise HTTPException(
228
229
  detail=f"Invalid cursor format: {cursor}",
229
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
230
+ status_code=422,
230
231
  )
231
232
 
232
233
  query = query.limit(limit + 1)
@@ -255,8 +256,8 @@ async def list_prompt_versions(
255
256
  response_description="The requested prompt version",
256
257
  responses=add_errors_to_responses(
257
258
  [
258
- HTTP_404_NOT_FOUND,
259
- HTTP_422_UNPROCESSABLE_ENTITY,
259
+ 404,
260
+ 422,
260
261
  ]
261
262
  ),
262
263
  response_model_by_alias=True,
@@ -286,11 +287,16 @@ async def get_prompt_version_by_prompt_version_id(
286
287
  PromptVersionNodeType.__name__,
287
288
  )
288
289
  except ValueError:
289
- raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt version ID")
290
+ raise HTTPException(422, "Invalid prompt version ID")
290
291
  async with request.app.state.db() as session:
291
- prompt_version = await session.get(models.PromptVersion, id_)
292
+ stmt = (
293
+ select(models.PromptVersion)
294
+ .options(joinedload(models.PromptVersion.prompt))
295
+ .where(models.PromptVersion.id == id_)
296
+ )
297
+ prompt_version = await session.scalar(stmt)
292
298
  if prompt_version is None:
293
- raise HTTPException(HTTP_404_NOT_FOUND)
299
+ raise HTTPException(404)
294
300
  data = _prompt_version_from_orm_version(prompt_version)
295
301
  return GetPromptResponseBody(data=data)
296
302
 
@@ -304,8 +310,8 @@ async def get_prompt_version_by_prompt_version_id(
304
310
  response_description="The prompt version with the specified tag",
305
311
  responses=add_errors_to_responses(
306
312
  [
307
- HTTP_404_NOT_FOUND,
308
- HTTP_422_UNPROCESSABLE_ENTITY,
313
+ 404,
314
+ 422,
309
315
  ]
310
316
  ),
311
317
  response_model_by_alias=True,
@@ -334,9 +340,10 @@ async def get_prompt_version_by_tag_name(
334
340
  try:
335
341
  name = Identifier.model_validate(tag_name)
336
342
  except ValidationError:
337
- raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid tag name")
343
+ raise HTTPException(422, "Invalid tag name")
338
344
  stmt = (
339
345
  select(models.PromptVersion)
346
+ .options(joinedload(models.PromptVersion.prompt))
340
347
  .join_from(models.PromptVersion, models.PromptVersionTag)
341
348
  .where(models.PromptVersionTag.name == name)
342
349
  )
@@ -344,7 +351,7 @@ async def get_prompt_version_by_tag_name(
344
351
  async with request.app.state.db() as session:
345
352
  prompt_version: models.PromptVersion = await session.scalar(stmt)
346
353
  if prompt_version is None:
347
- raise HTTPException(HTTP_404_NOT_FOUND)
354
+ raise HTTPException(404)
348
355
  data = _prompt_version_from_orm_version(prompt_version)
349
356
  return GetPromptResponseBody(data=data)
350
357
 
@@ -357,8 +364,8 @@ async def get_prompt_version_by_tag_name(
357
364
  response_description="The latest version of the specified prompt",
358
365
  responses=add_errors_to_responses(
359
366
  [
360
- HTTP_404_NOT_FOUND,
361
- HTTP_422_UNPROCESSABLE_ENTITY,
367
+ 404,
368
+ 422,
362
369
  ]
363
370
  ),
364
371
  response_model_by_alias=True,
@@ -382,12 +389,17 @@ async def get_prompt_version_by_latest(
382
389
  Raises:
383
390
  HTTPException: If the prompt identifier is invalid or no prompt version is found.
384
391
  """
385
- stmt = select(models.PromptVersion).order_by(models.PromptVersion.id.desc()).limit(1)
392
+ stmt = (
393
+ select(models.PromptVersion)
394
+ .options(joinedload(models.PromptVersion.prompt))
395
+ .order_by(models.PromptVersion.id.desc())
396
+ .limit(1)
397
+ )
386
398
  stmt = _filter_by_prompt_identifier(stmt.join(models.Prompt), prompt_identifier)
387
399
  async with request.app.state.db() as session:
388
400
  prompt_version: models.PromptVersion = await session.scalar(stmt)
389
401
  if prompt_version is None:
390
- raise HTTPException(HTTP_404_NOT_FOUND)
402
+ raise HTTPException(404)
391
403
  data = _prompt_version_from_orm_version(prompt_version)
392
404
  return GetPromptResponseBody(data=data)
393
405
 
@@ -401,7 +413,7 @@ async def get_prompt_version_by_latest(
401
413
  response_description="The newly created prompt version",
402
414
  responses=add_errors_to_responses(
403
415
  [
404
- HTTP_422_UNPROCESSABLE_ENTITY,
416
+ 422,
405
417
  ]
406
418
  ),
407
419
  response_model_by_alias=True,
@@ -431,7 +443,7 @@ async def create_prompt(
431
443
  or request_body.version.template_type != PromptTemplateType.CHAT
432
444
  ):
433
445
  raise HTTPException(
434
- HTTP_422_UNPROCESSABLE_ENTITY,
446
+ 422,
435
447
  "Only CHAT template type is supported for prompts",
436
448
  )
437
449
  prompt = request_body.prompt
@@ -439,7 +451,7 @@ async def create_prompt(
439
451
  name = Identifier.model_validate(prompt.name)
440
452
  except ValidationError as e:
441
453
  raise HTTPException(
442
- HTTP_422_UNPROCESSABLE_ENTITY,
454
+ 422,
443
455
  "Invalid name identifier for prompt: " + e.errors()[0]["msg"],
444
456
  )
445
457
  version = request_body.version
@@ -448,17 +460,15 @@ async def create_prompt(
448
460
  assert isinstance(user := request.user, PhoenixUser)
449
461
  user_id = int(user.identity)
450
462
  async with request.app.state.db() as session:
451
- if not (prompt_id := await session.scalar(select(models.Prompt.id).filter_by(name=name))):
463
+ if not (prompt_orm := await session.scalar(select(models.Prompt).filter_by(name=name))):
452
464
  prompt_orm = models.Prompt(
453
465
  name=name,
454
466
  description=prompt.description,
467
+ metadata_=prompt.metadata or {},
455
468
  )
456
- session.add(prompt_orm)
457
- await session.flush()
458
- prompt_id = prompt_orm.id
459
469
  version_orm = models.PromptVersion(
460
470
  user_id=user_id,
461
- prompt_id=prompt_id,
471
+ prompt=prompt_orm,
462
472
  description=version.description,
463
473
  model_provider=version.model_provider,
464
474
  model_name=version.model_name,
@@ -496,8 +506,8 @@ class GetPromptVersionTagsResponseBody(PaginatedResponseBody[PromptVersionTag]):
496
506
  response_description="A list of tags associated with the prompt version",
497
507
  responses=add_errors_to_responses(
498
508
  [
499
- HTTP_404_NOT_FOUND,
500
- HTTP_422_UNPROCESSABLE_ENTITY,
509
+ 404,
510
+ 422,
501
511
  ]
502
512
  ),
503
513
  response_model_by_alias=True,
@@ -537,7 +547,7 @@ async def list_prompt_version_tags(
537
547
  PromptVersionNodeType.__name__,
538
548
  )
539
549
  except ValueError:
540
- raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt version ID")
550
+ raise HTTPException(422, "Invalid prompt version ID")
541
551
 
542
552
  # Build the query for tags
543
553
  stmt = (
@@ -560,7 +570,7 @@ async def list_prompt_version_tags(
560
570
  except ValueError:
561
571
  raise HTTPException(
562
572
  detail=f"Invalid cursor format: {cursor}",
563
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
573
+ status_code=422,
564
574
  )
565
575
 
566
576
  # Apply limit
@@ -571,7 +581,7 @@ async def list_prompt_version_tags(
571
581
 
572
582
  # Check if prompt version exists
573
583
  if not result:
574
- raise HTTPException(HTTP_404_NOT_FOUND, "Prompt version not found")
584
+ raise HTTPException(404, "Prompt version not found")
575
585
 
576
586
  # Check if there are any tags
577
587
  has_tags = any(id_ is not None for _, id_, _, _ in result)
@@ -610,11 +620,11 @@ async def list_prompt_version_tags(
610
620
  description="Add a new tag to a specific prompt version. Tags help identify and categorize "
611
621
  "different versions of a prompt.",
612
622
  response_description="No content returned on successful tag creation",
613
- status_code=HTTP_204_NO_CONTENT,
623
+ status_code=204,
614
624
  responses=add_errors_to_responses(
615
625
  [
616
- HTTP_404_NOT_FOUND,
617
- HTTP_422_UNPROCESSABLE_ENTITY,
626
+ 404,
627
+ 422,
618
628
  ]
619
629
  ),
620
630
  response_model_by_alias=True,
@@ -647,7 +657,7 @@ async def create_prompt_version_tag(
647
657
  PromptVersionNodeType.__name__,
648
658
  )
649
659
  except ValueError:
650
- raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt version ID")
660
+ raise HTTPException(422, "Invalid prompt version ID")
651
661
  user_id: Optional[int] = None
652
662
  if request.app.state.authentication_enabled:
653
663
  assert isinstance(user := request.user, PhoenixUser)
@@ -655,7 +665,7 @@ async def create_prompt_version_tag(
655
665
  async with request.app.state.db() as session:
656
666
  prompt_id = await session.scalar(select(models.PromptVersion.prompt_id).filter_by(id=id_))
657
667
  if prompt_id is None:
658
- raise HTTPException(HTTP_404_NOT_FOUND)
668
+ raise HTTPException(404)
659
669
  dialect = SupportedSQLDialect(session.bind.dialect.name)
660
670
  values = dict(
661
671
  name=request_body.name,
@@ -686,7 +696,7 @@ def _parse_prompt_identifier(
686
696
  prompt_identifier: str,
687
697
  ) -> _PromptIdentifier:
688
698
  if not prompt_identifier:
689
- raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt identifier")
699
+ raise HTTPException(422, "Invalid prompt identifier")
690
700
  try:
691
701
  prompt_id = from_global_id_with_expected_type(
692
702
  GlobalID.from_id(prompt_identifier),
@@ -696,7 +706,7 @@ def _parse_prompt_identifier(
696
706
  try:
697
707
  return Identifier.model_validate(prompt_identifier)
698
708
  except ValidationError:
699
- raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, "Invalid prompt name")
709
+ raise HTTPException(422, "Invalid prompt name")
700
710
  return _PromptId(prompt_id)
701
711
 
702
712
 
@@ -742,4 +752,5 @@ def _prompt_from_orm_prompt(orm_prompt: models.Prompt) -> Prompt:
742
752
  source_prompt_id=source_prompt_id,
743
753
  name=orm_prompt.name,
744
754
  description=orm_prompt.description,
755
+ metadata=orm_prompt.metadata_,
745
756
  )
@@ -0,0 +1,108 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import Optional
5
+
6
+ from fastapi import APIRouter, Depends, HTTPException, Query
7
+ from pydantic import Field
8
+ from sqlalchemy import select
9
+ from starlette.requests import Request
10
+
11
+ from phoenix.db import models
12
+ from phoenix.db.helpers import SupportedSQLDialect
13
+ from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
14
+ from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
15
+ from phoenix.server.authorization import is_not_locked
16
+ from phoenix.server.bearer_auth import PhoenixUser
17
+
18
+ from .annotations import SessionAnnotationData
19
+ from .utils import RequestBody, ResponseBody, add_errors_to_responses
20
+
21
+ router = APIRouter(tags=["sessions"])
22
+
23
+
24
+ class InsertedSessionAnnotation(V1RoutesBaseModel):
25
+ id: str = Field(description="The ID of the inserted session annotation")
26
+
27
+
28
+ class AnnotateSessionsRequestBody(RequestBody[list[SessionAnnotationData]]):
29
+ pass
30
+
31
+
32
+ class AnnotateSessionsResponseBody(ResponseBody[list[InsertedSessionAnnotation]]):
33
+ pass
34
+
35
+
36
+ @router.post(
37
+ "/session_annotations",
38
+ dependencies=[Depends(is_not_locked)],
39
+ operation_id="annotateSessions",
40
+ summary="Create session annotations",
41
+ responses=add_errors_to_responses([{"status_code": 404, "description": "Session not found"}]),
42
+ response_description="Session annotations inserted successfully",
43
+ include_in_schema=True,
44
+ )
45
+ async def annotate_sessions(
46
+ request: Request,
47
+ request_body: AnnotateSessionsRequestBody,
48
+ sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
49
+ ) -> AnnotateSessionsResponseBody:
50
+ if not request_body.data:
51
+ return AnnotateSessionsResponseBody(data=[])
52
+
53
+ user_id: Optional[int] = None
54
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
55
+ user_id = int(request.user.identity)
56
+
57
+ session_annotations = request_body.data
58
+ filtered_session_annotations = list(filter(lambda d: d.name != "note", session_annotations))
59
+ if len(filtered_session_annotations) != len(session_annotations):
60
+ warnings.warn(
61
+ (
62
+ "Session annotations with the name 'note' are not supported in this endpoint. "
63
+ "They will be ignored."
64
+ ),
65
+ UserWarning,
66
+ )
67
+ precursors = [d.as_precursor(user_id=user_id) for d in filtered_session_annotations]
68
+ if not sync:
69
+ await request.state.enqueue_annotations(*precursors)
70
+ return AnnotateSessionsResponseBody(data=[])
71
+
72
+ session_ids = {p.session_id for p in precursors}
73
+ async with request.app.state.db() as session:
74
+ existing_sessions = {
75
+ session_id: rowid
76
+ async for session_id, rowid in await session.stream(
77
+ select(models.ProjectSession.session_id, models.ProjectSession.id).filter(
78
+ models.ProjectSession.session_id.in_(session_ids)
79
+ )
80
+ )
81
+ }
82
+
83
+ missing_session_ids = session_ids - set(existing_sessions.keys())
84
+ # We prefer to fail the entire operation if there are missing sessions in sync mode
85
+ if missing_session_ids:
86
+ raise HTTPException(
87
+ detail=f"Sessions with IDs {', '.join(missing_session_ids)} do not exist.",
88
+ status_code=404,
89
+ )
90
+
91
+ async with request.app.state.db() as session:
92
+ inserted_ids = []
93
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
94
+ for p in precursors:
95
+ values = dict(as_kv(p.as_insertable(existing_sessions[p.session_id]).row))
96
+ session_annotation_id = await session.scalar(
97
+ insert_on_conflict(
98
+ values,
99
+ dialect=dialect,
100
+ table=models.ProjectSessionAnnotation,
101
+ unique_by=("name", "project_session_id", "identifier"),
102
+ ).returning(models.ProjectSessionAnnotation.id)
103
+ )
104
+ inserted_ids.append(session_annotation_id)
105
+
106
+ return AnnotateSessionsResponseBody(
107
+ data=[InsertedSessionAnnotation(id=str(inserted_id)) for inserted_id in inserted_ids]
108
+ )