arize-phoenix 4.5.0__py3-none-any.whl → 4.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (123) hide show
  1. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/METADATA +16 -8
  2. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/RECORD +122 -58
  3. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +0 -27
  5. phoenix/config.py +42 -7
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +64 -62
  8. phoenix/core/model_schema_adapter.py +27 -25
  9. phoenix/datetime_utils.py +4 -0
  10. phoenix/db/bulk_inserter.py +54 -14
  11. phoenix/db/insertion/dataset.py +237 -0
  12. phoenix/db/insertion/evaluation.py +10 -10
  13. phoenix/db/insertion/helpers.py +17 -14
  14. phoenix/db/insertion/span.py +3 -3
  15. phoenix/db/migrations/types.py +29 -0
  16. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  17. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  18. phoenix/db/models.py +236 -4
  19. phoenix/experiments/__init__.py +6 -0
  20. phoenix/experiments/evaluators/__init__.py +29 -0
  21. phoenix/experiments/evaluators/base.py +153 -0
  22. phoenix/experiments/evaluators/code_evaluators.py +99 -0
  23. phoenix/experiments/evaluators/llm_evaluators.py +244 -0
  24. phoenix/experiments/evaluators/utils.py +186 -0
  25. phoenix/experiments/functions.py +757 -0
  26. phoenix/experiments/tracing.py +85 -0
  27. phoenix/experiments/types.py +753 -0
  28. phoenix/experiments/utils.py +24 -0
  29. phoenix/inferences/fixtures.py +23 -23
  30. phoenix/inferences/inferences.py +7 -7
  31. phoenix/inferences/validation.py +1 -1
  32. phoenix/server/api/context.py +20 -0
  33. phoenix/server/api/dataloaders/__init__.py +20 -0
  34. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  35. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  36. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  37. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  38. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  39. phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
  40. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  41. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  42. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  43. phoenix/server/api/dataloaders/span_projects.py +33 -0
  44. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  45. phoenix/server/api/helpers/dataset_helpers.py +179 -0
  46. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  47. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  48. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  49. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  50. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  51. phoenix/server/api/input_types/DatasetSort.py +17 -0
  52. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  53. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  54. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  55. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  56. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  57. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  58. phoenix/server/api/mutations/__init__.py +13 -0
  59. phoenix/server/api/mutations/auth.py +11 -0
  60. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  61. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  62. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  63. phoenix/server/api/mutations/project_mutations.py +47 -0
  64. phoenix/server/api/openapi/__init__.py +0 -0
  65. phoenix/server/api/openapi/main.py +6 -0
  66. phoenix/server/api/openapi/schema.py +16 -0
  67. phoenix/server/api/queries.py +503 -0
  68. phoenix/server/api/routers/v1/__init__.py +77 -2
  69. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  70. phoenix/server/api/routers/v1/datasets.py +965 -0
  71. phoenix/server/api/routers/v1/evaluations.py +8 -13
  72. phoenix/server/api/routers/v1/experiment_evaluations.py +143 -0
  73. phoenix/server/api/routers/v1/experiment_runs.py +220 -0
  74. phoenix/server/api/routers/v1/experiments.py +302 -0
  75. phoenix/server/api/routers/v1/spans.py +9 -5
  76. phoenix/server/api/routers/v1/traces.py +1 -4
  77. phoenix/server/api/schema.py +2 -303
  78. phoenix/server/api/types/AnnotatorKind.py +10 -0
  79. phoenix/server/api/types/Cluster.py +19 -19
  80. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  81. phoenix/server/api/types/Dataset.py +282 -63
  82. phoenix/server/api/types/DatasetExample.py +85 -0
  83. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  84. phoenix/server/api/types/DatasetVersion.py +14 -0
  85. phoenix/server/api/types/Dimension.py +30 -29
  86. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  87. phoenix/server/api/types/Event.py +16 -16
  88. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  89. phoenix/server/api/types/Experiment.py +147 -0
  90. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  91. phoenix/server/api/types/ExperimentComparison.py +19 -0
  92. phoenix/server/api/types/ExperimentRun.py +91 -0
  93. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  94. phoenix/server/api/types/Inferences.py +80 -0
  95. phoenix/server/api/types/InferencesRole.py +23 -0
  96. phoenix/server/api/types/Model.py +43 -42
  97. phoenix/server/api/types/Project.py +26 -12
  98. phoenix/server/api/types/Span.py +79 -2
  99. phoenix/server/api/types/TimeSeries.py +6 -6
  100. phoenix/server/api/types/Trace.py +15 -4
  101. phoenix/server/api/types/UMAPPoints.py +1 -1
  102. phoenix/server/api/types/node.py +5 -111
  103. phoenix/server/api/types/pagination.py +10 -52
  104. phoenix/server/app.py +103 -49
  105. phoenix/server/main.py +49 -27
  106. phoenix/server/openapi/docs.py +3 -0
  107. phoenix/server/static/index.js +2300 -1294
  108. phoenix/server/templates/index.html +1 -0
  109. phoenix/services.py +15 -15
  110. phoenix/session/client.py +581 -22
  111. phoenix/session/session.py +47 -37
  112. phoenix/trace/exporter.py +14 -9
  113. phoenix/trace/fixtures.py +133 -7
  114. phoenix/trace/schemas.py +1 -2
  115. phoenix/trace/span_evaluations.py +3 -3
  116. phoenix/trace/trace_dataset.py +6 -6
  117. phoenix/utilities/json.py +61 -0
  118. phoenix/utilities/re.py +50 -0
  119. phoenix/version.py +1 -1
  120. phoenix/server/api/types/DatasetRole.py +0 -23
  121. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/IP_NOTICE +0 -0
  122. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/LICENSE +0 -0
  123. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -0,0 +1,520 @@
