arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,158 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
5
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
6
+ from starlette.requests import Request
7
+ from strawberry import Info
8
+ from strawberry.relay import GlobalID
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
12
+ from phoenix.server.api.context import Context
13
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound, Unauthorized
14
+ from phoenix.server.api.helpers.annotations import get_user_identifier
15
+ from phoenix.server.api.input_types.CreateProjectSessionAnnotationInput import (
16
+ CreateProjectSessionAnnotationInput,
17
+ )
18
+ from phoenix.server.api.input_types.UpdateAnnotationInput import UpdateAnnotationInput
19
+ from phoenix.server.api.queries import Query
20
+ from phoenix.server.api.types.AnnotationSource import AnnotationSource
21
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
22
+ from phoenix.server.api.types.ProjectSessionAnnotation import ProjectSessionAnnotation
23
+ from phoenix.server.bearer_auth import PhoenixUser
24
+ from phoenix.server.dml_event import (
25
+ ProjectSessionAnnotationDeleteEvent,
26
+ ProjectSessionAnnotationInsertEvent,
27
+ )
28
+
29
+
30
+ @strawberry.type
31
+ class ProjectSessionAnnotationMutationPayload:
32
+ project_session_annotation: ProjectSessionAnnotation
33
+ query: Query
34
+
35
+
36
+ @strawberry.type
37
+ class ProjectSessionAnnotationMutationMixin:
38
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
39
+ async def create_project_session_annotations(
40
+ self, info: Info[Context, None], input: CreateProjectSessionAnnotationInput
41
+ ) -> ProjectSessionAnnotationMutationPayload:
42
+ assert isinstance(request := info.context.request, Request)
43
+ user_id: Optional[int] = None
44
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
45
+ user_id = int(user.identity)
46
+
47
+ try:
48
+ project_session_id = from_global_id_with_expected_type(
49
+ input.project_session_id, "ProjectSession"
50
+ )
51
+ except ValueError:
52
+ raise BadRequest(f"Invalid session ID: {input.project_session_id}")
53
+
54
+ identifier = ""
55
+ if isinstance(input.identifier, str):
56
+ identifier = input.identifier # Already trimmed in __post_init__
57
+ elif input.source == AnnotationSource.APP and user_id is not None:
58
+ identifier = get_user_identifier(user_id)
59
+
60
+ try:
61
+ async with info.context.db() as session:
62
+ anno = models.ProjectSessionAnnotation(
63
+ project_session_id=project_session_id,
64
+ name=input.name,
65
+ label=input.label,
66
+ score=input.score,
67
+ explanation=input.explanation,
68
+ annotator_kind=input.annotator_kind.value,
69
+ metadata_=input.metadata,
70
+ identifier=identifier,
71
+ source=input.source.value,
72
+ user_id=user_id,
73
+ )
74
+ session.add(anno)
75
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
76
+ raise Conflict(f"Error creating annotation: {e}")
77
+
78
+ info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
79
+
80
+ return ProjectSessionAnnotationMutationPayload(
81
+ project_session_annotation=ProjectSessionAnnotation(id=anno.id, db_record=anno),
82
+ query=Query(),
83
+ )
84
+
85
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
86
+ async def update_project_session_annotations(
87
+ self, info: Info[Context, None], input: UpdateAnnotationInput
88
+ ) -> ProjectSessionAnnotationMutationPayload:
89
+ assert isinstance(request := info.context.request, Request)
90
+ user_id: Optional[int] = None
91
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
92
+ user_id = int(user.identity)
93
+
94
+ try:
95
+ id_ = from_global_id_with_expected_type(input.id, "ProjectSessionAnnotation")
96
+ except ValueError:
97
+ raise BadRequest(f"Invalid session annotation ID: {input.id}")
98
+
99
+ async with info.context.db() as session:
100
+ if not (anno := await session.get(models.ProjectSessionAnnotation, id_)):
101
+ raise NotFound(f"Could not find session annotation with ID: {input.id}")
102
+ if anno.user_id != user_id:
103
+ raise Unauthorized("Session annotation is not associated with the current user.")
104
+
105
+ # Update the annotation fields
106
+ anno.name = input.name
107
+ anno.label = input.label
108
+ anno.score = input.score
109
+ anno.explanation = input.explanation
110
+ anno.annotator_kind = input.annotator_kind.value
111
+ anno.metadata_ = input.metadata
112
+ anno.source = input.source.value
113
+
114
+ session.add(anno)
115
+ try:
116
+ await session.flush()
117
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
118
+ raise Conflict(f"Error updating annotation: {e}")
119
+
120
+ info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
121
+ return ProjectSessionAnnotationMutationPayload(
122
+ project_session_annotation=ProjectSessionAnnotation(id=anno.id, db_record=anno),
123
+ query=Query(),
124
+ )
125
+
126
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
127
+ async def delete_project_session_annotation(
128
+ self, info: Info[Context, None], id: GlobalID
129
+ ) -> ProjectSessionAnnotationMutationPayload:
130
+ try:
131
+ id_ = from_global_id_with_expected_type(id, "ProjectSessionAnnotation")
132
+ except ValueError:
133
+ raise BadRequest(f"Invalid session annotation ID: {id}")
134
+
135
+ assert isinstance(request := info.context.request, Request)
136
+ user_id: Optional[int] = None
137
+ user_is_admin = False
138
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
139
+ user_id = int(user.identity)
140
+ user_is_admin = user.is_admin
141
+
142
+ async with info.context.db() as session:
143
+ if not (anno := await session.get(models.ProjectSessionAnnotation, id_)):
144
+ raise NotFound(f"Could not find session annotation with ID: {id}")
145
+
146
+ if not user_is_admin and anno.user_id != user_id:
147
+ raise Unauthorized(
148
+ "Session annotation is not associated with the current user and "
149
+ "the current user is not an admin."
150
+ )
151
+
152
+ await session.delete(anno)
153
+
154
+ deleted_gql_annotation = ProjectSessionAnnotation(id=anno.id, db_record=anno)
155
+ info.context.event_queue.put(ProjectSessionAnnotationDeleteEvent((id_,)))
156
+ return ProjectSessionAnnotationMutationPayload(
157
+ project_session_annotation=deleted_gql_annotation, query=Query()
158
+ )
@@ -16,7 +16,7 @@ from phoenix.db.types.trace_retention import (
16
16
  TraceRetentionCronExpression,
17
17
  TraceRetentionRule,
18
18
  )
