arize-phoenix 8.32.0__py3-none-any.whl → 9.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (81) hide show
  1. {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/METADATA +3 -2
  2. {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/RECORD +78 -58
  3. phoenix/db/constants.py +1 -0
  4. phoenix/db/facilitator.py +55 -0
  5. phoenix/db/insertion/document_annotation.py +31 -13
  6. phoenix/db/insertion/evaluation.py +15 -3
  7. phoenix/db/insertion/helpers.py +2 -1
  8. phoenix/db/insertion/span_annotation.py +26 -9
  9. phoenix/db/insertion/trace_annotation.py +25 -9
  10. phoenix/db/insertion/types.py +7 -0
  11. phoenix/db/migrations/versions/2f9d1a65945f_annotation_config_migration.py +322 -0
  12. phoenix/db/migrations/versions/8a3764fe7f1a_change_jsonb_to_json_for_prompts.py +76 -0
  13. phoenix/db/migrations/versions/bb8139330879_create_project_trace_retention_policies_table.py +77 -0
  14. phoenix/db/models.py +151 -10
  15. phoenix/db/types/annotation_configs.py +97 -0
  16. phoenix/db/types/db_models.py +41 -0
  17. phoenix/db/types/trace_retention.py +267 -0
  18. phoenix/experiments/functions.py +5 -1
  19. phoenix/server/api/auth.py +9 -0
  20. phoenix/server/api/context.py +5 -0
  21. phoenix/server/api/dataloaders/__init__.py +4 -0
  22. phoenix/server/api/dataloaders/annotation_summaries.py +203 -24
  23. phoenix/server/api/dataloaders/project_ids_by_trace_retention_policy_id.py +42 -0
  24. phoenix/server/api/dataloaders/trace_retention_policy_id_by_project_id.py +34 -0
  25. phoenix/server/api/helpers/annotations.py +9 -0
  26. phoenix/server/api/helpers/prompts/models.py +34 -67
  27. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +9 -0
  28. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +3 -0
  29. phoenix/server/api/input_types/PatchAnnotationInput.py +3 -0
  30. phoenix/server/api/input_types/SpanAnnotationFilter.py +67 -0
  31. phoenix/server/api/mutations/__init__.py +6 -0
  32. phoenix/server/api/mutations/annotation_config_mutations.py +413 -0
  33. phoenix/server/api/mutations/dataset_mutations.py +62 -39
  34. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +245 -0
  35. phoenix/server/api/mutations/span_annotations_mutations.py +272 -70
  36. phoenix/server/api/mutations/trace_annotations_mutations.py +203 -74
  37. phoenix/server/api/queries.py +86 -0
  38. phoenix/server/api/routers/v1/__init__.py +4 -0
  39. phoenix/server/api/routers/v1/annotation_configs.py +449 -0
  40. phoenix/server/api/routers/v1/annotations.py +161 -0
  41. phoenix/server/api/routers/v1/evaluations.py +6 -0
  42. phoenix/server/api/routers/v1/projects.py +1 -50
  43. phoenix/server/api/routers/v1/spans.py +37 -8
  44. phoenix/server/api/routers/v1/traces.py +22 -13
  45. phoenix/server/api/routers/v1/utils.py +60 -0
  46. phoenix/server/api/types/Annotation.py +7 -0
  47. phoenix/server/api/types/AnnotationConfig.py +124 -0
  48. phoenix/server/api/types/AnnotationSource.py +9 -0
  49. phoenix/server/api/types/AnnotationSummary.py +28 -14
  50. phoenix/server/api/types/AnnotatorKind.py +1 -0
  51. phoenix/server/api/types/CronExpression.py +15 -0
  52. phoenix/server/api/types/Evaluation.py +4 -30
  53. phoenix/server/api/types/Project.py +50 -2
  54. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +110 -0
  55. phoenix/server/api/types/Span.py +78 -0
  56. phoenix/server/api/types/SpanAnnotation.py +24 -0
  57. phoenix/server/api/types/Trace.py +2 -2
  58. phoenix/server/api/types/TraceAnnotation.py +23 -0
  59. phoenix/server/app.py +20 -0
  60. phoenix/server/retention.py +76 -0
  61. phoenix/server/static/.vite/manifest.json +36 -36
  62. phoenix/server/static/assets/components-B2MWTXnm.js +4326 -0
  63. phoenix/server/static/assets/{index-B0CbpsxD.js → index-Bfvpea_-.js} +10 -10
  64. phoenix/server/static/assets/pages-CZ2vKu8H.js +7268 -0
  65. phoenix/server/static/assets/vendor-BRDkBC5J.js +903 -0
  66. phoenix/server/static/assets/{vendor-arizeai-CxXYQNUl.js → vendor-arizeai-BvTqp_W8.js} +3 -3
  67. phoenix/server/static/assets/{vendor-codemirror-B0NIFPOL.js → vendor-codemirror-COt9UfW7.js} +1 -1
  68. phoenix/server/static/assets/{vendor-recharts-CrrDFWK1.js → vendor-recharts-BoHX9Hvs.js} +2 -2
  69. phoenix/server/static/assets/{vendor-shiki-C5bJ-RPf.js → vendor-shiki-Cw1dsDAz.js} +1 -1
  70. phoenix/session/client.py +6 -1
  71. phoenix/trace/dsl/filter.py +25 -5
  72. phoenix/trace/dsl/query.py +93 -13
  73. phoenix/utilities/__init__.py +18 -0
  74. phoenix/version.py +1 -1
  75. phoenix/server/static/assets/components-x-gKFJ8C.js +0 -3414
  76. phoenix/server/static/assets/pages-BU4VdyeH.js +0 -5867
  77. phoenix/server/static/assets/vendor-BfhM_F1u.js +0 -902
  78. {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/WHEEL +0 -0
  79. {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/entry_points.txt +0 -0
  80. {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  81. {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,245 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import sqlalchemy as sa
6
+ import strawberry
7
+ from strawberry import UNSET, Info
8
+ from strawberry.relay import GlobalID
9
+
10
+ from phoenix.db import models
11
+ from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
12
+ from phoenix.db.types.trace_retention import (
13
+ MaxCountRule,
14
+ MaxDaysOrCountRule,
15
+ MaxDaysRule,
16
+ TraceRetentionCronExpression,
17
+ TraceRetentionRule,
18
+ )
19
+ from phoenix.server.api.auth import IsAdminIfAuthEnabled, IsLocked, IsNotReadOnly
20
+ from phoenix.server.api.context import Context
21
+ from phoenix.server.api.exceptions import BadRequest, NotFound
22
+ from phoenix.server.api.queries import Query
23
+ from phoenix.server.api.types.CronExpression import CronExpression
24
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
25
+ from phoenix.server.api.types.Project import Project
26
+ from phoenix.server.api.types.ProjectTraceRetentionPolicy import (
27
+ ProjectTraceRetentionPolicy,
28
+ )
29
+
30
+
31
+ @strawberry.input
32
+ class ProjectTraceRetentionRuleMaxDaysInput:
33
+ max_days: float
34
+
35
+
36
+ @strawberry.input
37
+ class ProjectTraceRetentionRuleMaxCountInput:
38
+ max_count: int
39
+
40
+
41
+ @strawberry.input
42
+ class ProjectTraceRetentionRuleMaxDaysOrCountInput(
43
+ ProjectTraceRetentionRuleMaxDaysInput,
44
+ ProjectTraceRetentionRuleMaxCountInput,
45
+ ): ...
46
+
47
+
48
+ @strawberry.input(one_of=True)
49
+ class ProjectTraceRetentionRuleInput:
50
+ max_days: Optional[ProjectTraceRetentionRuleMaxDaysInput] = UNSET
51
+ max_count: Optional[ProjectTraceRetentionRuleMaxCountInput] = UNSET
52
+ max_days_or_count: Optional[ProjectTraceRetentionRuleMaxDaysOrCountInput] = UNSET
53
+
54
+ def __post_init__(self) -> None:
55
+ if (
56
+ sum(
57
+ (
58
+ isinstance(self.max_days, ProjectTraceRetentionRuleMaxDaysInput),
59
+ isinstance(self.max_count, ProjectTraceRetentionRuleMaxCountInput),
60
+ isinstance(
61
+ self.max_days_or_count, ProjectTraceRetentionRuleMaxDaysOrCountInput
62
+ ),
63
+ )
64
+ )
65
+ != 1
66
+ ):
67
+ raise BadRequest("Exactly one rule must be provided")
68
+
69
+
70
+ @strawberry.input
71
+ class CreateProjectTraceRetentionPolicyInput:
72
+ name: str
73
+ cron_expression: CronExpression
74
+ rule: ProjectTraceRetentionRuleInput
75
+ add_projects: Optional[list[GlobalID]] = UNSET
76
+
77
+ def __post_init__(self) -> None:
78
+ if not self.name.strip():
79
+ raise BadRequest("Name cannot be empty")
80
+ if not self.cron_expression.strip():
81
+ raise BadRequest("Cron expression cannot be empty")
82
+
83
+
84
+ @strawberry.input
85
+ class PatchProjectTraceRetentionPolicyInput:
86
+ id: GlobalID
87
+ name: Optional[str] = UNSET
88
+ cron_expression: Optional[CronExpression] = UNSET
89
+ rule: Optional[ProjectTraceRetentionRuleInput] = UNSET
90
+ add_projects: Optional[list[GlobalID]] = UNSET
91
+ remove_projects: Optional[list[GlobalID]] = UNSET
92
+
93
+ def __post_init__(self) -> None:
94
+ if isinstance(self.name, str) and not self.name.strip():
95
+ raise BadRequest("Name cannot be empty")
96
+ if isinstance(self.cron_expression, str) and not self.cron_expression.strip():
97
+ raise BadRequest("Cron expression cannot be empty")
98
+ if isinstance(self.add_projects, list) and isinstance(self.remove_projects, list):
99
+ if set(self.add_projects) & set(self.remove_projects):
100
+ raise BadRequest("A project cannot be in both add and remove lists")
101
+
102
+
103
+ @strawberry.input
104
+ class DeleteProjectTraceRetentionPolicyInput:
105
+ id: GlobalID
106
+
107
+
108
+ @strawberry.type
109
+ class ProjectTraceRetentionPolicyMutationPayload:
110
+ query: Query = strawberry.field(default_factory=Query)
111
+ node: ProjectTraceRetentionPolicy
112
+
113
+
114
+ @strawberry.type
115
+ class ProjectTraceRetentionPolicyMutationMixin:
116
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdminIfAuthEnabled, IsLocked]) # type: ignore
117
+ async def create_project_trace_retention_policy(
118
+ self,
119
+ info: Info[Context, None],
120
+ input: CreateProjectTraceRetentionPolicyInput,
121
+ ) -> ProjectTraceRetentionPolicyMutationPayload:
122
+ policy = models.ProjectTraceRetentionPolicy(
123
+ name=input.name,
124
+ cron_expression=TraceRetentionCronExpression.model_validate(input.cron_expression),
125
+ rule=_gql_to_db_rule(input.rule),
126
+ )
127
+ add_project_ids = (
128
+ []
129
+ if not isinstance(input.add_projects, list)
130
+ else [
131
+ from_global_id_with_expected_type(project_id, Project.__name__)
132
+ for project_id in input.add_projects
133
+ ]
134
+ )
135
+ async with info.context.db() as session:
136
+ session.add(policy)
137
+ await session.flush()
138
+ if add_project_ids:
139
+ stmt = (
140
+ sa.update(models.Project)
141
+ .where(models.Project.id.in_(set(add_project_ids)))
142
+ .values(trace_retention_policy_id=policy.id)
143
+ )
144
+ await session.execute(stmt)
145
+ return ProjectTraceRetentionPolicyMutationPayload(
146
+ node=ProjectTraceRetentionPolicy(id=policy.id, db_policy=policy),
147
+ )
148
+
149
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdminIfAuthEnabled, IsLocked]) # type: ignore
150
+ async def patch_project_trace_retention_policy(
151
+ self,
152
+ info: Info[Context, None],
153
+ input: PatchProjectTraceRetentionPolicyInput,
154
+ ) -> ProjectTraceRetentionPolicyMutationPayload:
155
+ id_ = from_global_id_with_expected_type(input.id, ProjectTraceRetentionPolicy.__name__)
156
+ add_project_ids = (
157
+ []
158
+ if not isinstance(input.add_projects, list)
159
+ else [
160
+ from_global_id_with_expected_type(project_id, Project.__name__)
161
+ for project_id in input.add_projects
162
+ ]
163
+ )
164
+ remove_project_ids = (
165
+ []
166
+ if not isinstance(input.remove_projects, list)
167
+ else [
168
+ from_global_id_with_expected_type(project_id, Project.__name__)
169
+ for project_id in input.remove_projects
170
+ ]
171
+ )
172
+ async with info.context.db() as session:
173
+ policy = await session.get(models.ProjectTraceRetentionPolicy, id_)
174
+ if not policy:
175
+ raise NotFound(f"ProjectTraceRetentionPolicy with ID={input.id} not found")
176
+ if isinstance(input.name, str) and input.name != policy.name:
177
+ if id_ == DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID:
178
+ raise BadRequest(
179
+ "Cannot change the name of the default project trace retention policy"
180
+ )
181
+ policy.name = input.name
182
+ if isinstance(input.cron_expression, str):
183
+ policy.cron_expression = TraceRetentionCronExpression(root=input.cron_expression)
184
+ if isinstance(input.rule, ProjectTraceRetentionRuleInput):
185
+ policy.rule = _gql_to_db_rule(input.rule)
186
+ if policy is session.dirty:
187
+ await session.flush()
188
+ if add_project_ids:
189
+ stmt = (
190
+ sa.update(models.Project)
191
+ .where(models.Project.id.in_(set(add_project_ids)))
192
+ .values(trace_retention_policy_id=policy.id)
193
+ )
194
+ await session.execute(stmt)
195
+ if remove_project_ids:
196
+ stmt = (
197
+ sa.update(models.Project)
198
+ .where(models.Project.trace_retention_policy_id == policy.id)
199
+ .where(models.Project.id.in_(set(remove_project_ids)))
200
+ .values(trace_retention_policy_id=None)
201
+ )
202
+ await session.execute(stmt)
203
+ return ProjectTraceRetentionPolicyMutationPayload(
204
+ node=ProjectTraceRetentionPolicy(id=policy.id, db_policy=policy),
205
+ )
206
+
207
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdminIfAuthEnabled]) # type: ignore
208
+ async def delete_project_trace_retention_policy(
209
+ self,
210
+ info: Info[Context, None],
211
+ input: DeleteProjectTraceRetentionPolicyInput,
212
+ ) -> ProjectTraceRetentionPolicyMutationPayload:
213
+ id_ = from_global_id_with_expected_type(input.id, ProjectTraceRetentionPolicy.__name__)
214
+ if id_ == DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID:
215
+ raise BadRequest("Cannot delete the default project trace retention policy.")
216
+ stmt = (
217
+ sa.delete(models.ProjectTraceRetentionPolicy)
218
+ .where(models.ProjectTraceRetentionPolicy.id == id_)
219
+ .returning(models.ProjectTraceRetentionPolicy)
220
+ )
221
+ async with info.context.db() as session:
222
+ policy = await session.scalar(stmt)
223
+ if not policy:
224
+ raise NotFound(f"ProjectTraceRetentionPolicy with ID={input.id} not found")
225
+ return ProjectTraceRetentionPolicyMutationPayload(
226
+ node=ProjectTraceRetentionPolicy(id=policy.id, db_policy=policy),
227
+ )
228
+
229
+
230
+ def _gql_to_db_rule(
231
+ rule: ProjectTraceRetentionRuleInput,
232
+ ) -> TraceRetentionRule:
233
+ if isinstance(rule.max_days, ProjectTraceRetentionRuleMaxDaysInput):
234
+ return TraceRetentionRule(root=MaxDaysRule(max_days=rule.max_days.max_days))
235
+ elif isinstance(rule.max_count, ProjectTraceRetentionRuleMaxCountInput):
236
+ return TraceRetentionRule(root=MaxCountRule(max_count=rule.max_count.max_count))
237
+ elif isinstance(rule.max_days_or_count, ProjectTraceRetentionRuleMaxDaysOrCountInput):
238
+ return TraceRetentionRule(
239
+ root=MaxDaysOrCountRule(
240
+ max_days=rule.max_days_or_count.max_days,
241
+ max_count=rule.max_days_or_count.max_count,
242
+ )
243
+ )
244
+ else:
245
+ raise ValueError("Invalid rule input")
@@ -1,19 +1,28 @@
1
- from collections.abc import Sequence
1
+ from datetime import datetime
2
+ from typing import Optional
2
3
 