1
+ from datetime import datetime
2
+ from typing import Any, Dict
3
+
4
+ import strawberry
5
+ from openinference.semconv.trace import (
6
+ SpanAttributes,
7
+ )
8
+ from sqlalchemy import and_, delete, distinct, func, insert, select, update
9
+ from strawberry import UNSET
10
+ from strawberry.types import Info
11
+
12
+ from phoenix.db import models
13
+ from phoenix.server.api.context import Context
14
+ from phoenix.server.api.helpers.dataset_helpers import (
15
+ get_dataset_example_input,
16
+ get_dataset_example_output,
17
+ )
18
+ from phoenix.server.api.input_types.AddExamplesToDatasetInput import AddExamplesToDatasetInput
19
+ from phoenix.server.api.input_types.AddSpansToDatasetInput import AddSpansToDatasetInput
20
+ from phoenix.server.api.input_types.CreateDatasetInput import CreateDatasetInput
21
+ from phoenix.server.api.input_types.DeleteDatasetExamplesInput import DeleteDatasetExamplesInput
22
+ from phoenix.server.api.input_types.DeleteDatasetInput import DeleteDatasetInput
23
+ from phoenix.server.api.input_types.PatchDatasetExamplesInput import (
24
+ DatasetExamplePatch,
25
+ PatchDatasetExamplesInput,
26
+ )
27
+ from phoenix.server.api.input_types.PatchDatasetInput import PatchDatasetInput
28
+ from phoenix.server.api.mutations.auth import IsAuthenticated
29
+ from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
30
+ from phoenix.server.api.types.DatasetExample import DatasetExample
31
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
32
+ from phoenix.server.api.types.Span import Span
33
+
34
+
35
+ @strawberry.type
36
+ class DatasetMutationPayload:
37
+ dataset: Dataset
38
+
39
+
40
+ @strawberry.type
41
+ class DatasetMutationMixin:
42
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
43
+ async def create_dataset(
44
+ self,
45
+ info: Info[Context, None],
46
+ input: CreateDatasetInput,
47
+ ) -> DatasetMutationPayload:
48
+ name = input.name
49
+ description = input.description if input.description is not UNSET else None
50
+ metadata = input.metadata
51
+ async with info.context.db() as session:
52
+ dataset = await session.scalar(
53
+ insert(models.Dataset)
54
+ .values(
55
+ name=name,
56
+ description=description,
57
+ metadata_=metadata,
58
+ )
59
+ .returning(models.Dataset)
60
+ )
61
+ assert dataset is not None
62
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
63
+
64
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
65
+ async def patch_dataset(
66
+ self,
67
+ info: Info[Context, None],
68
+ input: PatchDatasetInput,
69
+ ) -> DatasetMutationPayload:
70
+ dataset_id = from_global_id_with_expected_type(
71
+ global_id=input.dataset_id, expected_type_name=Dataset.__name__
72
+ )
73
+ patch = {
74
+ column.key: patch_value
75
+ for column, patch_value, column_is_nullable in (
76
+ (models.Dataset.name, input.name, False),
77
+ (models.Dataset.description, input.description, True),
78
+ (models.Dataset.metadata_, input.metadata, False),
79
+ )
80
+ if patch_value is not UNSET and (patch_value is not None or column_is_nullable)
81
+ }
82
+ async with info.context.db() as session:
83
+ dataset = await session.scalar(
84
+ update(models.Dataset)
85
+ .where(models.Dataset.id == dataset_id)
86
+ .returning(models.Dataset)
87
+ .values(**patch)
88
+ )
89
+ assert dataset is not None
90
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
91
+
92
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
93
+ async def add_spans_to_dataset(
94
+ self,
95
+ info: Info[Context, None],
96
+ input: AddSpansToDatasetInput,
97
+ ) -> DatasetMutationPayload:
98
+ dataset_id = input.dataset_id
99
+ span_ids = input.span_ids
100
+ dataset_version_description = (
101
+ input.dataset_version_description
102
+ if isinstance(input.dataset_version_description, str)
103
+ else None
104
+ )
105
+ dataset_version_metadata = input.dataset_version_metadata
106
+ dataset_rowid = from_global_id_with_expected_type(
107
+ global_id=dataset_id, expected_type_name=Dataset.__name__
108
+ )
109
+ span_rowids = {
110
+ from_global_id_with_expected_type(global_id=span_id, expected_type_name=Span.__name__)
111
+ for span_id in set(span_ids)
112
+ }
113
+ async with info.context.db() as session:
114
+ if (
115
+ dataset := await session.scalar(
116
+ select(models.Dataset).where(models.Dataset.id == dataset_rowid)
117
+ )
118
+ ) is None:
119
+ raise ValueError(
120
+ f"Unknown dataset: {dataset_id}"
121
+ ) # todo: implement error types https://github.com/Arize-ai/phoenix/issues/3221
122
+ dataset_version_rowid = await session.scalar(
123
+ insert(models.DatasetVersion)
124
+ .values(
125
+ dataset_id=dataset_rowid,
126
+ description=dataset_version_description,
127
+ metadata_=dataset_version_metadata,
128
+ )
129
+ .returning(models.DatasetVersion.id)
130
+ )
131
+ spans = (
132
+ await session.execute(
133
+ select(
134
+ models.Span.id,
135
+ models.Span.span_kind,
136
+ models.Span.attributes,
137
+ _span_attribute(INPUT_MIME_TYPE),
138
+ _span_attribute(INPUT_VALUE),
139
+ _span_attribute(OUTPUT_MIME_TYPE),
140
+ _span_attribute(OUTPUT_VALUE),
141
+ _span_attribute(LLM_PROMPT_TEMPLATE_VARIABLES),
142
+ _span_attribute(LLM_INPUT_MESSAGES),
143
+ _span_attribute(LLM_OUTPUT_MESSAGES),
144
+ _span_attribute(RETRIEVAL_DOCUMENTS),
145
+ )
146
+ .select_from(models.Span)
147
+ .where(models.Span.id.in_(span_rowids))
148
+ )
149
+ ).all()
150
+ if missing_span_rowids := span_rowids - {span.id for span in spans}:
151
+ raise ValueError(
152
+ f"Could not find spans with rowids: {', '.join(map(str, missing_span_rowids))}"
153
+ ) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221
154
+ DatasetExample = models.DatasetExample
155
+ dataset_example_rowids = (
156
+ await session.scalars(
157
+ insert(DatasetExample).returning(DatasetExample.id),
158
+ [
159
+ {
160
+ DatasetExample.dataset_id.key: dataset_rowid,
161
+ DatasetExample.span_rowid.key: span.id,
162
+ }
163
+ for span in spans
164
+ ],
165
+ )
166
+ ).all()
167
+ assert len(dataset_example_rowids) == len(spans)
168
+ assert all(map(lambda id: isinstance(id, int), dataset_example_rowids))
169
+ DatasetExampleRevision = models.DatasetExampleRevision
170
+ await session.execute(
171
+ insert(DatasetExampleRevision),
172
+ [
173
+ {
174
+ DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
175
+ DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
176
+ DatasetExampleRevision.input.key: get_dataset_example_input(span),
177
+ DatasetExampleRevision.output.key: get_dataset_example_output(span),
178
+ DatasetExampleRevision.metadata_.key: span.attributes,
179
+ DatasetExampleRevision.revision_kind.key: "CREATE",
180
+ }
181
+ for dataset_example_rowid, span in zip(dataset_example_rowids, spans)
182
+ ],
183
+ )
184
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
185
+
186
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
187
+ async def add_examples_to_dataset(
188
+ self, info: Info[Context, None], input: AddExamplesToDatasetInput
189
+ ) -> DatasetMutationPayload:
190
+ dataset_id = input.dataset_id
191
+ # Extract the span rowids from the input examples if they exist
192
+ span_ids = span_ids = [example.span_id for example in input.examples if example.span_id]
193
+ span_rowids = {
194
+ from_global_id_with_expected_type(global_id=span_id, expected_type_name=Span.__name__)
195
+ for span_id in set(span_ids)
196
+ }
197
+ dataset_version_description = (
198
+ input.dataset_version_description if input.dataset_version_description else None
199
+ )
200
+ dataset_version_metadata = input.dataset_version_metadata
201
+ dataset_rowid = from_global_id_with_expected_type(
202
+ global_id=dataset_id, expected_type_name=Dataset.__name__
203
+ )
204
+ async with info.context.db() as session:
205
+ if (
206
+ dataset := await session.scalar(
207
+ select(models.Dataset).where(models.Dataset.id == dataset_rowid)
208
+ )
209
+ ) is None:
210
+ raise ValueError(
211
+ f"Unknown dataset: {dataset_id}"
212
+ ) # todo: implement error types https://github.com/Arize-ai/phoenix/issues/3221
213
+ dataset_version_rowid = await session.scalar(
214
+ insert(models.DatasetVersion)
215
+ .values(
216
+ dataset_id=dataset_rowid,
217
+ description=dataset_version_description,
218
+ metadata_=dataset_version_metadata,
219
+ )
220
+ .returning(models.DatasetVersion.id)
221
+ )
222
+ spans = (
223
+ await session.execute(
224
+ select(models.Span.id)
225
+ .select_from(models.Span)
226
+ .where(models.Span.id.in_(span_rowids))
227
+ )
228
+ ).all()
229
+ # Just validate that the number of spans matches the number of span_ids
230
+ # to ensure that the span_ids are valid
231
+ assert len(spans) == len(span_rowids)
232
+ DatasetExample = models.DatasetExample
233
+ dataset_example_rowids = (
234
+ await session.scalars(
235
+ insert(DatasetExample).returning(DatasetExample.id),
236
+ [
237
+ {
238
+ DatasetExample.dataset_id.key: dataset_rowid,
239
+ DatasetExample.span_rowid.key: from_global_id_with_expected_type(
240
+ global_id=example.span_id,
241
+ expected_type_name=Span.__name__,
242
+ )
243
+ if example.span_id
244
+ else None,
245
+ }
246
+ for example in input.examples
247
+ ],
248
+ )
249
+ ).all()
250
+ assert len(dataset_example_rowids) == len(input.examples)
251
+ assert all(map(lambda id: isinstance(id, int), dataset_example_rowids))
252
+ DatasetExampleRevision = models.DatasetExampleRevision
253
+ await session.execute(
254
+ insert(DatasetExampleRevision),
255
+ [
256
+ {
257
+ DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
258
+ DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
259
+ DatasetExampleRevision.input.key: example.input,
260
+ DatasetExampleRevision.output.key: example.output,
261
+ DatasetExampleRevision.metadata_.key: example.metadata,
262
+ DatasetExampleRevision.revision_kind.key: "CREATE",
263
+ }
264
+ for dataset_example_rowid, example in zip(
265
+ dataset_example_rowids, input.examples
266
+ )
267
+ ],
268
+ )
269
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
270
+
271
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
272
+ async def delete_dataset(
273
+ self,
274
+ info: Info[Context, None],
275
+ input: DeleteDatasetInput,
276
+ ) -> DatasetMutationPayload:
277
+ dataset_id = input.dataset_id
278
+ dataset_rowid = from_global_id_with_expected_type(
279
+ global_id=dataset_id, expected_type_name=Dataset.__name__
280
+ )
281
+
282
+ async with info.context.db() as session:
283
+ delete_result = await session.execute(
284
+ delete(models.Dataset)
285
+ .where(models.Dataset.id == dataset_rowid)
286
+ .returning(models.Dataset)
287
+ )
288
+ if not (datasets := delete_result.first()):
289
+ raise ValueError(f"Unknown dataset: {dataset_id}")
290
+
291
+ dataset = datasets[0]
292
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
293
+
294
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
295
+ async def patch_dataset_examples(
296
+ self,
297
+ info: Info[Context, None],
298
+ input: PatchDatasetExamplesInput,
299
+ ) -> DatasetMutationPayload:
300
+ if not (patches := input.patches):
301
+ raise ValueError("Must provide examples to patch.")
302
+ by_numeric_id = [
303
+ (
304
+ from_global_id_with_expected_type(patch.example_id, DatasetExample.__name__),
305
+ index,
306
+ patch,
307
+ )
308
+ for index, patch in enumerate(patches)
309
+ ]
310
+ example_ids, _, patches = map(list, zip(*sorted(by_numeric_id)))
311
+ if len(set(example_ids)) < len(example_ids):
312
+ raise ValueError("Cannot patch the same example more than once per mutation.")
313
+ if any(patch.is_empty() for patch in patches):
314
+ raise ValueError("Received one or more empty patches that contain no fields to update.")
315
+ version_description = input.version_description or None
316
+ version_metadata = input.version_metadata
317
+ async with info.context.db() as session:
318
+ datasets = (
319
+ await session.scalars(
320
+ select(models.Dataset)
321
+ .where(
322
+ models.Dataset.id.in_(
323
+ select(distinct(models.DatasetExample.dataset_id))
324
+ .where(models.DatasetExample.id.in_(example_ids))
325
+ .scalar_subquery()
326
+ )
327
+ )
328
+ .limit(2)
329
+ )
330
+ ).all()
331
+ if not datasets:
332
+ raise ValueError("No examples found.")
333
+ if len(set(ds.id for ds in datasets)) > 1:
334
+ raise ValueError("Examples must come from the same dataset.")
335
+ dataset = datasets[0]
336
+
337
+ revision_ids = (
338
+ select(func.max(models.DatasetExampleRevision.id))
339
+ .where(models.DatasetExampleRevision.dataset_example_id.in_(example_ids))
340
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
341
+ .scalar_subquery()
342
+ )
343
+ revisions = (
344
+ await session.scalars(
345
+ select(models.DatasetExampleRevision)
346
+ .where(
347
+ and_(
348
+ models.DatasetExampleRevision.id.in_(revision_ids),
349
+ models.DatasetExampleRevision.revision_kind != "DELETE",
350
+ )
351
+ )
352
+ .order_by(
353
+ models.DatasetExampleRevision.dataset_example_id
354
+ ) # ensure the order of the revisions matches the order of the input patches
355
+ )
356
+ ).all()
357
+ if (num_missing_examples := len(example_ids) - len(revisions)) > 0:
358
+ raise ValueError(f"{num_missing_examples} example(s) could not be found.")
359
+
360
+ version_id = await session.scalar(
361
+ insert(models.DatasetVersion)
362
+ .returning(models.DatasetVersion.id)
363
+ .values(
364
+ dataset_id=dataset.id,
365
+ description=version_description,
366
+ metadata_=version_metadata,
367
+ )
368
+ )
369
+ assert version_id is not None
370
+
371
+ await session.execute(
372
+ insert(models.DatasetExampleRevision),
373
+ [
374
+ _to_orm_revision(
375
+ existing_revision=revision,
376
+ patch=patch,
377
+ example_id=example_id,
378
+ version_id=version_id,
379
+ )
380
+ for revision, patch, example_id in zip(revisions, patches, example_ids)
381
+ ],
382
+ )
383
+
384
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
385
+
386
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
387
+ async def delete_dataset_examples(
388
+ self, info: Info[Context, None], input: DeleteDatasetExamplesInput
389
+ ) -> DatasetMutationPayload:
390
+ timestamp = datetime.now()
391
+ example_db_ids = [
392
+ from_global_id_with_expected_type(global_id, models.DatasetExample.__name__)
393
+ for global_id in input.example_ids
394
+ ]
395
+ # Guard against empty input
396
+ if not example_db_ids:
397
+ raise ValueError("Must provide examples to delete")
398
+ dataset_version_description = (
399
+ input.dataset_version_description
400
+ if isinstance(input.dataset_version_description, str)
401
+ else None
402
+ )
403
+ dataset_version_metadata = input.dataset_version_metadata
404
+ async with info.context.db() as session:
405
+ # Check if the examples are from a single dataset
406
+ datasets = (
407
+ await session.scalars(
408
+ select(models.Dataset)
409
+ .join(
410
+ models.DatasetExample, models.Dataset.id == models.DatasetExample.dataset_id
411
+ )
412
+ .where(models.DatasetExample.id.in_(example_db_ids))
413
+ .distinct()
414
+ .limit(2) # limit to 2 to check if there are more than 1 dataset
415
+ )
416
+ ).all()
417
+ if len(datasets) > 1:
418
+ raise ValueError("Examples must be from the same dataset")
419
+ elif not datasets:
420
+ raise ValueError("Examples not found")
421
+
422
+ dataset = datasets[0]
423
+
424
+ dataset_version_rowid = await session.scalar(
425
+ insert(models.DatasetVersion)
426
+ .values(
427
+ dataset_id=dataset.id,
428
+ description=dataset_version_description,
429
+ metadata_=dataset_version_metadata,
430
+ created_at=timestamp,
431
+ )
432
+ .returning(models.DatasetVersion.id)
433
+ )
434
+
435
+ # If the examples already have a delete revision, skip the deletion
436
+ existing_delete_revisions = (
437
+ await session.scalars(
438
+ select(models.DatasetExampleRevision).where(
439
+ models.DatasetExampleRevision.dataset_example_id.in_(example_db_ids),
440
+ models.DatasetExampleRevision.revision_kind == "DELETE",
441
+ )
442
+ )
443
+ ).all()
444
+
445
+ if existing_delete_revisions:
446
+ raise ValueError(
447
+ "Provided examples contain already deleted examples. Delete aborted."
448
+ )
449
+
450
+ DatasetExampleRevision = models.DatasetExampleRevision
451
+ await session.execute(
452
+ insert(DatasetExampleRevision),
453
+ [
454
+ {
455
+ DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
456
+ DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
457
+ DatasetExampleRevision.input.key: {},
458
+ DatasetExampleRevision.output.key: {},
459
+ DatasetExampleRevision.metadata_.key: {},
460
+ DatasetExampleRevision.revision_kind.key: "DELETE",
461
+ DatasetExampleRevision.created_at.key: timestamp,
462
+ }
463
+ for dataset_example_rowid in example_db_ids
464
+ ],
465
+ )
466
+
467
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
468
+
469
+
470
+ def _span_attribute(semconv: str) -> Any:
471
+ """
472
+ Extracts an attribute from the ORM span attributes column and labels the
473
+ result.
474
+
475
+ E.g., "input.value" -> Span.attributes["input"]["value"].label("input_value")
476
+ """
477
+ attribute_value: Any = models.Span.attributes
478
+ for key in semconv.split("."):
479
+ attribute_value = attribute_value[key]
480
+ return attribute_value.label(semconv.replace(".", "_"))
481
+
482
+
483
+ def _to_orm_revision(
484
+ *,
485
+ existing_revision: models.DatasetExampleRevision,
486
+ patch: DatasetExamplePatch,
487
+ example_id: int,
488
+ version_id: int,
489
+ ) -> Dict[str, Any]:
490
+ """
491
+ Creates a new revision from an existing revision and a patch. The output is a
492
+ dictionary suitable for insertion into the database using the sqlalchemy
493
+ bulk insertion API.
494
+ """
495
+
496
+ db_rev = models.DatasetExampleRevision
497
+ input = patch.input if isinstance(patch.input, dict) else existing_revision.input
498
+ output = patch.output if isinstance(patch.output, dict) else existing_revision.output
499
+ metadata = patch.metadata if isinstance(patch.metadata, dict) else existing_revision.metadata_
500
+ return {
501
+ str(db_column.key): patch_value
502
+ for db_column, patch_value in (
503
+ (db_rev.dataset_example_id, example_id),
504
+ (db_rev.dataset_version_id, version_id),
505
+ (db_rev.input, input),
506
+ (db_rev.output, output),
507
+ (db_rev.metadata_, metadata),
508
+ (db_rev.revision_kind, "PATCH"),
509
+ )
510
+ }
511
+
512
+
513
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
514
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE
515
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
516
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
517
+ LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
518
+ LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
519
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
520
+ RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS
@@ -0,0 +1,65 @@
1
+ from typing import List
2
+
3
+ import strawberry
4
+ from sqlalchemy import delete
5
+ from strawberry.relay import GlobalID
6
+ from strawberry.types import Info
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.context import Context
10
+ from phoenix.server.api.input_types.DeleteExperimentsInput import DeleteExperimentsInput
11
+ from phoenix.server.api.mutations.auth import IsAuthenticated
12
+ from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
13
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
14
+
15
+
16
+ @strawberry.type
17
+ class ExperimentMutationPayload:
18
+ experiments: List[Experiment]
19
+
20
+
21
+ @strawberry.type
22
+ class ExperimentMutationMixin:
23
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
24
+ async def delete_experiments(
25
+ self,
26
+ info: Info[Context, None],
27
+ input: DeleteExperimentsInput,
28
+ ) -> ExperimentMutationPayload:
29
+ experiment_ids = [
30
+ from_global_id_with_expected_type(experiment_id, Experiment.__name__)
31
+ for experiment_id in input.experiment_ids
32
+ ]
33
+ async with info.context.db() as session:
34
+ savepoint = await session.begin_nested()
35
+ experiments = {
36
+ experiment.id: experiment
37
+ async for experiment in (
38
+ await session.stream_scalars(
39
+ delete(models.Experiment)
40
+ .where(models.Experiment.id.in_(experiment_ids))
41
+ .returning(models.Experiment)
42
+ )
43
+ )
44
+ }
45
+ if unknown_experiment_ids := set(experiment_ids) - set(experiments.keys()):
46
+ await savepoint.rollback()
47
+ raise ValueError(
48
+ "Failed to delete experiment(s), "
49
+ "probably due to invalid input experiment ID(s): "
50
+ + str(
51
+ [
52
+ str(GlobalID(Experiment.__name__, str(experiment_id)))
53
+ for experiment_id in unknown_experiment_ids
54
+ ]
55
+ )
56
+ )
57
+ if project_names := set(filter(bool, (e.project_name for e in experiments.values()))):
58
+ await session.execute(
59
+ delete(models.Project).where(models.Project.name.in_(project_names))
60
+ )
61
+ return ExperimentMutationPayload(
62
+ experiments=[
63
+ to_gql_experiment(experiments[experiment_id]) for experiment_id in experiment_ids
64
+ ]
65
+ )
@@ -10,14 +10,16 @@ from strawberry.types import Info
10
10
  import phoenix.core.model_schema as ms
