arize-phoenix 3.16.1__py3-none-any.whl → 7.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (338) hide show
  1. arize_phoenix-7.7.1.dist-info/METADATA +261 -0
  2. arize_phoenix-7.7.1.dist-info/RECORD +345 -0
  3. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/WHEEL +1 -1
  4. arize_phoenix-7.7.1.dist-info/entry_points.txt +3 -0
  5. phoenix/__init__.py +86 -14
  6. phoenix/auth.py +309 -0
  7. phoenix/config.py +675 -45
  8. phoenix/core/model.py +32 -30
  9. phoenix/core/model_schema.py +102 -109
  10. phoenix/core/model_schema_adapter.py +48 -45
  11. phoenix/datetime_utils.py +24 -3
  12. phoenix/db/README.md +54 -0
  13. phoenix/db/__init__.py +4 -0
  14. phoenix/db/alembic.ini +85 -0
  15. phoenix/db/bulk_inserter.py +294 -0
  16. phoenix/db/engines.py +208 -0
  17. phoenix/db/enums.py +20 -0
  18. phoenix/db/facilitator.py +113 -0
  19. phoenix/db/helpers.py +159 -0
  20. phoenix/db/insertion/constants.py +2 -0
  21. phoenix/db/insertion/dataset.py +227 -0
  22. phoenix/db/insertion/document_annotation.py +171 -0
  23. phoenix/db/insertion/evaluation.py +191 -0
  24. phoenix/db/insertion/helpers.py +98 -0
  25. phoenix/db/insertion/span.py +193 -0
  26. phoenix/db/insertion/span_annotation.py +158 -0
  27. phoenix/db/insertion/trace_annotation.py +158 -0
  28. phoenix/db/insertion/types.py +256 -0
  29. phoenix/db/migrate.py +86 -0
  30. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  31. phoenix/db/migrations/env.py +114 -0
  32. phoenix/db/migrations/script.py.mako +26 -0
  33. phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
  34. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
  35. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  36. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  37. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  38. phoenix/db/models.py +807 -0
  39. phoenix/exceptions.py +5 -1
  40. phoenix/experiments/__init__.py +6 -0
  41. phoenix/experiments/evaluators/__init__.py +29 -0
  42. phoenix/experiments/evaluators/base.py +158 -0
  43. phoenix/experiments/evaluators/code_evaluators.py +184 -0
  44. phoenix/experiments/evaluators/llm_evaluators.py +473 -0
  45. phoenix/experiments/evaluators/utils.py +236 -0
  46. phoenix/experiments/functions.py +772 -0
  47. phoenix/experiments/tracing.py +86 -0
  48. phoenix/experiments/types.py +726 -0
  49. phoenix/experiments/utils.py +25 -0
  50. phoenix/inferences/__init__.py +0 -0
  51. phoenix/{datasets → inferences}/errors.py +6 -5
  52. phoenix/{datasets → inferences}/fixtures.py +49 -42
  53. phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
  54. phoenix/{datasets → inferences}/schema.py +11 -11
  55. phoenix/{datasets → inferences}/validation.py +13 -14
  56. phoenix/logging/__init__.py +3 -0
  57. phoenix/logging/_config.py +90 -0
  58. phoenix/logging/_filter.py +6 -0
  59. phoenix/logging/_formatter.py +69 -0
  60. phoenix/metrics/__init__.py +5 -4
  61. phoenix/metrics/binning.py +4 -3
  62. phoenix/metrics/metrics.py +2 -1
  63. phoenix/metrics/mixins.py +7 -6
  64. phoenix/metrics/retrieval_metrics.py +2 -1
  65. phoenix/metrics/timeseries.py +5 -4
  66. phoenix/metrics/wrappers.py +9 -3
  67. phoenix/pointcloud/clustering.py +5 -5
  68. phoenix/pointcloud/pointcloud.py +7 -5
  69. phoenix/pointcloud/projectors.py +5 -6
  70. phoenix/pointcloud/umap_parameters.py +53 -52
  71. phoenix/server/api/README.md +28 -0
  72. phoenix/server/api/auth.py +44 -0
  73. phoenix/server/api/context.py +152 -9
  74. phoenix/server/api/dataloaders/__init__.py +91 -0
  75. phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
  76. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  77. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  78. phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
  79. phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
  80. phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
  81. phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
  82. phoenix/server/api/dataloaders/document_evaluations.py +31 -0
  83. phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
  84. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
  85. phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
  86. phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
  87. phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
  88. phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
  89. phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
  90. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
  91. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  92. phoenix/server/api/dataloaders/record_counts.py +116 -0
  93. phoenix/server/api/dataloaders/session_io.py +79 -0
  94. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  95. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  96. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  97. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  98. phoenix/server/api/dataloaders/span_annotations.py +26 -0
  99. phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
  100. phoenix/server/api/dataloaders/span_descendants.py +57 -0
  101. phoenix/server/api/dataloaders/span_projects.py +33 -0
  102. phoenix/server/api/dataloaders/token_counts.py +124 -0
  103. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  104. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  105. phoenix/server/api/dataloaders/user_roles.py +30 -0
  106. phoenix/server/api/dataloaders/users.py +33 -0
  107. phoenix/server/api/exceptions.py +48 -0
  108. phoenix/server/api/helpers/__init__.py +12 -0
  109. phoenix/server/api/helpers/dataset_helpers.py +217 -0
  110. phoenix/server/api/helpers/experiment_run_filters.py +763 -0
  111. phoenix/server/api/helpers/playground_clients.py +948 -0
  112. phoenix/server/api/helpers/playground_registry.py +70 -0
  113. phoenix/server/api/helpers/playground_spans.py +455 -0
  114. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  115. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  116. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  117. phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
  118. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  119. phoenix/server/api/input_types/ClusterInput.py +2 -2
  120. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  121. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
  122. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
  123. phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
  124. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  125. phoenix/server/api/input_types/DatasetSort.py +17 -0
  126. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  127. phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
  128. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  129. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  130. phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
  131. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  132. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  133. phoenix/server/api/input_types/Granularity.py +1 -1
  134. phoenix/server/api/input_types/InvocationParameters.py +162 -0
  135. phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
  136. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  137. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  138. phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
  139. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  140. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  141. phoenix/server/api/input_types/SpanSort.py +134 -69
  142. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  143. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  144. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  145. phoenix/server/api/mutations/__init__.py +28 -0
  146. phoenix/server/api/mutations/api_key_mutations.py +167 -0
  147. phoenix/server/api/mutations/chat_mutations.py +593 -0
  148. phoenix/server/api/mutations/dataset_mutations.py +591 -0
  149. phoenix/server/api/mutations/experiment_mutations.py +75 -0
  150. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
  151. phoenix/server/api/mutations/project_mutations.py +57 -0
  152. phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
  153. phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
  154. phoenix/server/api/mutations/user_mutations.py +329 -0
  155. phoenix/server/api/openapi/__init__.py +0 -0
  156. phoenix/server/api/openapi/main.py +17 -0
  157. phoenix/server/api/openapi/schema.py +16 -0
  158. phoenix/server/api/queries.py +738 -0
  159. phoenix/server/api/routers/__init__.py +11 -0
  160. phoenix/server/api/routers/auth.py +284 -0
  161. phoenix/server/api/routers/embeddings.py +26 -0
  162. phoenix/server/api/routers/oauth2.py +488 -0
  163. phoenix/server/api/routers/v1/__init__.py +64 -0
  164. phoenix/server/api/routers/v1/datasets.py +1017 -0
  165. phoenix/server/api/routers/v1/evaluations.py +362 -0
  166. phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
  167. phoenix/server/api/routers/v1/experiment_runs.py +167 -0
  168. phoenix/server/api/routers/v1/experiments.py +308 -0
  169. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  170. phoenix/server/api/routers/v1/spans.py +267 -0
  171. phoenix/server/api/routers/v1/traces.py +208 -0
  172. phoenix/server/api/routers/v1/utils.py +95 -0
  173. phoenix/server/api/schema.py +44 -241
  174. phoenix/server/api/subscriptions.py +597 -0
  175. phoenix/server/api/types/Annotation.py +21 -0
  176. phoenix/server/api/types/AnnotationSummary.py +55 -0
  177. phoenix/server/api/types/AnnotatorKind.py +16 -0
  178. phoenix/server/api/types/ApiKey.py +27 -0
  179. phoenix/server/api/types/AuthMethod.py +9 -0
  180. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  181. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  182. phoenix/server/api/types/Cluster.py +25 -24
  183. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  184. phoenix/server/api/types/DataQualityMetric.py +31 -13
  185. phoenix/server/api/types/Dataset.py +288 -63
  186. phoenix/server/api/types/DatasetExample.py +85 -0
  187. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  188. phoenix/server/api/types/DatasetVersion.py +14 -0
  189. phoenix/server/api/types/Dimension.py +32 -31
  190. phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
  191. phoenix/server/api/types/EmbeddingDimension.py +56 -49
  192. phoenix/server/api/types/Evaluation.py +25 -31
  193. phoenix/server/api/types/EvaluationSummary.py +30 -50
  194. phoenix/server/api/types/Event.py +20 -20
  195. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  196. phoenix/server/api/types/Experiment.py +152 -0
  197. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  198. phoenix/server/api/types/ExperimentComparison.py +17 -0
  199. phoenix/server/api/types/ExperimentRun.py +119 -0
  200. phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
  201. phoenix/server/api/types/GenerativeModel.py +9 -0
  202. phoenix/server/api/types/GenerativeProvider.py +85 -0
  203. phoenix/server/api/types/Inferences.py +80 -0
  204. phoenix/server/api/types/InferencesRole.py +23 -0
  205. phoenix/server/api/types/LabelFraction.py +7 -0
  206. phoenix/server/api/types/MimeType.py +2 -2
  207. phoenix/server/api/types/Model.py +54 -54
  208. phoenix/server/api/types/PerformanceMetric.py +8 -5
  209. phoenix/server/api/types/Project.py +407 -142
  210. phoenix/server/api/types/ProjectSession.py +139 -0
  211. phoenix/server/api/types/Segments.py +4 -4
  212. phoenix/server/api/types/Span.py +221 -176
  213. phoenix/server/api/types/SpanAnnotation.py +43 -0
  214. phoenix/server/api/types/SpanIOValue.py +15 -0
  215. phoenix/server/api/types/SystemApiKey.py +9 -0
  216. phoenix/server/api/types/TemplateLanguage.py +10 -0
  217. phoenix/server/api/types/TimeSeries.py +19 -15
  218. phoenix/server/api/types/TokenUsage.py +11 -0
  219. phoenix/server/api/types/Trace.py +154 -0
  220. phoenix/server/api/types/TraceAnnotation.py +45 -0
  221. phoenix/server/api/types/UMAPPoints.py +7 -7
  222. phoenix/server/api/types/User.py +60 -0
  223. phoenix/server/api/types/UserApiKey.py +45 -0
  224. phoenix/server/api/types/UserRole.py +15 -0
  225. phoenix/server/api/types/node.py +4 -112
  226. phoenix/server/api/types/pagination.py +156 -57
  227. phoenix/server/api/utils.py +34 -0
  228. phoenix/server/app.py +864 -115
  229. phoenix/server/bearer_auth.py +163 -0
  230. phoenix/server/dml_event.py +136 -0
  231. phoenix/server/dml_event_handler.py +256 -0
  232. phoenix/server/email/__init__.py +0 -0
  233. phoenix/server/email/sender.py +97 -0
  234. phoenix/server/email/templates/__init__.py +0 -0
  235. phoenix/server/email/templates/password_reset.html +19 -0
  236. phoenix/server/email/types.py +11 -0
  237. phoenix/server/grpc_server.py +102 -0
  238. phoenix/server/jwt_store.py +505 -0
  239. phoenix/server/main.py +305 -116
  240. phoenix/server/oauth2.py +52 -0
  241. phoenix/server/openapi/__init__.py +0 -0
  242. phoenix/server/prometheus.py +111 -0
  243. phoenix/server/rate_limiters.py +188 -0
  244. phoenix/server/static/.vite/manifest.json +87 -0
  245. phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
  246. phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
  247. phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
  248. phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
  249. phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
  250. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
  251. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
  252. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
  253. phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
  254. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
  255. phoenix/server/telemetry.py +68 -0
  256. phoenix/server/templates/index.html +82 -23
  257. phoenix/server/thread_server.py +3 -3
  258. phoenix/server/types.py +275 -0
  259. phoenix/services.py +27 -18
  260. phoenix/session/client.py +743 -68
  261. phoenix/session/data_extractor.py +31 -7
  262. phoenix/session/evaluation.py +3 -9
  263. phoenix/session/session.py +263 -219
  264. phoenix/settings.py +22 -0
  265. phoenix/trace/__init__.py +2 -22
  266. phoenix/trace/attributes.py +338 -0
  267. phoenix/trace/dsl/README.md +116 -0
  268. phoenix/trace/dsl/filter.py +663 -213
  269. phoenix/trace/dsl/helpers.py +73 -21
  270. phoenix/trace/dsl/query.py +574 -201
  271. phoenix/trace/exporter.py +24 -19
  272. phoenix/trace/fixtures.py +368 -32
  273. phoenix/trace/otel.py +71 -219
  274. phoenix/trace/projects.py +3 -2
  275. phoenix/trace/schemas.py +33 -11
  276. phoenix/trace/span_evaluations.py +21 -16
  277. phoenix/trace/span_json_decoder.py +6 -4
  278. phoenix/trace/span_json_encoder.py +2 -2
  279. phoenix/trace/trace_dataset.py +47 -32
  280. phoenix/trace/utils.py +21 -4
  281. phoenix/utilities/__init__.py +0 -26
  282. phoenix/utilities/client.py +132 -0
  283. phoenix/utilities/deprecation.py +31 -0
  284. phoenix/utilities/error_handling.py +3 -2
  285. phoenix/utilities/json.py +109 -0
  286. phoenix/utilities/logging.py +8 -0
  287. phoenix/utilities/project.py +2 -2
  288. phoenix/utilities/re.py +49 -0
  289. phoenix/utilities/span_store.py +0 -23
  290. phoenix/utilities/template_formatters.py +99 -0
  291. phoenix/version.py +1 -1
  292. arize_phoenix-3.16.1.dist-info/METADATA +0 -495
  293. arize_phoenix-3.16.1.dist-info/RECORD +0 -178
  294. phoenix/core/project.py +0 -619
  295. phoenix/core/traces.py +0 -96
  296. phoenix/experimental/evals/__init__.py +0 -73
  297. phoenix/experimental/evals/evaluators.py +0 -413
  298. phoenix/experimental/evals/functions/__init__.py +0 -4
  299. phoenix/experimental/evals/functions/classify.py +0 -453
  300. phoenix/experimental/evals/functions/executor.py +0 -353
  301. phoenix/experimental/evals/functions/generate.py +0 -138
  302. phoenix/experimental/evals/functions/processing.py +0 -76
  303. phoenix/experimental/evals/models/__init__.py +0 -14
  304. phoenix/experimental/evals/models/anthropic.py +0 -175
  305. phoenix/experimental/evals/models/base.py +0 -170
  306. phoenix/experimental/evals/models/bedrock.py +0 -221
  307. phoenix/experimental/evals/models/litellm.py +0 -134
  308. phoenix/experimental/evals/models/openai.py +0 -448
  309. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  310. phoenix/experimental/evals/models/vertex.py +0 -173
  311. phoenix/experimental/evals/models/vertexai.py +0 -186
  312. phoenix/experimental/evals/retrievals.py +0 -96
  313. phoenix/experimental/evals/templates/__init__.py +0 -50
  314. phoenix/experimental/evals/templates/default_templates.py +0 -472
  315. phoenix/experimental/evals/templates/template.py +0 -195
  316. phoenix/experimental/evals/utils/__init__.py +0 -172
  317. phoenix/experimental/evals/utils/threads.py +0 -27
  318. phoenix/server/api/helpers.py +0 -11
  319. phoenix/server/api/routers/evaluation_handler.py +0 -109
  320. phoenix/server/api/routers/span_handler.py +0 -70
  321. phoenix/server/api/routers/trace_handler.py +0 -60
  322. phoenix/server/api/types/DatasetRole.py +0 -23
  323. phoenix/server/static/index.css +0 -6
  324. phoenix/server/static/index.js +0 -7447
  325. phoenix/storage/span_store/__init__.py +0 -23
  326. phoenix/storage/span_store/text_file.py +0 -85
  327. phoenix/trace/dsl/missing.py +0 -60
  328. phoenix/trace/langchain/__init__.py +0 -3
  329. phoenix/trace/langchain/instrumentor.py +0 -35
  330. phoenix/trace/llama_index/__init__.py +0 -3
  331. phoenix/trace/llama_index/callback.py +0 -102
  332. phoenix/trace/openai/__init__.py +0 -3
  333. phoenix/trace/openai/instrumentor.py +0 -30
  334. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/IP_NOTICE +0 -0
  335. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/LICENSE +0 -0
  336. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  337. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  338. /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