19
- from phoenix.server.api.auth import IsAdminIfAuthEnabled, IsLocked, IsNotReadOnly
19
+ from phoenix.server.api.auth import IsAdminIfAuthEnabled, IsLocked, IsNotReadOnly, IsNotViewer
20
20
  from phoenix.server.api.context import Context
21
21
  from phoenix.server.api.exceptions import BadRequest, NotFound
22
22
  from phoenix.server.api.queries import Query
@@ -113,7 +113,9 @@ class ProjectTraceRetentionPolicyMutationPayload:
113
113
 
114
114
  @strawberry.type
115
115
  class ProjectTraceRetentionPolicyMutationMixin:
116
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdminIfAuthEnabled, IsLocked]) # type: ignore
116
+ @strawberry.mutation(
117
+ permission_classes=[IsNotReadOnly, IsNotViewer, IsAdminIfAuthEnabled, IsLocked]
118
+ ) # type: ignore
117
119
  async def create_project_trace_retention_policy(
118
120
  self,
119
121
  info: Info[Context, None],
@@ -146,7 +148,9 @@ class ProjectTraceRetentionPolicyMutationMixin:
146
148
  node=ProjectTraceRetentionPolicy(id=policy.id, db_policy=policy),
147
149
  )
148
150
 
149
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdminIfAuthEnabled, IsLocked]) # type: ignore
151
+ @strawberry.mutation(
152
+ permission_classes=[IsNotReadOnly, IsNotViewer, IsAdminIfAuthEnabled, IsLocked]
153
+ ) # type: ignore
150
154
  async def patch_project_trace_retention_policy(
151
155
  self,
152
156
  info: Info[Context, None],
@@ -204,7 +208,7 @@ class ProjectTraceRetentionPolicyMutationMixin:
204
208
  node=ProjectTraceRetentionPolicy(id=policy.id, db_policy=policy),
205
209
  )
206
210
 