11
11
  from phoenix.server.api.context import Context
12
12
  from phoenix.server.api.input_types.ClusterInput import ClusterInput
13
- from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
14
- from phoenix.server.api.types.Event import parse_event_ids_by_dataset_role, unpack_event_id
13
+ from phoenix.server.api.mutations.auth import IsAuthenticated
14
+ from phoenix.server.api.types.Event import parse_event_ids_by_inferences_role, unpack_event_id
15
15
  from phoenix.server.api.types.ExportedFile import ExportedFile
16
+ from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
16
17
 
17
18
 
18
19
  @strawberry.type
19
- class ExportEventsMutation:
20
+ class ExportEventsMutationMixin:
20
21
  @strawberry.mutation(
22
+ permission_classes=[IsAuthenticated],
21
23
  description=(
22
24
  "Given a list of event ids, export the corresponding data subset in Parquet format."
23
25
  " File name is optional, but if specified, should be without file extension. By default"
@@ -32,11 +34,11 @@ class ExportEventsMutation:
32
34
  ) -> ExportedFile:
33
35
  if not isinstance(file_name, str):
34
36
  file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
35
- row_ids = parse_event_ids_by_dataset_role(event_ids)
37
+ row_ids = parse_event_ids_by_inferences_role(event_ids)
36
38
  exclude_corpus_row_ids = {}
37
- for dataset_role in list(row_ids.keys()):
38
- if isinstance(dataset_role, DatasetRole):
39
- exclude_corpus_row_ids[dataset_role.value] = row_ids[dataset_role]
39
+ for inferences_role in list(row_ids.keys()):
40
+ if isinstance(inferences_role, InferencesRole):
41
+ exclude_corpus_row_ids[inferences_role.value] = row_ids[inferences_role]
40
42
  path = info.context.export_path
41
43
  with open(path / (file_name + ".parquet"), "wb") as fd:
42
44
  loop = asyncio.get_running_loop()
@@ -49,6 +51,7 @@ class ExportEventsMutation:
49
51
  return ExportedFile(file_name=file_name)
50
52
 
51
53
  @strawberry.mutation(
54
+ permission_classes=[IsAuthenticated],
52
55
  description=(
53
56
  "Given a list of clusters, export the corresponding data subset in Parquet format."
54
57
  " File name is optional, but if specified, should be without file extension. By default"
@@ -79,13 +82,13 @@ class ExportEventsMutation:
79
82
 
80
83
  def _unpack_clusters(
81
84
  clusters: List[ClusterInput],
82
- ) -> Tuple[Dict[ms.DatasetRole, List[int]], Dict[ms.DatasetRole, Dict[int, str]]]:
83
- row_numbers: Dict[ms.DatasetRole, List[int]] = defaultdict(list)
84
- cluster_ids: Dict[ms.DatasetRole, Dict[int, str]] = defaultdict(dict)
85
+ ) -> Tuple[Dict[ms.InferencesRole, List[int]], Dict[ms.InferencesRole, Dict[int, str]]]:
86
+ row_numbers: Dict[ms.InferencesRole, List[int]] = defaultdict(list)
87
+ cluster_ids: Dict[ms.InferencesRole, Dict[int, str]] = defaultdict(dict)
85
88
  for i, cluster in enumerate(clusters):
86
- for row_number, dataset_role in map(unpack_event_id, cluster.event_ids):
87
- if isinstance(dataset_role, AncillaryDatasetRole):
89
+ for row_number, inferences_role in map(unpack_event_id, cluster.event_ids):
90
+ if isinstance(inferences_role, AncillaryInferencesRole):
88
91
  continue
89
- row_numbers[dataset_role.value].append(row_number)
90
- cluster_ids[dataset_role.value][row_number] = cluster.id or str(i)
92
+ row_numbers[inferences_role.value].append(row_number)
93
+ cluster_ids[inferences_role.value][row_number] = cluster.id or str(i)
91
94
  return row_numbers, cluster_ids