arize-phoenix 8.32.0__py3-none-any.whl → 9.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/METADATA +3 -2
- {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/RECORD +78 -58
- phoenix/db/constants.py +1 -0
- phoenix/db/facilitator.py +55 -0
- phoenix/db/insertion/document_annotation.py +31 -13
- phoenix/db/insertion/evaluation.py +15 -3
- phoenix/db/insertion/helpers.py +2 -1
- phoenix/db/insertion/span_annotation.py +26 -9
- phoenix/db/insertion/trace_annotation.py +25 -9
- phoenix/db/insertion/types.py +7 -0
- phoenix/db/migrations/versions/2f9d1a65945f_annotation_config_migration.py +322 -0
- phoenix/db/migrations/versions/8a3764fe7f1a_change_jsonb_to_json_for_prompts.py +76 -0
- phoenix/db/migrations/versions/bb8139330879_create_project_trace_retention_policies_table.py +77 -0
- phoenix/db/models.py +151 -10
- phoenix/db/types/annotation_configs.py +97 -0
- phoenix/db/types/db_models.py +41 -0
- phoenix/db/types/trace_retention.py +267 -0
- phoenix/experiments/functions.py +5 -1
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/context.py +5 -0
- phoenix/server/api/dataloaders/__init__.py +4 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +203 -24
- phoenix/server/api/dataloaders/project_ids_by_trace_retention_policy_id.py +42 -0
- phoenix/server/api/dataloaders/trace_retention_policy_id_by_project_id.py +34 -0
- phoenix/server/api/helpers/annotations.py +9 -0
- phoenix/server/api/helpers/prompts/models.py +34 -67
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +9 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +3 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +3 -0
- phoenix/server/api/input_types/SpanAnnotationFilter.py +67 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +413 -0
- phoenix/server/api/mutations/dataset_mutations.py +62 -39
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +245 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +272 -70
- phoenix/server/api/mutations/trace_annotations_mutations.py +203 -74
- phoenix/server/api/queries.py +86 -0
- phoenix/server/api/routers/v1/__init__.py +4 -0
- phoenix/server/api/routers/v1/annotation_configs.py +449 -0
- phoenix/server/api/routers/v1/annotations.py +161 -0
- phoenix/server/api/routers/v1/evaluations.py +6 -0
- phoenix/server/api/routers/v1/projects.py +1 -50
- phoenix/server/api/routers/v1/spans.py +37 -8
- phoenix/server/api/routers/v1/traces.py +22 -13
- phoenix/server/api/routers/v1/utils.py +60 -0
- phoenix/server/api/types/Annotation.py +7 -0
- phoenix/server/api/types/AnnotationConfig.py +124 -0
- phoenix/server/api/types/AnnotationSource.py +9 -0
- phoenix/server/api/types/AnnotationSummary.py +28 -14
- phoenix/server/api/types/AnnotatorKind.py +1 -0
- phoenix/server/api/types/CronExpression.py +15 -0
- phoenix/server/api/types/Evaluation.py +4 -30
- phoenix/server/api/types/Project.py +50 -2
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +110 -0
- phoenix/server/api/types/Span.py +78 -0
- phoenix/server/api/types/SpanAnnotation.py +24 -0
- phoenix/server/api/types/Trace.py +2 -2
- phoenix/server/api/types/TraceAnnotation.py +23 -0
- phoenix/server/app.py +20 -0
- phoenix/server/retention.py +76 -0
- phoenix/server/static/.vite/manifest.json +36 -36
- phoenix/server/static/assets/components-B2MWTXnm.js +4326 -0
- phoenix/server/static/assets/{index-B0CbpsxD.js → index-Bfvpea_-.js} +10 -10
- phoenix/server/static/assets/pages-CZ2vKu8H.js +7268 -0
- phoenix/server/static/assets/vendor-BRDkBC5J.js +903 -0
- phoenix/server/static/assets/{vendor-arizeai-CxXYQNUl.js → vendor-arizeai-BvTqp_W8.js} +3 -3
- phoenix/server/static/assets/{vendor-codemirror-B0NIFPOL.js → vendor-codemirror-COt9UfW7.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-CrrDFWK1.js → vendor-recharts-BoHX9Hvs.js} +2 -2
- phoenix/server/static/assets/{vendor-shiki-C5bJ-RPf.js → vendor-shiki-Cw1dsDAz.js} +1 -1
- phoenix/session/client.py +6 -1
- phoenix/trace/dsl/filter.py +25 -5
- phoenix/trace/dsl/query.py +93 -13
- phoenix/utilities/__init__.py +18 -0
- phoenix/version.py +1 -1
- phoenix/server/static/assets/components-x-gKFJ8C.js +0 -3414
- phoenix/server/static/assets/pages-BU4VdyeH.js +0 -5867
- phoenix/server/static/assets/vendor-BfhM_F1u.js +0 -902
- {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-8.32.0.dist-info → arize_phoenix-9.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import sqlalchemy as sa
|
|
6
|
+
import strawberry
|
|
7
|
+
from strawberry import UNSET, Info
|
|
8
|
+
from strawberry.relay import GlobalID
|
|
9
|
+
|
|
10
|
+
from phoenix.db import models
|
|
11
|
+
from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
|
|
12
|
+
from phoenix.db.types.trace_retention import (
|
|
13
|
+
MaxCountRule,
|
|
14
|
+
MaxDaysOrCountRule,
|
|
15
|
+
MaxDaysRule,
|
|
16
|
+
TraceRetentionCronExpression,
|
|
17
|
+
TraceRetentionRule,
|
|
18
|
+
)
|
|
19
|
+
from phoenix.server.api.auth import IsAdminIfAuthEnabled, IsLocked, IsNotReadOnly
|
|
20
|
+
from phoenix.server.api.context import Context
|
|
21
|
+
from phoenix.server.api.exceptions import BadRequest, NotFound
|
|
22
|
+
from phoenix.server.api.queries import Query
|
|
23
|
+
from phoenix.server.api.types.CronExpression import CronExpression
|
|
24
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
25
|
+
from phoenix.server.api.types.Project import Project
|
|
26
|
+
from phoenix.server.api.types.ProjectTraceRetentionPolicy import (
|
|
27
|
+
ProjectTraceRetentionPolicy,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@strawberry.input
|
|
32
|
+
class ProjectTraceRetentionRuleMaxDaysInput:
|
|
33
|
+
max_days: float
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@strawberry.input
|
|
37
|
+
class ProjectTraceRetentionRuleMaxCountInput:
|
|
38
|
+
max_count: int
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@strawberry.input
|
|
42
|
+
class ProjectTraceRetentionRuleMaxDaysOrCountInput(
|
|
43
|
+
ProjectTraceRetentionRuleMaxDaysInput,
|
|
44
|
+
ProjectTraceRetentionRuleMaxCountInput,
|
|
45
|
+
): ...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@strawberry.input(one_of=True)
|
|
49
|
+
class ProjectTraceRetentionRuleInput:
|
|
50
|
+
max_days: Optional[ProjectTraceRetentionRuleMaxDaysInput] = UNSET
|
|
51
|
+
max_count: Optional[ProjectTraceRetentionRuleMaxCountInput] = UNSET
|
|
52
|
+
max_days_or_count: Optional[ProjectTraceRetentionRuleMaxDaysOrCountInput] = UNSET
|
|
53
|
+
|
|
54
|
+
def __post_init__(self) -> None:
|
|
55
|
+
if (
|
|
56
|
+
sum(
|
|
57
|
+
(
|
|
58
|
+
isinstance(self.max_days, ProjectTraceRetentionRuleMaxDaysInput),
|
|
59
|
+
isinstance(self.max_count, ProjectTraceRetentionRuleMaxCountInput),
|
|
60
|
+
isinstance(
|
|
61
|
+
self.max_days_or_count, ProjectTraceRetentionRuleMaxDaysOrCountInput
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
!= 1
|
|
66
|
+
):
|
|
67
|
+
raise BadRequest("Exactly one rule must be provided")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@strawberry.input
|
|
71
|
+
class CreateProjectTraceRetentionPolicyInput:
|
|
72
|
+
name: str
|
|
73
|
+
cron_expression: CronExpression
|
|
74
|
+
rule: ProjectTraceRetentionRuleInput
|
|
75
|
+
add_projects: Optional[list[GlobalID]] = UNSET
|
|
76
|
+
|
|
77
|
+
def __post_init__(self) -> None:
|
|
78
|
+
if not self.name.strip():
|
|
79
|
+
raise BadRequest("Name cannot be empty")
|
|
80
|
+
if not self.cron_expression.strip():
|
|
81
|
+
raise BadRequest("Cron expression cannot be empty")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@strawberry.input
|
|
85
|
+
class PatchProjectTraceRetentionPolicyInput:
|
|
86
|
+
id: GlobalID
|
|
87
|
+
name: Optional[str] = UNSET
|
|
88
|
+
cron_expression: Optional[CronExpression] = UNSET
|
|
89
|
+
rule: Optional[ProjectTraceRetentionRuleInput] = UNSET
|
|
90
|
+
add_projects: Optional[list[GlobalID]] = UNSET
|
|
91
|
+
remove_projects: Optional[list[GlobalID]] = UNSET
|
|
92
|
+
|
|
93
|
+
def __post_init__(self) -> None:
|
|
94
|
+
if isinstance(self.name, str) and not self.name.strip():
|
|
95
|
+
raise BadRequest("Name cannot be empty")
|
|
96
|
+
if isinstance(self.cron_expression, str) and not self.cron_expression.strip():
|
|
97
|
+
raise BadRequest("Cron expression cannot be empty")
|
|
98
|
+
if isinstance(self.add_projects, list) and isinstance(self.remove_projects, list):
|
|
99
|
+
if set(self.add_projects) & set(self.remove_projects):
|
|
100
|
+
raise BadRequest("A project cannot be in both add and remove lists")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@strawberry.input
|
|
104
|
+
class DeleteProjectTraceRetentionPolicyInput:
|
|
105
|
+
id: GlobalID
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@strawberry.type
|
|
109
|
+
class ProjectTraceRetentionPolicyMutationPayload:
|
|
110
|
+
query: Query = strawberry.field(default_factory=Query)
|
|
111
|
+
node: ProjectTraceRetentionPolicy
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@strawberry.type
|
|
115
|
+
class ProjectTraceRetentionPolicyMutationMixin:
|
|
116
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdminIfAuthEnabled, IsLocked]) # type: ignore
|
|
117
|
+
async def create_project_trace_retention_policy(
|
|
118
|
+
self,
|
|
119
|
+
info: Info[Context, None],
|
|
120
|
+
input: CreateProjectTraceRetentionPolicyInput,
|
|
121
|
+
) -> ProjectTraceRetentionPolicyMutationPayload:
|
|
122
|
+
policy = models.ProjectTraceRetentionPolicy(
|
|
123
|
+
name=input.name,
|
|
124
|
+
cron_expression=TraceRetentionCronExpression.model_validate(input.cron_expression),
|
|
125
|
+
rule=_gql_to_db_rule(input.rule),
|
|
126
|
+
)
|
|
127
|
+
add_project_ids = (
|
|
128
|
+
[]
|
|
129
|
+
if not isinstance(input.add_projects, list)
|
|
130
|
+
else [
|
|
131
|
+
from_global_id_with_expected_type(project_id, Project.__name__)
|
|
132
|
+
for project_id in input.add_projects
|
|
133
|
+
]
|
|
134
|
+
)
|
|
135
|
+
async with info.context.db() as session:
|
|
136
|
+
session.add(policy)
|
|
137
|
+
await session.flush()
|
|
138
|
+
if add_project_ids:
|
|
139
|
+
stmt = (
|
|
140
|
+
sa.update(models.Project)
|
|
141
|
+
.where(models.Project.id.in_(set(add_project_ids)))
|
|
142
|
+
.values(trace_retention_policy_id=policy.id)
|
|
143
|
+
)
|
|
144
|
+
await session.execute(stmt)
|
|
145
|
+
return ProjectTraceRetentionPolicyMutationPayload(
|
|
146
|
+
node=ProjectTraceRetentionPolicy(id=policy.id, db_policy=policy),
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdminIfAuthEnabled, IsLocked]) # type: ignore
|
|
150
|
+
async def patch_project_trace_retention_policy(
|
|
151
|
+
self,
|
|
152
|
+
info: Info[Context, None],
|
|
153
|
+
input: PatchProjectTraceRetentionPolicyInput,
|
|
154
|
+
) -> ProjectTraceRetentionPolicyMutationPayload:
|
|
155
|
+
id_ = from_global_id_with_expected_type(input.id, ProjectTraceRetentionPolicy.__name__)
|
|
156
|
+
add_project_ids = (
|
|
157
|
+
[]
|
|
158
|
+
if not isinstance(input.add_projects, list)
|
|
159
|
+
else [
|
|
160
|
+
from_global_id_with_expected_type(project_id, Project.__name__)
|
|
161
|
+
for project_id in input.add_projects
|
|
162
|
+
]
|
|
163
|
+
)
|
|
164
|
+
remove_project_ids = (
|
|
165
|
+
[]
|
|
166
|
+
if not isinstance(input.remove_projects, list)
|
|
167
|
+
else [
|
|
168
|
+
from_global_id_with_expected_type(project_id, Project.__name__)
|
|
169
|
+
for project_id in input.remove_projects
|
|
170
|
+
]
|
|
171
|
+
)
|
|
172
|
+
async with info.context.db() as session:
|
|
173
|
+
policy = await session.get(models.ProjectTraceRetentionPolicy, id_)
|
|
174
|
+
if not policy:
|
|
175
|
+
raise NotFound(f"ProjectTraceRetentionPolicy with ID={input.id} not found")
|
|
176
|
+
if isinstance(input.name, str) and input.name != policy.name:
|
|
177
|
+
if id_ == DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID:
|
|
178
|
+
raise BadRequest(
|
|
179
|
+
"Cannot change the name of the default project trace retention policy"
|
|
180
|
+
)
|
|
181
|
+
policy.name = input.name
|
|
182
|
+
if isinstance(input.cron_expression, str):
|
|
183
|
+
policy.cron_expression = TraceRetentionCronExpression(root=input.cron_expression)
|
|
184
|
+
if isinstance(input.rule, ProjectTraceRetentionRuleInput):
|
|
185
|
+
policy.rule = _gql_to_db_rule(input.rule)
|
|
186
|
+
if policy is session.dirty:
|
|
187
|
+
await session.flush()
|
|
188
|
+
if add_project_ids:
|
|
189
|
+
stmt = (
|
|
190
|
+
sa.update(models.Project)
|
|
191
|
+
.where(models.Project.id.in_(set(add_project_ids)))
|
|
192
|
+
.values(trace_retention_policy_id=policy.id)
|
|
193
|
+
)
|
|
194
|
+
await session.execute(stmt)
|
|
195
|
+
if remove_project_ids:
|
|
196
|
+
stmt = (
|
|
197
|
+
sa.update(models.Project)
|
|
198
|
+
.where(models.Project.trace_retention_policy_id == policy.id)
|
|
199
|
+
.where(models.Project.id.in_(set(remove_project_ids)))
|
|
200
|
+
.values(trace_retention_policy_id=None)
|
|
201
|
+
)
|
|
202
|
+
await session.execute(stmt)
|
|
203
|
+
return ProjectTraceRetentionPolicyMutationPayload(
|
|
204
|
+
node=ProjectTraceRetentionPolicy(id=policy.id, db_policy=policy),
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdminIfAuthEnabled]) # type: ignore
|
|
208
|
+
async def delete_project_trace_retention_policy(
|
|
209
|
+
self,
|
|
210
|
+
info: Info[Context, None],
|
|
211
|
+
input: DeleteProjectTraceRetentionPolicyInput,
|
|
212
|
+
) -> ProjectTraceRetentionPolicyMutationPayload:
|
|
213
|
+
id_ = from_global_id_with_expected_type(input.id, ProjectTraceRetentionPolicy.__name__)
|
|
214
|
+
if id_ == DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID:
|
|
215
|
+
raise BadRequest("Cannot delete the default project trace retention policy.")
|
|
216
|
+
stmt = (
|
|
217
|
+
sa.delete(models.ProjectTraceRetentionPolicy)
|
|
218
|
+
.where(models.ProjectTraceRetentionPolicy.id == id_)
|
|
219
|
+
.returning(models.ProjectTraceRetentionPolicy)
|
|
220
|
+
)
|
|
221
|
+
async with info.context.db() as session:
|
|
222
|
+
policy = await session.scalar(stmt)
|
|
223
|
+
if not policy:
|
|
224
|
+
raise NotFound(f"ProjectTraceRetentionPolicy with ID={input.id} not found")
|
|
225
|
+
return ProjectTraceRetentionPolicyMutationPayload(
|
|
226
|
+
node=ProjectTraceRetentionPolicy(id=policy.id, db_policy=policy),
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _gql_to_db_rule(
|
|
231
|
+
rule: ProjectTraceRetentionRuleInput,
|
|
232
|
+
) -> TraceRetentionRule:
|
|
233
|
+
if isinstance(rule.max_days, ProjectTraceRetentionRuleMaxDaysInput):
|
|
234
|
+
return TraceRetentionRule(root=MaxDaysRule(max_days=rule.max_days.max_days))
|
|
235
|
+
elif isinstance(rule.max_count, ProjectTraceRetentionRuleMaxCountInput):
|
|
236
|
+
return TraceRetentionRule(root=MaxCountRule(max_count=rule.max_count.max_count))
|
|
237
|
+
elif isinstance(rule.max_days_or_count, ProjectTraceRetentionRuleMaxDaysOrCountInput):
|
|
238
|
+
return TraceRetentionRule(
|
|
239
|
+
root=MaxDaysOrCountRule(
|
|
240
|
+
max_days=rule.max_days_or_count.max_days,
|
|
241
|
+
max_count=rule.max_days_or_count.max_count,
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
else:
|
|
245
|
+
raise ValueError("Invalid rule input")
|
|
@@ -1,19 +1,28 @@
|
|
|
1
|
-
from
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
import strawberry
|
|
4
|
-
from sqlalchemy import delete, insert,
|
|
5
|
-
from
|
|
6
|
-
from strawberry
|
|
5
|
+
from sqlalchemy import delete, insert, select
|
|
6
|
+
from starlette.requests import Request
|
|
7
|
+
from strawberry import UNSET, Info
|
|
7
8
|
|
|
8
9
|
from phoenix.db import models
|
|
9
10
|
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
10
11
|
from phoenix.server.api.context import Context
|
|
11
|
-
from phoenix.server.api.
|
|
12
|
+
from phoenix.server.api.exceptions import BadRequest, NotFound, Unauthorized
|
|
13
|
+
from phoenix.server.api.helpers.annotations import get_user_identifier
|
|
14
|
+
from phoenix.server.api.input_types.CreateSpanAnnotationInput import (
|
|
15
|
+
CreateSpanAnnotationInput,
|
|
16
|
+
CreateSpanNoteInput,
|
|
17
|
+
)
|
|
12
18
|
from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
|
|
13
19
|
from phoenix.server.api.input_types.PatchAnnotationInput import PatchAnnotationInput
|
|
14
20
|
from phoenix.server.api.queries import Query
|
|
21
|
+
from phoenix.server.api.types.AnnotationSource import AnnotationSource
|
|
22
|
+
from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
|
|
15
23
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
16
24
|
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
|
|
25
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
17
26
|
from phoenix.server.dml_event import SpanAnnotationDeleteEvent, SpanAnnotationInsertEvent
|
|
18
27
|
|
|
19
28
|
|
|
@@ -29,33 +38,156 @@ class SpanAnnotationMutationMixin:
|
|
|
29
38
|
async def create_span_annotations(
|
|
30
39
|
self, info: Info[Context, None], input: list[CreateSpanAnnotationInput]
|
|
31
40
|
) -> SpanAnnotationMutationPayload:
|
|
32
|
-
|
|
41
|
+
if not input:
|
|
42
|
+
raise BadRequest("No span annotations provided.")
|
|
43
|
+
|
|
44
|
+
if any(d.name == "note" for d in input):
|
|
45
|
+
raise BadRequest("Span notes are not supported in this endpoint.")
|
|
46
|
+
|
|
47
|
+
assert isinstance(request := info.context.request, Request)
|
|
48
|
+
user_id: Optional[int] = None
|
|
49
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
50
|
+
user_id = int(user.identity)
|
|
51
|
+
|
|
52
|
+
processed_annotations_map: dict[int, models.SpanAnnotation] = {}
|
|
53
|
+
|
|
54
|
+
span_rowids = []
|
|
55
|
+
for idx, annotation_input in enumerate(input):
|
|
56
|
+
try:
|
|
57
|
+
span_rowid = from_global_id_with_expected_type(annotation_input.span_id, "Span")
|
|
58
|
+
except ValueError:
|
|
59
|
+
raise BadRequest(
|
|
60
|
+
f"Invalid span ID for annotation at index {idx}: {annotation_input.span_id}"
|
|
61
|
+
)
|
|
62
|
+
span_rowids.append(span_rowid)
|
|
63
|
+
|
|
33
64
|
async with info.context.db() as session:
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
65
|
+
for idx, (span_rowid, annotation_input) in enumerate(zip(span_rowids, input)):
|
|
66
|
+
resolved_identifier = ""
|
|
67
|
+
if isinstance(annotation_input.identifier, str):
|
|
68
|
+
resolved_identifier = annotation_input.identifier
|
|
69
|
+
elif annotation_input.source == AnnotationSource.APP and user_id is not None:
|
|
70
|
+
resolved_identifier = get_user_identifier(user_id)
|
|
71
|
+
values = {
|
|
72
|
+
"span_rowid": span_rowid,
|
|
73
|
+
"name": annotation_input.name,
|
|
74
|
+
"label": annotation_input.label,
|
|
75
|
+
"score": annotation_input.score,
|
|
76
|
+
"explanation": annotation_input.explanation,
|
|
77
|
+
"annotator_kind": annotation_input.annotator_kind.value,
|
|
78
|
+
"metadata_": annotation_input.metadata,
|
|
79
|
+
"identifier": resolved_identifier,
|
|
80
|
+
"source": annotation_input.source.value,
|
|
81
|
+
"user_id": user_id,
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
processed_annotation: Optional[models.SpanAnnotation] = None
|
|
85
|
+
|
|
86
|
+
q = select(models.SpanAnnotation).where(
|
|
87
|
+
models.SpanAnnotation.span_rowid == span_rowid,
|
|
88
|
+
models.SpanAnnotation.name == annotation_input.name,
|
|
89
|
+
models.SpanAnnotation.identifier == resolved_identifier,
|
|
90
|
+
)
|
|
91
|
+
existing_annotation = await session.scalar(q)
|
|
92
|
+
|
|
93
|
+
if existing_annotation:
|
|
94
|
+
existing_annotation.name = values["name"]
|
|
95
|
+
existing_annotation.label = values["label"]
|
|
96
|
+
existing_annotation.score = values["score"]
|
|
97
|
+
existing_annotation.explanation = values["explanation"]
|
|
98
|
+
existing_annotation.metadata_ = values["metadata_"]
|
|
99
|
+
existing_annotation.annotator_kind = values["annotator_kind"]
|
|
100
|
+
existing_annotation.source = values["source"]
|
|
101
|
+
existing_annotation.user_id = values["user_id"]
|
|
102
|
+
session.add(existing_annotation)
|
|
103
|
+
processed_annotation = existing_annotation
|
|
104
|
+
|
|
105
|
+
if processed_annotation is None:
|
|
106
|
+
stmt = insert(models.SpanAnnotation).values(**values)
|
|
107
|
+
stmt = stmt.returning(models.SpanAnnotation)
|
|
108
|
+
result = await session.scalars(stmt)
|
|
109
|
+
processed_annotation = result.one()
|
|
110
|
+
|
|
111
|
+
processed_annotations_map[idx] = processed_annotation
|
|
112
|
+
|
|
113
|
+
# Collect the objects that were inserted or updated
|
|
114
|
+
processed_annotation_objects = list(processed_annotations_map.values())
|
|
115
|
+
processed_annotation_ids = [anno.id for anno in processed_annotation_objects]
|
|
116
|
+
|
|
117
|
+
# Commit the transaction to finalize the state in the DB
|
|
118
|
+
await session.flush()
|
|
119
|
+
|
|
120
|
+
# Re-fetch the annotations in a batch to get the final state including DB defaults
|
|
121
|
+
final_annotations_result = await session.scalars(
|
|
122
|
+
select(models.SpanAnnotation).where(
|
|
123
|
+
models.SpanAnnotation.id.in_(processed_annotation_ids)
|
|
43
124
|
)
|
|
44
|
-
for annotation in input
|
|
45
|
-
]
|
|
46
|
-
stmt = (
|
|
47
|
-
insert(models.SpanAnnotation).values(values_list).returning(models.SpanAnnotation)
|
|
48
125
|
)
|
|
126
|
+
final_annotations_by_id = {anno.id: anno for anno in final_annotations_result.all()}
|
|
127
|
+
|
|
128
|
+
# Order the final annotations according to the input order
|
|
129
|
+
ordered_final_annotations = [
|
|
130
|
+
final_annotations_by_id[id] for id in processed_annotation_ids
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
# Put event on queue *after* successful commit
|
|
134
|
+
if ordered_final_annotations:
|
|
135
|
+
info.context.event_queue.put(
|
|
136
|
+
SpanAnnotationInsertEvent(tuple(processed_annotation_ids))
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Convert the fully loaded annotations to GQL types
|
|
140
|
+
returned_annotations = [
|
|
141
|
+
to_gql_span_annotation(anno) for anno in ordered_final_annotations
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
await session.commit()
|
|
145
|
+
|
|
146
|
+
return SpanAnnotationMutationPayload(
|
|
147
|
+
span_annotations=returned_annotations,
|
|
148
|
+
query=Query(),
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
152
|
+
async def create_span_note(
|
|
153
|
+
self, info: Info[Context, None], annotation_input: CreateSpanNoteInput
|
|
154
|
+
) -> SpanAnnotationMutationPayload:
|
|
155
|
+
assert isinstance(request := info.context.request, Request)
|
|
156
|
+
user_id: Optional[int] = None
|
|
157
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
158
|
+
user_id = int(user.identity)
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
span_rowid = from_global_id_with_expected_type(annotation_input.span_id, "Span")
|
|
162
|
+
except ValueError:
|
|
163
|
+
raise BadRequest(f"Invalid span ID: {annotation_input.span_id}")
|
|
164
|
+
|
|
165
|
+
async with info.context.db() as session:
|
|
166
|
+
timestamp = datetime.now().isoformat()
|
|
167
|
+
note_identifier = f"px-span-note:{timestamp}"
|
|
168
|
+
values = {
|
|
169
|
+
"span_rowid": span_rowid,
|
|
170
|
+
"name": "note",
|
|
171
|
+
"label": None,
|
|
172
|
+
"score": None,
|
|
173
|
+
"explanation": annotation_input.note,
|
|
174
|
+
"annotator_kind": AnnotatorKind.HUMAN.value,
|
|
175
|
+
"metadata_": dict(),
|
|
176
|
+
"identifier": note_identifier,
|
|
177
|
+
"source": AnnotationSource.APP.value,
|
|
178
|
+
"user_id": user_id,
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
stmt = insert(models.SpanAnnotation).values(**values)
|
|
182
|
+
stmt = stmt.returning(models.SpanAnnotation)
|
|
49
183
|
result = await session.scalars(stmt)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
info.context.event_queue.put(
|
|
53
|
-
|
|
54
|
-
)
|
|
184
|
+
processed_annotation = result.one()
|
|
185
|
+
|
|
186
|
+
info.context.event_queue.put(SpanAnnotationInsertEvent((processed_annotation.id,)))
|
|
187
|
+
returned_annotation = to_gql_span_annotation(processed_annotation)
|
|
188
|
+
await session.commit()
|
|
55
189
|
return SpanAnnotationMutationPayload(
|
|
56
|
-
span_annotations=[
|
|
57
|
-
to_gql_span_annotation(annotation) for annotation in inserted_annotations
|
|
58
|
-
],
|
|
190
|
+
span_annotations=[returned_annotation],
|
|
59
191
|
query=Query(),
|
|
60
192
|
)
|
|
61
193
|
|
|
@@ -63,66 +195,136 @@ class SpanAnnotationMutationMixin:
|
|
|
63
195
|
async def patch_span_annotations(
|
|
64
196
|
self, info: Info[Context, None], input: list[PatchAnnotationInput]
|
|
65
197
|
) -> SpanAnnotationMutationPayload:
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
198
|
+
if not input:
|
|
199
|
+
raise BadRequest("No span annotations provided.")
|
|
200
|
+
|
|
201
|
+
assert isinstance(request := info.context.request, Request)
|
|
202
|
+
user_id: Optional[int] = None
|
|
203
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
204
|
+
user_id = int(user.identity)
|
|
205
|
+
|
|
206
|
+
patch_by_id = {}
|
|
207
|
+
for patch in input:
|
|
208
|
+
try:
|
|
69
209
|
span_annotation_id = from_global_id_with_expected_type(
|
|
70
|
-
|
|
210
|
+
patch.annotation_id, SpanAnnotation.__name__
|
|
71
211
|
)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
212
|
+
except ValueError:
|
|
213
|
+
raise BadRequest(f"Invalid span annotation ID: {patch.annotation_id}")
|
|
214
|
+
if span_annotation_id in patch_by_id:
|
|
215
|
+
raise BadRequest(f"Duplicate patch for span annotation ID: {span_annotation_id}")
|
|
216
|
+
patch_by_id[span_annotation_id] = patch
|
|
217
|
+
|
|
218
|
+
async with info.context.db() as session:
|
|
219
|
+
span_annotations_by_id = {}
|
|
220
|
+
for span_annotation in await session.scalars(
|
|
221
|
+
select(models.SpanAnnotation).where(
|
|
222
|
+
models.SpanAnnotation.id.in_(patch_by_id.keys())
|
|
223
|
+
)
|
|
224
|
+
):
|
|
225
|
+
if span_annotation.user_id != user_id:
|
|
226
|
+
raise Unauthorized(
|
|
227
|
+
"At least one span annotation is not associated with the current user."
|
|
88
228
|
)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
229
|
+
span_annotations_by_id[span_annotation.id] = span_annotation
|
|
230
|
+
missing_span_annotation_ids = set(patch_by_id.keys()) - set(
|
|
231
|
+
span_annotations_by_id.keys()
|
|
232
|
+
)
|
|
233
|
+
if missing_span_annotation_ids:
|
|
234
|
+
raise NotFound(
|
|
235
|
+
f"Could not find span annotations with IDs: {missing_span_annotation_ids}"
|
|
96
236
|
)
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
237
|
+
for span_annotation_id, patch in patch_by_id.items():
|
|
238
|
+
span_annotation = span_annotations_by_id[span_annotation_id]
|
|
239
|
+
if patch.name:
|
|
240
|
+
span_annotation.name = patch.name
|
|
241
|
+
if patch.annotator_kind:
|
|
242
|
+
span_annotation.annotator_kind = patch.annotator_kind.value
|
|
243
|
+
if patch.label is not UNSET:
|
|
244
|
+
span_annotation.label = patch.label
|
|
245
|
+
if patch.score is not UNSET:
|
|
246
|
+
span_annotation.score = patch.score
|
|
247
|
+
if patch.explanation is not UNSET:
|
|
248
|
+
span_annotation.explanation = patch.explanation
|
|
249
|
+
if patch.metadata is not UNSET:
|
|
250
|
+
assert isinstance(patch.metadata, dict)
|
|
251
|
+
span_annotation.metadata_ = patch.metadata
|
|
252
|
+
if patch.identifier is not UNSET:
|
|
253
|
+
span_annotation.identifier = patch.identifier or ""
|
|
254
|
+
if patch.source:
|
|
255
|
+
span_annotation.source = patch.source.value
|
|
256
|
+
session.add(span_annotation)
|
|
257
|
+
|
|
258
|
+
patched_annotations = [
|
|
259
|
+
to_gql_span_annotation(span_annotation)
|
|
260
|
+
for span_annotation in span_annotations_by_id.values()
|
|
261
|
+
]
|
|
262
|
+
|
|
263
|
+
info.context.event_queue.put(
|
|
264
|
+
SpanAnnotationInsertEvent(tuple(span_annotations_by_id.keys()))
|
|
265
|
+
)
|
|
266
|
+
return SpanAnnotationMutationPayload(
|
|
267
|
+
span_annotations=patched_annotations,
|
|
268
|
+
query=Query(),
|
|
269
|
+
)
|
|
101
270
|
|
|
102
271
|
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
103
272
|
async def delete_span_annotations(
|
|
104
273
|
self, info: Info[Context, None], input: DeleteAnnotationsInput
|
|
105
274
|
) -> SpanAnnotationMutationPayload:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
275
|
+
if not input.annotation_ids:
|
|
276
|
+
raise BadRequest("No span annotation IDs provided.")
|
|
277
|
+
|
|
278
|
+
assert isinstance(request := info.context.request, Request)
|
|
279
|
+
user_id: Optional[int] = None
|
|
280
|
+
user_is_admin = False
|
|
281
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
282
|
+
user_id = int(user.identity)
|
|
283
|
+
user_is_admin = user.is_admin
|
|
284
|
+
|
|
285
|
+
span_annotation_ids: dict[int, None] = {} # use a dict to preserve ordering
|
|
286
|
+
for annotation_gid in input.annotation_ids:
|
|
287
|
+
try:
|
|
288
|
+
span_annotation_id = from_global_id_with_expected_type(
|
|
289
|
+
annotation_gid, SpanAnnotation.__name__
|
|
290
|
+
)
|
|
291
|
+
except ValueError:
|
|
292
|
+
raise BadRequest(f"Invalid span annotation ID: {annotation_gid}")
|
|
293
|
+
if span_annotation_id in span_annotation_ids:
|
|
294
|
+
raise BadRequest(f"Duplicate span annotation ID: {span_annotation_id}")
|
|
295
|
+
span_annotation_ids[span_annotation_id] = None
|
|
296
|
+
|
|
110
297
|
async with info.context.db() as session:
|
|
111
298
|
stmt = (
|
|
112
299
|
delete(models.SpanAnnotation)
|
|
113
|
-
.where(models.SpanAnnotation.id.in_(span_annotation_ids))
|
|
300
|
+
.where(models.SpanAnnotation.id.in_(span_annotation_ids.keys()))
|
|
114
301
|
.returning(models.SpanAnnotation)
|
|
115
302
|
)
|
|
116
303
|
result = await session.scalars(stmt)
|
|
117
|
-
|
|
304
|
+
deleted_annotations_by_id = {annotation.id: annotation for annotation in result.all()}
|
|
118
305
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
306
|
+
if not user_is_admin and any(
|
|
307
|
+
annotation.user_id != user_id for annotation in deleted_annotations_by_id.values()
|
|
308
|
+
):
|
|
309
|
+
await session.rollback()
|
|
310
|
+
raise Unauthorized(
|
|
311
|
+
"At least one span annotation is not associated with the current user."
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
missing_span_annotation_ids = set(span_annotation_ids.keys()) - set(
|
|
315
|
+
deleted_annotations_by_id.keys()
|
|
125
316
|
)
|
|
317
|
+
if missing_span_annotation_ids:
|
|
318
|
+
raise NotFound(
|
|
319
|
+
f"Could not find span annotations with IDs: {missing_span_annotation_ids}"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
deleted_annotations_gql = [
|
|
323
|
+
to_gql_span_annotation(deleted_annotations_by_id[id]) for id in span_annotation_ids
|
|
324
|
+
]
|
|
325
|
+
info.context.event_queue.put(
|
|
326
|
+
SpanAnnotationDeleteEvent(tuple(deleted_annotations_by_id.keys()))
|
|
327
|
+
)
|
|
126
328
|
return SpanAnnotationMutationPayload(
|
|
127
329
|
span_annotations=deleted_annotations_gql, query=Query()
|
|
128
330
|
)
|