3
4
  import strawberry
4
- from sqlalchemy import delete, insert, update
5
- from strawberry import UNSET
6
- from strawberry.types import Info
5
+ from sqlalchemy import delete, insert, select
6
+ from starlette.requests import Request
7
+ from strawberry import UNSET, Info
7
8
 
8
9
  from phoenix.db import models
9
10
  from phoenix.server.api.auth import IsLocked, IsNotReadOnly
10
11
  from phoenix.server.api.context import Context
11
- from phoenix.server.api.input_types.CreateSpanAnnotationInput import CreateSpanAnnotationInput
12
+ from phoenix.server.api.exceptions import BadRequest, NotFound, Unauthorized
13
+ from phoenix.server.api.helpers.annotations import get_user_identifier
14
+ from phoenix.server.api.input_types.CreateSpanAnnotationInput import (
15
+ CreateSpanAnnotationInput,
16
+ CreateSpanNoteInput,
17
+ )
12
18
  from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
13
19
  from phoenix.server.api.input_types.PatchAnnotationInput import PatchAnnotationInput
14
20
  from phoenix.server.api.queries import Query
21
+ from phoenix.server.api.types.AnnotationSource import AnnotationSource
22
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
15
23
  from phoenix.server.api.types.node import from_global_id_with_expected_type