207
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdminIfAuthEnabled]) # type: ignore
211
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsAdminIfAuthEnabled]) # type: ignore
208
212
  async def delete_project_trace_retention_policy(
209
213
  self,
210
214
  info: Info[Context, None],
@@ -10,79 +10,89 @@ from strawberry.relay import GlobalID
10
10
  from strawberry.types import Info
11
11
 
12
12
  from phoenix.db import models
13
- from phoenix.db.types.identifier import Identifier as IdentifierModel
14
- from phoenix.server.api.auth import IsLocked, IsNotReadOnly
13
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
15
14
  from phoenix.server.api.context import Context
16
15
  from phoenix.server.api.exceptions import Conflict, NotFound
17
16
  from phoenix.server.api.queries import Query
18
- from phoenix.server.api.types.Identifier import Identifier
19
17
  from phoenix.server.api.types.node import from_global_id_with_expected_type
20
18
  from phoenix.server.api.types.Prompt import Prompt
21
- from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
19
+ from phoenix.server.api.types.PromptLabel import PromptLabel
22
20
 
23
21
 
24
22
  @strawberry.input
25
23
  class CreatePromptLabelInput:
26
- name: Identifier
24
+ name: str
27
25
  description: Optional[str] = None
26
+ color: str
28
27
 
29
28
 
30
29
  @strawberry.input
31
30
  class PatchPromptLabelInput:
32
31
  prompt_label_id: GlobalID
33
- name: Optional[Identifier] = None
32
+ name: Optional[str] = None
34
33
  description: Optional[str] = None
35
34
 
36
35
 
37
36
  @strawberry.input
38
- class DeletePromptLabelInput:
39
- prompt_label_id: GlobalID
37
+ class DeletePromptLabelsInput:
38
+ prompt_label_ids: list[GlobalID]
40
39
 
41
40
 
42
41
  @strawberry.input
43
- class SetPromptLabelInput:
42
+ class SetPromptLabelsInput:
44
43
  prompt_id: GlobalID
45
- prompt_label_id: GlobalID
44
+ prompt_label_ids: list[GlobalID]
46
45
 
47
46
 
48
47
  @strawberry.input
49
- class UnsetPromptLabelInput:
48
+ class UnsetPromptLabelsInput:
50
49
  prompt_id: GlobalID
51
- prompt_label_id: GlobalID
50
+ prompt_label_ids: list[GlobalID]
52
51
 
53
52
 
54
53
  @strawberry.type
55
54
  class PromptLabelMutationPayload:
56
- prompt_label: Optional["PromptLabel"]
55
+ prompt_labels: list["PromptLabel"]
56
+ query: "Query"
57
+
58
+
59
+ @strawberry.type
60
+ class PromptLabelDeleteMutationPayload:
61
+ deleted_prompt_label_ids: list["GlobalID"]
62
+ query: "Query"
63
+
64
+
65
+ @strawberry.type
66
+ class PromptLabelAssociationMutationPayload:
57
67
  query: "Query"
58
68
 
59
69
 
60
70
  @strawberry.type
61
71
  class PromptLabelMutationMixin:
62
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
72
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
63
73
  async def create_prompt_label(
64
74
  self, info: Info[Context, None], input: CreatePromptLabelInput
65
75
  ) -> PromptLabelMutationPayload:
66
76
  async with info.context.db() as session:
67
- name = IdentifierModel.model_validate(str(input.name))
68
- label_orm = models.PromptLabel(name=name, description=input.description)
77
+ label_orm = models.PromptLabel(
78
+ name=input.name, description=input.description, color=input.color
79
+ )
69
80
  session.add(label_orm)
70
81
 
71
82
  try:
72
83
  await session.commit()
73
84
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
74
- raise Conflict(f"A prompt label named '{name}' already exists.")
85
+ raise Conflict(f"A prompt label named '{input.name}' already exists.")
75
86
 
76
87
  return PromptLabelMutationPayload(
77
- prompt_label=to_gql_prompt_label(label_orm),
88
+ prompt_labels=[PromptLabel(id=label_orm.id, db_record=label_orm)],
78
89
  query=Query(),
79
90
  )
80
91
 
81
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
92
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
82
93
  async def patch_prompt_label(
83
94
  self, info: Info[Context, None], input: PatchPromptLabelInput
84
95
  ) -> PromptLabelMutationPayload:
85
- validated_name = IdentifierModel.model_validate(str(input.name)) if input.name else None
86
96
  async with info.context.db() as session:
87
97
  label_id = from_global_id_with_expected_type(
88
98
  input.prompt_label_id, PromptLabel.__name__
@@ -92,8 +102,8 @@ class PromptLabelMutationMixin:
92
102
  if not label_orm:
93
103
  raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
94
104
 
95
- if validated_name is not None:
96
- label_orm.name = validated_name.root
105
+ if input.name is not None:
106
+ label_orm.name = input.name
97
107
  if input.description is not None:
98
108
  label_orm.description = input.description
99
109
 
@@ -103,46 +113,48 @@ class PromptLabelMutationMixin:
103
113
  raise Conflict("Error patching PromptLabel. Possibly a name conflict?")
104
114
 
105
115
  return PromptLabelMutationPayload(
106
- prompt_label=to_gql_prompt_label(label_orm),
116
+ prompt_labels=[PromptLabel(id=label_orm.id, db_record=label_orm)],
107
117
  query=Query(),
108
118
  )
109
119
 
110
- @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
111
- async def delete_prompt_label(
112
- self, info: Info[Context, None], input: DeletePromptLabelInput
113
- ) -> PromptLabelMutationPayload:
120
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
121
+ async def delete_prompt_labels(
122
+ self, info: Info[Context, None], input: DeletePromptLabelsInput
123
+ ) -> PromptLabelDeleteMutationPayload:
114
124
  """
115
125
  Deletes a PromptLabel (and any crosswalk references).
116
126
  """
117
127
  async with info.context.db() as session:
118
- label_id = from_global_id_with_expected_type(
119
- input.prompt_label_id, PromptLabel.__name__
120
- )
121
- stmt = delete(models.PromptLabel).where(models.PromptLabel.id == label_id)
122
- result = await session.execute(stmt)
123
-
124
- if result.rowcount == 0:
125
- raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
128
+ label_ids = [
129
+ from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
130
+ for prompt_label_id in input.prompt_label_ids
131
+ ]
132
+ stmt = delete(models.PromptLabel).where(models.PromptLabel.id.in_(label_ids))
133
+ await session.execute(stmt)
126
134
 
127
135
  await session.commit()
128
136
 
129
- return PromptLabelMutationPayload(
130
- prompt_label=None,
137
+ return PromptLabelDeleteMutationPayload(
138
+ deleted_prompt_label_ids=input.prompt_label_ids,
131
139
  query=Query(),
132
140
  )
133
141
 
134
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
135
- async def set_prompt_label(
136
- self, info: Info[Context, None], input: SetPromptLabelInput
137
- ) -> PromptLabelMutationPayload:
142
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
143
+ async def set_prompt_labels(
144
+ self, info: Info[Context, None], input: SetPromptLabelsInput
145
+ ) -> PromptLabelAssociationMutationPayload:
138
146
  async with info.context.db() as session:
139
147
  prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
140
- label_id = from_global_id_with_expected_type(
141
- input.prompt_label_id, PromptLabel.__name__
142
- )
148
+ label_ids = [
149
+ from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
150
+ for prompt_label_id in input.prompt_label_ids
151
+ ]
143
152
 
144
- crosswalk = models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
145
- session.add(crosswalk)
153
+ crosswalk_items = [
154
+ models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
155
+ for label_id in label_ids
156
+ ]
157
+ session.add_all(crosswalk_items)
146
158
 
147
159
  try:
148
160
  await session.commit()
@@ -152,41 +164,38 @@ class PromptLabelMutationMixin:
152
164
  # - Foreign key violation => prompt_id or label_id doesn't exist
153
165
  raise Conflict("Failed to associate PromptLabel with Prompt.") from e
154
166
 
155
- label_orm = await session.get(models.PromptLabel, label_id)
156
- if not label_orm:
157
- raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
158
-
159
- return PromptLabelMutationPayload(
160
- prompt_label=to_gql_prompt_label(label_orm),
167
+ return PromptLabelAssociationMutationPayload(
161
168
  query=Query(),
162
169
  )
163
170
 
164
- @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
165
- async def unset_prompt_label(
166
- self, info: Info[Context, None], input: UnsetPromptLabelInput
167
- ) -> PromptLabelMutationPayload:
171
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
172
+ async def unset_prompt_labels(
173
+ self, info: Info[Context, None], input: UnsetPromptLabelsInput
174
+ ) -> PromptLabelAssociationMutationPayload:
168
175
  """
169
176
  Unsets a PromptLabel from a Prompt by removing the row in the crosswalk.
170
177
  """
171
178
  async with info.context.db() as session:
172
179
  prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
173
- label_id = from_global_id_with_expected_type(
174
- input.prompt_label_id, PromptLabel.__name__
175
- )
180
+ label_ids = [
181
+ from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
182
+ for prompt_label_id in input.prompt_label_ids
183
+ ]
176
184
 
177
185
  stmt = delete(models.PromptPromptLabel).where(
178
186
  (models.PromptPromptLabel.prompt_id == prompt_id)
179
- & (models.PromptPromptLabel.prompt_label_id == label_id)
187
+ & (models.PromptPromptLabel.prompt_label_id.in_(label_ids))
180
188
  )
181
189
  result = await session.execute(stmt)
182
190
 
183
- if result.rowcount == 0:
184
- raise NotFound(f"No association between prompt={prompt_id} and label={label_id}.")
191
+ if result.rowcount != len(label_ids): # type: ignore[attr-defined]
192
+ label_ids_str = ", ".join(str(i) for i in label_ids)
193
+ raise NotFound(
194
+ f"No association between prompt={prompt_id} and labels={label_ids_str}."
195
+ )
185
196
 
186
197
  await session.commit()
187
198
 
188
- label_orm = await session.get(models.PromptLabel, label_id)
189
- return PromptLabelMutationPayload(
190
- prompt_label=to_gql_prompt_label(label_orm) if label_orm else None,
199
+ return PromptLabelAssociationMutationPayload(
191
200
  query=Query(),
192
201
  )