@@ -0,0 +1,591 @@
1
+ import asyncio
2
+ from datetime import datetime
3
+ from typing import Any
4
+
5
+ import strawberry
6
+ from openinference.semconv.trace import (
7
+ SpanAttributes,
8
+ )
9
+ from sqlalchemy import and_, delete, distinct, func, insert, select, update
10
+ from strawberry import UNSET
11
+ from strawberry.types import Info
12
+
13
+ from phoenix.db import models
14
+ from phoenix.db.helpers import get_eval_trace_ids_for_datasets, get_project_names_for_datasets
15
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly
16
+ from phoenix.server.api.context import Context
17
+ from phoenix.server.api.exceptions import BadRequest, NotFound
18
+ from phoenix.server.api.helpers.dataset_helpers import (
19
+ get_dataset_example_input,
20
+ get_dataset_example_output,
21
+ )
22
+ from phoenix.server.api.input_types.AddExamplesToDatasetInput import AddExamplesToDatasetInput
23
+ from phoenix.server.api.input_types.AddSpansToDatasetInput import AddSpansToDatasetInput
24
+ from phoenix.server.api.input_types.CreateDatasetInput import CreateDatasetInput
25
+ from phoenix.server.api.input_types.DeleteDatasetExamplesInput import DeleteDatasetExamplesInput
26
+ from phoenix.server.api.input_types.DeleteDatasetInput import DeleteDatasetInput
27
+ from phoenix.server.api.input_types.PatchDatasetExamplesInput import (
28
+ DatasetExamplePatch,
29
+ PatchDatasetExamplesInput,
30
+ )
31
+ from phoenix.server.api.input_types.PatchDatasetInput import PatchDatasetInput
32
+ from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
33
+ from phoenix.server.api.types.DatasetExample import DatasetExample
34
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
35
+ from phoenix.server.api.types.Span import Span
36
+ from phoenix.server.api.utils import delete_projects, delete_traces
37
+ from phoenix.server.dml_event import DatasetDeleteEvent, DatasetInsertEvent
38
+
39
+
40
+ @strawberry.type
41
+ class DatasetMutationPayload:
42
+ dataset: Dataset
43
+
44
+
45
+ @strawberry.type
46
+ class DatasetMutationMixin:
47
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
48
+ async def create_dataset(
49
+ self,
50
+ info: Info[Context, None],
51
+ input: CreateDatasetInput,
52
+ ) -> DatasetMutationPayload:
53
+ name = input.name
54
+ description = input.description if input.description is not UNSET else None
55
+ metadata = input.metadata
56
+ async with info.context.db() as session:
57
+ dataset = await session.scalar(
58
+ insert(models.Dataset)
59
+ .values(
60
+ name=name,
61
+ description=description,
62
+ metadata_=metadata,
63
+ )
64
+ .returning(models.Dataset)
65
+ )
66
+ assert dataset is not None
67
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
68
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
69
+
70
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
71
+ async def patch_dataset(
72
+ self,
73
+ info: Info[Context, None],
74
+ input: PatchDatasetInput,
75
+ ) -> DatasetMutationPayload:
76
+ dataset_id = from_global_id_with_expected_type(
77
+ global_id=input.dataset_id, expected_type_name=Dataset.__name__
78
+ )
79
+ patch = {
80
+ column.key: patch_value
81
+ for column, patch_value, column_is_nullable in (
82
+ (models.Dataset.name, input.name, False),
83
+ (models.Dataset.description, input.description, True),
84
+ (models.Dataset.metadata_, input.metadata, False),
85
+ )
86
+ if patch_value is not UNSET and (patch_value is not None or column_is_nullable)
87
+ }
88
+ async with info.context.db() as session:
89
+ dataset = await session.scalar(
90
+ update(models.Dataset)
91
+ .where(models.Dataset.id == dataset_id)
92
+ .returning(models.Dataset)
93
+ .values(**patch)
94
+ )
95
+ assert dataset is not None
96
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
97
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
98
+
99
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
100
+ async def add_spans_to_dataset(
101
+ self,
102
+ info: Info[Context, None],
103
+ input: AddSpansToDatasetInput,
104
+ ) -> DatasetMutationPayload:
105
+ dataset_id = input.dataset_id
106
+ span_ids = input.span_ids
107
+ dataset_version_description = (
108
+ input.dataset_version_description
109
+ if isinstance(input.dataset_version_description, str)
110
+ else None
111
+ )
112
+ dataset_version_metadata = input.dataset_version_metadata
113
+ dataset_rowid = from_global_id_with_expected_type(
114
+ global_id=dataset_id, expected_type_name=Dataset.__name__
115
+ )
116
+ span_rowids = {
117
+ from_global_id_with_expected_type(global_id=span_id, expected_type_name=Span.__name__)
118
+ for span_id in set(span_ids)
119
+ }
120
+ async with info.context.db() as session:
121
+ if (
122
+ dataset := await session.scalar(
123
+ select(models.Dataset).where(models.Dataset.id == dataset_rowid)
124
+ )
125
+ ) is None:
126
+ raise ValueError(
127
+ f"Unknown dataset: {dataset_id}"
128
+ ) # todo: implement error types https://github.com/Arize-ai/phoenix/issues/3221
129
+ dataset_version_rowid = await session.scalar(
130
+ insert(models.DatasetVersion)
131
+ .values(
132
+ dataset_id=dataset_rowid,
133
+ description=dataset_version_description,
134
+ metadata_=dataset_version_metadata,
135
+ )
136
+ .returning(models.DatasetVersion.id)
137
+ )
138
+ spans = (
139
+ await session.scalars(select(models.Span).where(models.Span.id.in_(span_rowids)))
140
+ ).all()
141
+ if missing_span_rowids := span_rowids - {span.id for span in spans}:
142
+ raise ValueError(
143
+ f"Could not find spans with rowids: {', '.join(map(str, missing_span_rowids))}"
144
+ ) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221
145
+
146
+ span_annotations = (
147
+ await session.scalars(
148
+ select(models.SpanAnnotation).where(
149
+ models.SpanAnnotation.span_rowid.in_(span_rowids)
150
+ )
151
+ )
152
+ ).all()
153
+
154
+ span_annotations_by_span: dict[int, dict[Any, Any]] = {span.id: {} for span in spans}
155
+ for annotation in span_annotations:
156
+ span_id = annotation.span_rowid
157
+ if span_id not in span_annotations_by_span:
158
+ span_annotations_by_span[span_id] = dict()
159
+ span_annotations_by_span[span_id][annotation.name] = {
160
+ "label": annotation.label,
161
+ "score": annotation.score,
162
+ "explanation": annotation.explanation,
163
+ "metadata": annotation.metadata_,
164
+ "annotator_kind": annotation.annotator_kind,
165
+ }
166
+
167
+ DatasetExample = models.DatasetExample
168
+ dataset_example_rowids = (
169
+ await session.scalars(
170
+ insert(DatasetExample).returning(DatasetExample.id),
171
+ [
172
+ {
173
+ DatasetExample.dataset_id.key: dataset_rowid,
174
+ DatasetExample.span_rowid.key: span.id,
175
+ }
176
+ for span in spans
177
+ ],
178
+ )
179
+ ).all()
180
+ assert len(dataset_example_rowids) == len(spans)
181
+ assert all(map(lambda id: isinstance(id, int), dataset_example_rowids))
182
+ DatasetExampleRevision = models.DatasetExampleRevision
183
+
184
+ await session.execute(
185
+ insert(DatasetExampleRevision),
186
+ [
187
+ {
188
+ DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
189
+ DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
190
+ DatasetExampleRevision.input.key: get_dataset_example_input(span),
191
+ DatasetExampleRevision.output.key: get_dataset_example_output(span),
192
+ DatasetExampleRevision.metadata_.key: {
193
+ "span_kind": span.span_kind,
194
+ **(
195
+ {"annotations": annotations}
196
+ if (annotations := span_annotations_by_span[span.id])
197
+ else {}
198
+ ),
199
+ },
200
+ DatasetExampleRevision.revision_kind.key: "CREATE",
201
+ }
202
+ for dataset_example_rowid, span in zip(dataset_example_rowids, spans)
203
+ ],
204
+ )
205
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
206
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
207
+
208
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
209
+ async def add_examples_to_dataset(
210
+ self, info: Info[Context, None], input: AddExamplesToDatasetInput
211
+ ) -> DatasetMutationPayload:
212
+ dataset_id = input.dataset_id
213
+ # Extract the span rowids from the input examples if they exist
214
+ span_ids = [example.span_id for example in input.examples if example.span_id]
215
+ span_rowids = {
216
+ from_global_id_with_expected_type(global_id=span_id, expected_type_name=Span.__name__)
217
+ for span_id in set(span_ids)
218
+ }
219
+ dataset_version_description = (
220
+ input.dataset_version_description if input.dataset_version_description else None
221
+ )
222
+ dataset_version_metadata = input.dataset_version_metadata
223
+ dataset_rowid = from_global_id_with_expected_type(
224
+ global_id=dataset_id, expected_type_name=Dataset.__name__
225
+ )
226
+ async with info.context.db() as session:
227
+ if (
228
+ dataset := await session.scalar(
229
+ select(models.Dataset).where(models.Dataset.id == dataset_rowid)
230
+ )
231
+ ) is None:
232
+ raise ValueError(
233
+ f"Unknown dataset: {dataset_id}"
234
+ ) # todo: implement error types https://github.com/Arize-ai/phoenix/issues/3221
235
+ dataset_version_rowid = await session.scalar(
236
+ insert(models.DatasetVersion)
237
+ .values(
238
+ dataset_id=dataset_rowid,
239
+ description=dataset_version_description,
240
+ metadata_=dataset_version_metadata,
241
+ )
242
+ .returning(models.DatasetVersion.id)
243
+ )
244
+
245
+ # Fetch spans and span annotations
246
+ spans = (
247
+ await session.execute(
248
+ select(models.Span.id)
249
+ .select_from(models.Span)
250
+ .where(models.Span.id.in_(span_rowids))
251
+ )
252
+ ).all()
253
+
254
+ span_annotations = (
255
+ await session.execute(
256
+ select(
257
+ models.SpanAnnotation.span_rowid,
258
+ models.SpanAnnotation.name,
259
+ models.SpanAnnotation.label,
260
+ models.SpanAnnotation.score,
261
+ models.SpanAnnotation.explanation,
262
+ models.SpanAnnotation.metadata_,
263
+ models.SpanAnnotation.annotator_kind,
264
+ )
265
+ .select_from(models.SpanAnnotation)
266
+ .where(models.SpanAnnotation.span_rowid.in_(span_rowids))
267
+ )
268
+ ).all()
269
+
270
+ span_annotations_by_span: dict[int, dict[Any, Any]] = {span.id: {} for span in spans}
271
+ for annotation in span_annotations:
272
+ span_id = annotation.span_rowid
273
+ if span_id not in span_annotations_by_span:
274
+ span_annotations_by_span[span_id] = dict()
275
+ span_annotations_by_span[span_id][annotation.name] = {
276
+ "label": annotation.label,
277
+ "score": annotation.score,
278
+ "explanation": annotation.explanation,
279
+ "metadata": annotation.metadata_,
280
+ "annotator_kind": annotation.annotator_kind,
281
+ }
282
+
283
+ DatasetExample = models.DatasetExample
284
+ dataset_example_rowids = (
285
+ await session.scalars(
286
+ insert(DatasetExample).returning(DatasetExample.id),
287
+ [
288
+ {
289
+ DatasetExample.dataset_id.key: dataset_rowid,
290
+ DatasetExample.span_rowid.key: from_global_id_with_expected_type(
291
+ global_id=example.span_id,
292
+ expected_type_name=Span.__name__,
293
+ )
294
+ if example.span_id
295
+ else None,
296
+ }
297
+ for example in input.examples
298
+ ],
299
+ )
300
+ ).all()
301
+ assert len(dataset_example_rowids) == len(input.examples)
302
+ assert all(map(lambda id: isinstance(id, int), dataset_example_rowids))
303
+ DatasetExampleRevision = models.DatasetExampleRevision
304
+
305
+ dataset_example_revisions = []
306
+ for dataset_example_rowid, example in zip(dataset_example_rowids, input.examples):
307
+ span_annotation = {}
308
+ if example.span_id:
309
+ span_id = from_global_id_with_expected_type(
310
+ global_id=example.span_id,
311
+ expected_type_name=Span.__name__,
312
+ )
313
+ span_annotation = span_annotations_by_span.get(span_id, {})
314
+ dataset_example_revisions.append(
315
+ {
316
+ DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
317
+ DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
318
+ DatasetExampleRevision.input.key: example.input,
319
+ DatasetExampleRevision.output.key: example.output,
320
+ DatasetExampleRevision.metadata_.key: {
321
+ **(example.metadata or {}),
322
+ "annotations": span_annotation,
323
+ },
324
+ DatasetExampleRevision.revision_kind.key: "CREATE",
325
+ }
326
+ )
327
+ await session.execute(
328
+ insert(DatasetExampleRevision),
329
+ dataset_example_revisions,
330
+ )
331
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
332
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
333
+
334
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
335
+ async def delete_dataset(
336
+ self,
337
+ info: Info[Context, None],
338
+ input: DeleteDatasetInput,
339
+ ) -> DatasetMutationPayload:
340
+ try:
341
+ dataset_id = from_global_id_with_expected_type(
342
+ global_id=input.dataset_id,
343
+ expected_type_name=Dataset.__name__,
344
+ )
345
+ except ValueError:
346
+ raise NotFound(f"Unknown dataset: {input.dataset_id}")
347
+ project_names_stmt = get_project_names_for_datasets(dataset_id)
348
+ eval_trace_ids_stmt = get_eval_trace_ids_for_datasets(dataset_id)
349
+ stmt = (
350
+ delete(models.Dataset).where(models.Dataset.id == dataset_id).returning(models.Dataset)
351
+ )
352
+ async with info.context.db() as session:
353
+ project_names = await session.scalars(project_names_stmt)
354
+ eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
355
+ if not (dataset := await session.scalar(stmt)):
356
+ raise NotFound(f"Unknown dataset: {input.dataset_id}")
357
+ await asyncio.gather(
358
+ delete_projects(info.context.db, *project_names),
359
+ delete_traces(info.context.db, *eval_trace_ids),
360
+ return_exceptions=True,
361
+ )
362
+ info.context.event_queue.put(DatasetDeleteEvent((dataset.id,)))
363
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
364
+
365
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
366
+ async def patch_dataset_examples(
367
+ self,
368
+ info: Info[Context, None],
369
+ input: PatchDatasetExamplesInput,
370
+ ) -> DatasetMutationPayload:
371
+ if not (patches := input.patches):
372
+ raise BadRequest("Must provide examples to patch.")
373
+ by_numeric_id = [
374
+ (
375
+ from_global_id_with_expected_type(patch.example_id, DatasetExample.__name__),
376
+ index,
377
+ patch,
378
+ )
379
+ for index, patch in enumerate(patches)
380
+ ]
381
+ example_ids, _, patches = map(list, zip(*sorted(by_numeric_id)))
382
+ if len(set(example_ids)) < len(example_ids):
383
+ raise BadRequest("Cannot patch the same example more than once per mutation.")
384
+ if any(patch.is_empty() for patch in patches):
385
+ raise BadRequest("Received one or more empty patches that contain no fields to update.")
386
+ version_description = input.version_description or None
387
+ version_metadata = input.version_metadata
388
+ async with info.context.db() as session:
389
+ datasets = (
390
+ await session.scalars(
391
+ select(models.Dataset)
392
+ .where(
393
+ models.Dataset.id.in_(
394
+ select(distinct(models.DatasetExample.dataset_id))
395
+ .where(models.DatasetExample.id.in_(example_ids))
396
+ .scalar_subquery()
397
+ )
398
+ )
399
+ .limit(2)
400
+ )
401
+ ).all()
402
+ if not datasets:
403
+ raise NotFound("No examples found.")
404
+ if len(set(ds.id for ds in datasets)) > 1:
405
+ raise BadRequest("Examples must come from the same dataset.")
406
+ dataset = datasets[0]
407
+
408
+ revision_ids = (
409
+ select(func.max(models.DatasetExampleRevision.id))
410
+ .where(models.DatasetExampleRevision.dataset_example_id.in_(example_ids))
411
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
412
+ .scalar_subquery()
413
+ )
414
+ revisions = (
415
+ await session.scalars(
416
+ select(models.DatasetExampleRevision)
417
+ .where(
418
+ and_(
419
+ models.DatasetExampleRevision.id.in_(revision_ids),
420
+ models.DatasetExampleRevision.revision_kind != "DELETE",
421
+ )
422
+ )
423
+ .order_by(
424
+ models.DatasetExampleRevision.dataset_example_id
425
+ ) # ensure the order of the revisions matches the order of the input patches
426
+ )
427
+ ).all()
428
+ if (num_missing_examples := len(example_ids) - len(revisions)) > 0:
429
+ raise NotFound(f"{num_missing_examples} example(s) could not be found.")
430
+
431
+ version_id = await session.scalar(
432
+ insert(models.DatasetVersion)
433
+ .returning(models.DatasetVersion.id)
434
+ .values(
435
+ dataset_id=dataset.id,
436
+ description=version_description,
437
+ metadata_=version_metadata,
438
+ )
439
+ )
440
+ assert version_id is not None
441
+
442
+ await session.execute(
443
+ insert(models.DatasetExampleRevision),
444
+ [
445
+ _to_orm_revision(
446
+ existing_revision=revision,
447
+ patch=patch,
448
+ example_id=example_id,
449
+ version_id=version_id,
450
+ )
451
+ for revision, patch, example_id in zip(revisions, patches, example_ids)
452
+ ],
453
+ )
454
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
455
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
456
+
457
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
458
+ async def delete_dataset_examples(
459
+ self, info: Info[Context, None], input: DeleteDatasetExamplesInput
460
+ ) -> DatasetMutationPayload:
461
+ timestamp = datetime.now()
462
+ example_db_ids = [
463
+ from_global_id_with_expected_type(global_id, models.DatasetExample.__name__)
464
+ for global_id in input.example_ids
465
+ ]
466
+ # Guard against empty input
467
+ if not example_db_ids:
468
+ raise ValueError("Must provide examples to delete")
469
+ dataset_version_description = (
470
+ input.dataset_version_description
471
+ if isinstance(input.dataset_version_description, str)
472
+ else None
473
+ )
474
+ dataset_version_metadata = input.dataset_version_metadata
475
+ async with info.context.db() as session:
476
+ # Check if the examples are from a single dataset
477
+ datasets = (
478
+ await session.scalars(
479
+ select(models.Dataset)
480
+ .join(
481
+ models.DatasetExample, models.Dataset.id == models.DatasetExample.dataset_id
482
+ )
483
+ .where(models.DatasetExample.id.in_(example_db_ids))
484
+ .distinct()
485
+ .limit(2) # limit to 2 to check if there are more than 1 dataset
486
+ )
487
+ ).all()
488
+ if len(datasets) > 1:
489
+ raise ValueError("Examples must be from the same dataset")
490
+ elif not datasets:
491
+ raise ValueError("Examples not found")
492
+
493
+ dataset = datasets[0]
494
+
495
+ dataset_version_rowid = await session.scalar(
496
+ insert(models.DatasetVersion)
497
+ .values(
498
+ dataset_id=dataset.id,
499
+ description=dataset_version_description,
500
+ metadata_=dataset_version_metadata,
501
+ created_at=timestamp,
502
+ )
503
+ .returning(models.DatasetVersion.id)
504
+ )
505
+
506
+ # If the examples already have a delete revision, skip the deletion
507
+ existing_delete_revisions = (
508
+ await session.scalars(
509
+ select(models.DatasetExampleRevision).where(
510
+ models.DatasetExampleRevision.dataset_example_id.in_(example_db_ids),
511
+ models.DatasetExampleRevision.revision_kind == "DELETE",
512
+ )
513
+ )
514
+ ).all()
515
+
516
+ if existing_delete_revisions:
517
+ raise ValueError(
518
+ "Provided examples contain already deleted examples. Delete aborted."
519
+ )
520
+
521
+ DatasetExampleRevision = models.DatasetExampleRevision
522
+ await session.execute(
523
+ insert(DatasetExampleRevision),
524
+ [
525
+ {
526
+ DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
527
+ DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
528
+ DatasetExampleRevision.input.key: {},
529
+ DatasetExampleRevision.output.key: {},
530
+ DatasetExampleRevision.metadata_.key: {},
531
+ DatasetExampleRevision.revision_kind.key: "DELETE",
532
+ DatasetExampleRevision.created_at.key: timestamp,
533
+ }
534
+ for dataset_example_rowid in example_db_ids
535
+ ],
536
+ )
537
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
538
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
539
+
540
+
541
+ def _span_attribute(semconv: str) -> Any:
542
+ """
543
+ Extracts an attribute from the ORM span attributes column and labels the
544
+ result.
545
+
546
+ E.g., "input.value" -> Span.attributes["input"]["value"].label("input_value")
547
+ """
548
+ attribute_value: Any = models.Span.attributes
549
+ for key in semconv.split("."):
550
+ attribute_value = attribute_value[key]
551
+ return attribute_value.label(semconv.replace(".", "_"))
552
+
553
+
554
+ def _to_orm_revision(
555
+ *,
556
+ existing_revision: models.DatasetExampleRevision,
557
+ patch: DatasetExamplePatch,
558
+ example_id: int,
559
+ version_id: int,
560
+ ) -> dict[str, Any]:
561
+ """
562
+ Creates a new revision from an existing revision and a patch. The output is a
563
+ dictionary suitable for insertion into the database using the sqlalchemy
564
+ bulk insertion API.
565
+ """
566
+
567
+ db_rev = models.DatasetExampleRevision
568
+ input = patch.input if isinstance(patch.input, dict) else existing_revision.input
569
+ output = patch.output if isinstance(patch.output, dict) else existing_revision.output
570
+ metadata = patch.metadata if isinstance(patch.metadata, dict) else existing_revision.metadata_
571
+ return {
572
+ str(db_column.key): patch_value
573
+ for db_column, patch_value in (
574
+ (db_rev.dataset_example_id, example_id),
575
+ (db_rev.dataset_version_id, version_id),
576
+ (db_rev.input, input),
577
+ (db_rev.output, output),
578
+ (db_rev.metadata_, metadata),
579
+ (db_rev.revision_kind, "PATCH"),
580
+ )
581
+ }
582
+
583
+
584
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
585
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE
586
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
587
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
588
+ LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
589
+ LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
590
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
591
+ RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS
@@ -0,0 +1,75 @@
1
+ import asyncio
2
+
3
+ import strawberry
4
+ from sqlalchemy import delete
5
+ from strawberry.relay import GlobalID
6
+ from strawberry.types import Info
7
+
8
+ from phoenix.db import models
9
+ from phoenix.db.helpers import get_eval_trace_ids_for_experiments, get_project_names_for_experiments
10
+ from phoenix.server.api.auth import IsNotReadOnly
11
+ from phoenix.server.api.context import Context
12
+ from phoenix.server.api.exceptions import CustomGraphQLError
13
+ from phoenix.server.api.input_types.DeleteExperimentsInput import DeleteExperimentsInput
14
+ from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
15
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
16
+ from phoenix.server.api.utils import delete_projects, delete_traces
17
+ from phoenix.server.dml_event import ExperimentDeleteEvent
18
+
19
+
20
+ @strawberry.type
21
+ class ExperimentMutationPayload:
22
+ experiments: list[Experiment]
23
+
24
+
25
+ @strawberry.type
26
+ class ExperimentMutationMixin:
27
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
28
+ async def delete_experiments(
29
+ self,
30
+ info: Info[Context, None],
31
+ input: DeleteExperimentsInput,
32
+ ) -> ExperimentMutationPayload:
33
+ experiment_ids = [
34
+ from_global_id_with_expected_type(experiment_id, Experiment.__name__)
35
+ for experiment_id in input.experiment_ids
36
+ ]
37
+ project_names_stmt = get_project_names_for_experiments(*experiment_ids)
38
+ eval_trace_ids_stmt = get_eval_trace_ids_for_experiments(*experiment_ids)
39
+ async with info.context.db() as session:
40
+ project_names = await session.scalars(project_names_stmt)
41
+ eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
42
+ savepoint = await session.begin_nested()
43
+ experiments = {
44
+ experiment.id: experiment
45
+ async for experiment in (
46
+ await session.stream_scalars(
47
+ delete(models.Experiment)
48
+ .where(models.Experiment.id.in_(experiment_ids))
49
+ .returning(models.Experiment)
50
+ )
51
+ )
52
+ }
53
+ if unknown_experiment_ids := set(experiment_ids) - set(experiments.keys()):
54
+ await savepoint.rollback()
55
+ raise CustomGraphQLError(
56
+ "Failed to delete experiment(s), "
57
+ "probably due to invalid input experiment ID(s): "
58
+ + str(
59
+ [
60
+ str(GlobalID(Experiment.__name__, str(experiment_id)))
61
+ for experiment_id in unknown_experiment_ids
62
+ ]
63
+ )
64
+ )
65
+ await asyncio.gather(
66
+ delete_projects(info.context.db, *project_names),
67
+ delete_traces(info.context.db, *eval_trace_ids),
68
+ return_exceptions=True,
69
+ )
70
+ info.context.event_queue.put(ExperimentDeleteEvent(tuple(experiments.keys())))
71
+ return ExperimentMutationPayload(
72
+ experiments=[
73
+ to_gql_experiment(experiments[experiment_id]) for experiment_id in experiment_ids
74
+ ]
75
+ )