16
24
  from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
25
+ from phoenix.server.bearer_auth import PhoenixUser
17
26
  from phoenix.server.dml_event import SpanAnnotationDeleteEvent, SpanAnnotationInsertEvent
18
27
 
19
28
 
@@ -29,33 +38,156 @@ class SpanAnnotationMutationMixin:
29
38
  async def create_span_annotations(
30
39
  self, info: Info[Context, None], input: list[CreateSpanAnnotationInput]
31
40
  ) -> SpanAnnotationMutationPayload:
32
- inserted_annotations: Sequence[models.SpanAnnotation] = []
41
+ if not input:
42
+ raise BadRequest("No span annotations provided.")
43
+
44
+ if any(d.name == "note" for d in input):
45
+ raise BadRequest("Span notes are not supported in this endpoint.")
46
+
47
+ assert isinstance(request := info.context.request, Request)
48
+ user_id: Optional[int] = None
49
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
50
+ user_id = int(user.identity)
51
+
52
+ processed_annotations_map: dict[int, models.SpanAnnotation] = {}
53
+
54
+ span_rowids = []
55
+ for idx, annotation_input in enumerate(input):
56
+ try:
57
+ span_rowid = from_global_id_with_expected_type(annotation_input.span_id, "Span")
58
+ except ValueError:
59
+ raise BadRequest(
60
+ f"Invalid span ID for annotation at index {idx}: {annotation_input.span_id}"
61
+ )
62
+ span_rowids.append(span_rowid)
63
+
33
64
  async with info.context.db() as session:
34
- values_list = [
35
- dict(
36
- span_rowid=from_global_id_with_expected_type(annotation.span_id, "Span"),
37
- name=annotation.name,
38
- label=annotation.label,
39
- score=annotation.score,
40
- explanation=annotation.explanation,
41
- annotator_kind=annotation.annotator_kind.value,
42
- metadata_=annotation.metadata,
65
+ for idx, (span_rowid, annotation_input) in enumerate(zip(span_rowids, input)):
66
+ resolved_identifier = ""
67
+ if isinstance(annotation_input.identifier, str):
68
+ resolved_identifier = annotation_input.identifier
69
+ elif annotation_input.source == AnnotationSource.APP and user_id is not None:
70
+ resolved_identifier = get_user_identifier(user_id)
71
+ values = {
72
+ "span_rowid": span_rowid,
73
+ "name": annotation_input.name,
74
+ "label": annotation_input.label,
75
+ "score": annotation_input.score,
76
+ "explanation": annotation_input.explanation,
77
+ "annotator_kind": annotation_input.annotator_kind.value,
78
+ "metadata_": annotation_input.metadata,
79
+ "identifier": resolved_identifier,
80
+ "source": annotation_input.source.value,
81
+ "user_id": user_id,
82
+ }
83
+
84
+ processed_annotation: Optional[models.SpanAnnotation] = None
85
+
86
+ q = select(models.SpanAnnotation).where(
87
+ models.SpanAnnotation.span_rowid == span_rowid,
88
+ models.SpanAnnotation.name == annotation_input.name,
89
+ models.SpanAnnotation.identifier == resolved_identifier,
90
+ )
91
+ existing_annotation = await session.scalar(q)
92
+
93
+ if existing_annotation:
94
+ existing_annotation.name = values["name"]
95
+ existing_annotation.label = values["label"]
96
+ existing_annotation.score = values["score"]
97
+ existing_annotation.explanation = values["explanation"]
98
+ existing_annotation.metadata_ = values["metadata_"]
99
+ existing_annotation.annotator_kind = values["annotator_kind"]
100
+ existing_annotation.source = values["source"]
101
+ existing_annotation.user_id = values["user_id"]
102
+ session.add(existing_annotation)
103
+ processed_annotation = existing_annotation
104
+
105
+ if processed_annotation is None:
106
+ stmt = insert(models.SpanAnnotation).values(**values)
107
+ stmt = stmt.returning(models.SpanAnnotation)
108
+ result = await session.scalars(stmt)
109
+ processed_annotation = result.one()
110
+
111
+ processed_annotations_map[idx] = processed_annotation
112
+
113
+ # Collect the objects that were inserted or updated
114
+ processed_annotation_objects = list(processed_annotations_map.values())
115
+ processed_annotation_ids = [anno.id for anno in processed_annotation_objects]
116
+
117
+ # Commit the transaction to finalize the state in the DB
118
+ await session.flush()
119
+
120
+ # Re-fetch the annotations in a batch to get the final state including DB defaults
121
+ final_annotations_result = await session.scalars(
122
+ select(models.SpanAnnotation).where(
123
+ models.SpanAnnotation.id.in_(processed_annotation_ids)
43
124
  )
44
- for annotation in input
45
- ]
46
- stmt = (
47
- insert(models.SpanAnnotation).values(values_list).returning(models.SpanAnnotation)
48
125
  )
126
+ final_annotations_by_id = {anno.id: anno for anno in final_annotations_result.all()}
127
+
128
+ # Order the final annotations according to the input order
129
+ ordered_final_annotations = [
130
+ final_annotations_by_id[id] for id in processed_annotation_ids
131
+ ]
132
+
133
+ # Put event on queue *after* successful commit
134
+ if ordered_final_annotations:
135
+ info.context.event_queue.put(
136
+ SpanAnnotationInsertEvent(tuple(processed_annotation_ids))
137
+ )
138
+
139
+ # Convert the fully loaded annotations to GQL types
140
+ returned_annotations = [
141
+ to_gql_span_annotation(anno) for anno in ordered_final_annotations
142
+ ]
143
+
144
+ await session.commit()
145
+
146
+ return SpanAnnotationMutationPayload(
147
+ span_annotations=returned_annotations,
148
+ query=Query(),
149
+ )
150
+
151
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
152
+ async def create_span_note(
153
+ self, info: Info[Context, None], annotation_input: CreateSpanNoteInput
154
+ ) -> SpanAnnotationMutationPayload:
155
+ assert isinstance(request := info.context.request, Request)
156
+ user_id: Optional[int] = None
157
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
158
+ user_id = int(user.identity)
159
+
160
+ try:
161
+ span_rowid = from_global_id_with_expected_type(annotation_input.span_id, "Span")
162
+ except ValueError:
163
+ raise BadRequest(f"Invalid span ID: {annotation_input.span_id}")
164
+
165
+ async with info.context.db() as session:
166
+ timestamp = datetime.now().isoformat()
167
+ note_identifier = f"px-span-note:{timestamp}"
168
+ values = {
169
+ "span_rowid": span_rowid,
170
+ "name": "note",
171
+ "label": None,
172
+ "score": None,
173
+ "explanation": annotation_input.note,
174
+ "annotator_kind": AnnotatorKind.HUMAN.value,
175
+ "metadata_": dict(),
176
+ "identifier": note_identifier,
177
+ "source": AnnotationSource.APP.value,
178
+ "user_id": user_id,
179
+ }
180
+
181
+ stmt = insert(models.SpanAnnotation).values(**values)
182
+ stmt = stmt.returning(models.SpanAnnotation)
49
183
  result = await session.scalars(stmt)
50
- inserted_annotations = result.all()
51
- if inserted_annotations:
52
- info.context.event_queue.put(
53
- SpanAnnotationInsertEvent(tuple(anno.id for anno in inserted_annotations))
54
- )
184
+ processed_annotation = result.one()
185
+
186
+ info.context.event_queue.put(SpanAnnotationInsertEvent((processed_annotation.id,)))
187
+ returned_annotation = to_gql_span_annotation(processed_annotation)
188
+ await session.commit()
55
189
  return SpanAnnotationMutationPayload(
56
- span_annotations=[
57
- to_gql_span_annotation(annotation) for annotation in inserted_annotations
58
- ],
190
+ span_annotations=[returned_annotation],
59
191
  query=Query(),
60
192
  )
61
193
 
@@ -63,66 +195,136 @@ class SpanAnnotationMutationMixin:
63
195
  async def patch_span_annotations(
64
196
  self, info: Info[Context, None], input: list[PatchAnnotationInput]
65
197
  ) -> SpanAnnotationMutationPayload:
66
- patched_annotations = []
67
- async with info.context.db() as session:
68
- for annotation in input:
198
+ if not input:
199
+ raise BadRequest("No span annotations provided.")
200
+
201
+ assert isinstance(request := info.context.request, Request)
202
+ user_id: Optional[int] = None
203
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
204
+ user_id = int(user.identity)
205
+
206
+ patch_by_id = {}
207
+ for patch in input:
208
+ try:
69
209
  span_annotation_id = from_global_id_with_expected_type(
70
- annotation.annotation_id, "SpanAnnotation"
210
+ patch.annotation_id, SpanAnnotation.__name__
71
211
  )
72
- patch = {
73
- column.key: patch_value
74
- for column, patch_value, column_is_nullable in (
75
- (models.SpanAnnotation.name, annotation.name, False),
76
- (
77
- models.SpanAnnotation.annotator_kind,
78
- annotation.annotator_kind.value
79
- if annotation.annotator_kind is not None
80
- and annotation.annotator_kind is not UNSET
81
- else None,
82
- False,
83
- ),
84
- (models.SpanAnnotation.label, annotation.label, True),
85
- (models.SpanAnnotation.score, annotation.score, True),
86
- (models.SpanAnnotation.explanation, annotation.explanation, True),
87
- (models.SpanAnnotation.metadata_, annotation.metadata, False),
212
+ except ValueError:
213
+ raise BadRequest(f"Invalid span annotation ID: {patch.annotation_id}")
214
+ if span_annotation_id in patch_by_id:
215
+ raise BadRequest(f"Duplicate patch for span annotation ID: {span_annotation_id}")
216
+ patch_by_id[span_annotation_id] = patch
217
+
218
+ async with info.context.db() as session:
219
+ span_annotations_by_id = {}
220
+ for span_annotation in await session.scalars(
221
+ select(models.SpanAnnotation).where(
222
+ models.SpanAnnotation.id.in_(patch_by_id.keys())
223
+ )
224
+ ):
225
+ if span_annotation.user_id != user_id:
226
+ raise Unauthorized(
227
+ "At least one span annotation is not associated with the current user."
88
228
  )
89
- if patch_value is not UNSET and (patch_value is not None or column_is_nullable)
90
- }
91
- span_annotation = await session.scalar(
92
- update(models.SpanAnnotation)
93
- .where(models.SpanAnnotation.id == span_annotation_id)
94
- .values(**patch)
95
- .returning(models.SpanAnnotation)
229
+ span_annotations_by_id[span_annotation.id] = span_annotation
230
+ missing_span_annotation_ids = set(patch_by_id.keys()) - set(
231
+ span_annotations_by_id.keys()
232
+ )
233
+ if missing_span_annotation_ids:
234
+ raise NotFound(
235
+ f"Could not find span annotations with IDs: {missing_span_annotation_ids}"
96
236
  )
97
- if span_annotation is not None:
98
- patched_annotations.append(to_gql_span_annotation(span_annotation))
99
- info.context.event_queue.put(SpanAnnotationInsertEvent((span_annotation.id,)))
100
- return SpanAnnotationMutationPayload(span_annotations=patched_annotations, query=Query())
237
+ for span_annotation_id, patch in patch_by_id.items():
238
+ span_annotation = span_annotations_by_id[span_annotation_id]
239
+ if patch.name:
240
+ span_annotation.name = patch.name
241
+ if patch.annotator_kind:
242
+ span_annotation.annotator_kind = patch.annotator_kind.value
243
+ if patch.label is not UNSET:
244
+ span_annotation.label = patch.label
245
+ if patch.score is not UNSET:
246
+ span_annotation.score = patch.score
247
+ if patch.explanation is not UNSET:
248
+ span_annotation.explanation = patch.explanation
249
+ if patch.metadata is not UNSET:
250
+ assert isinstance(patch.metadata, dict)
251
+ span_annotation.metadata_ = patch.metadata
252
+ if patch.identifier is not UNSET:
253
+ span_annotation.identifier = patch.identifier or ""
254
+ if patch.source:
255
+ span_annotation.source = patch.source.value
256
+ session.add(span_annotation)
257
+
258
+ patched_annotations = [
259
+ to_gql_span_annotation(span_annotation)
260
+ for span_annotation in span_annotations_by_id.values()
261
+ ]
262
+
263
+ info.context.event_queue.put(
264
+ SpanAnnotationInsertEvent(tuple(span_annotations_by_id.keys()))
265
+ )
266
+ return SpanAnnotationMutationPayload(
267
+ span_annotations=patched_annotations,
268
+ query=Query(),
269
+ )
101
270
 
102
271
  @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
103
272
  async def delete_span_annotations(
104
273
  self, info: Info[Context, None], input: DeleteAnnotationsInput
105
274
  ) -> SpanAnnotationMutationPayload:
106
- span_annotation_ids = [
107
- from_global_id_with_expected_type(global_id, "SpanAnnotation")
108
- for global_id in input.annotation_ids
109
- ]
275
+ if not input.annotation_ids:
276
+ raise BadRequest("No span annotation IDs provided.")
277
+
278
+ assert isinstance(request := info.context.request, Request)
279
+ user_id: Optional[int] = None
280
+ user_is_admin = False
281
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
282
+ user_id = int(user.identity)
283
+ user_is_admin = user.is_admin
284
+
285
+ span_annotation_ids: dict[int, None] = {} # use a dict to preserve ordering
286
+ for annotation_gid in input.annotation_ids:
287
+ try:
288
+ span_annotation_id = from_global_id_with_expected_type(
289
+ annotation_gid, SpanAnnotation.__name__
290
+ )
291
+ except ValueError:
292
+ raise BadRequest(f"Invalid span annotation ID: {annotation_gid}")
293
+ if span_annotation_id in span_annotation_ids:
294
+ raise BadRequest(f"Duplicate span annotation ID: {span_annotation_id}")
295
+ span_annotation_ids[span_annotation_id] = None
296
+
110
297
  async with info.context.db() as session:
111
298
  stmt = (
112
299
  delete(models.SpanAnnotation)
113
- .where(models.SpanAnnotation.id.in_(span_annotation_ids))
300
+ .where(models.SpanAnnotation.id.in_(span_annotation_ids.keys()))
114
301
  .returning(models.SpanAnnotation)
115
302
  )
116
303
  result = await session.scalars(stmt)
117
- deleted_annotations = result.all()
304
+ deleted_annotations_by_id = {annotation.id: annotation for annotation in result.all()}
118
305
 
119
- deleted_annotations_gql = [
120
- to_gql_span_annotation(annotation) for annotation in deleted_annotations
121
- ]
122
- if deleted_annotations:
123
- info.context.event_queue.put(
124
- SpanAnnotationDeleteEvent(tuple(anno.id for anno in deleted_annotations))
306
+ if not user_is_admin and any(
307
+ annotation.user_id != user_id for annotation in deleted_annotations_by_id.values()
308
+ ):
309
+ await session.rollback()
310
+ raise Unauthorized(
311
+ "At least one span annotation is not associated with the current user."
312
+ )
313
+
314
+ missing_span_annotation_ids = set(span_annotation_ids.keys()) - set(
315
+ deleted_annotations_by_id.keys()
125
316
  )
317
+ if missing_span_annotation_ids:
318
+ raise NotFound(
319
+ f"Could not find span annotations with IDs: {missing_span_annotation_ids}"
320
+ )
321
+
322
+ deleted_annotations_gql = [
323
+ to_gql_span_annotation(deleted_annotations_by_id[id]) for id in span_annotation_ids
324
+ ]
325
+ info.context.event_queue.put(
326
+ SpanAnnotationDeleteEvent(tuple(deleted_annotations_by_id.keys()))
327
+ )
126
328
  return SpanAnnotationMutationPayload(
127
329
  span_annotations=deleted_annotations_gql, query=Query()
128
330
  )