arize-phoenix 3.16.0__py3-none-any.whl → 7.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (338) hide show
  1. arize_phoenix-7.7.0.dist-info/METADATA +261 -0
  2. arize_phoenix-7.7.0.dist-info/RECORD +345 -0
  3. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/WHEEL +1 -1
  4. arize_phoenix-7.7.0.dist-info/entry_points.txt +3 -0
  5. phoenix/__init__.py +86 -14
  6. phoenix/auth.py +309 -0
  7. phoenix/config.py +675 -45
  8. phoenix/core/model.py +32 -30
  9. phoenix/core/model_schema.py +102 -109
  10. phoenix/core/model_schema_adapter.py +48 -45
  11. phoenix/datetime_utils.py +24 -3
  12. phoenix/db/README.md +54 -0
  13. phoenix/db/__init__.py +4 -0
  14. phoenix/db/alembic.ini +85 -0
  15. phoenix/db/bulk_inserter.py +294 -0
  16. phoenix/db/engines.py +208 -0
  17. phoenix/db/enums.py +20 -0
  18. phoenix/db/facilitator.py +113 -0
  19. phoenix/db/helpers.py +159 -0
  20. phoenix/db/insertion/constants.py +2 -0
  21. phoenix/db/insertion/dataset.py +227 -0
  22. phoenix/db/insertion/document_annotation.py +171 -0
  23. phoenix/db/insertion/evaluation.py +191 -0
  24. phoenix/db/insertion/helpers.py +98 -0
  25. phoenix/db/insertion/span.py +193 -0
  26. phoenix/db/insertion/span_annotation.py +158 -0
  27. phoenix/db/insertion/trace_annotation.py +158 -0
  28. phoenix/db/insertion/types.py +256 -0
  29. phoenix/db/migrate.py +86 -0
  30. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  31. phoenix/db/migrations/env.py +114 -0
  32. phoenix/db/migrations/script.py.mako +26 -0
  33. phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
  34. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
  35. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  36. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  37. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  38. phoenix/db/models.py +807 -0
  39. phoenix/exceptions.py +5 -1
  40. phoenix/experiments/__init__.py +6 -0
  41. phoenix/experiments/evaluators/__init__.py +29 -0
  42. phoenix/experiments/evaluators/base.py +158 -0
  43. phoenix/experiments/evaluators/code_evaluators.py +184 -0
  44. phoenix/experiments/evaluators/llm_evaluators.py +473 -0
  45. phoenix/experiments/evaluators/utils.py +236 -0
  46. phoenix/experiments/functions.py +772 -0
  47. phoenix/experiments/tracing.py +86 -0
  48. phoenix/experiments/types.py +726 -0
  49. phoenix/experiments/utils.py +25 -0
  50. phoenix/inferences/__init__.py +0 -0
  51. phoenix/{datasets → inferences}/errors.py +6 -5
  52. phoenix/{datasets → inferences}/fixtures.py +49 -42
  53. phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
  54. phoenix/{datasets → inferences}/schema.py +11 -11
  55. phoenix/{datasets → inferences}/validation.py +13 -14
  56. phoenix/logging/__init__.py +3 -0
  57. phoenix/logging/_config.py +90 -0
  58. phoenix/logging/_filter.py +6 -0
  59. phoenix/logging/_formatter.py +69 -0
  60. phoenix/metrics/__init__.py +5 -4
  61. phoenix/metrics/binning.py +4 -3
  62. phoenix/metrics/metrics.py +2 -1
  63. phoenix/metrics/mixins.py +7 -6
  64. phoenix/metrics/retrieval_metrics.py +2 -1
  65. phoenix/metrics/timeseries.py +5 -4
  66. phoenix/metrics/wrappers.py +9 -3
  67. phoenix/pointcloud/clustering.py +5 -5
  68. phoenix/pointcloud/pointcloud.py +7 -5
  69. phoenix/pointcloud/projectors.py +5 -6
  70. phoenix/pointcloud/umap_parameters.py +53 -52
  71. phoenix/server/api/README.md +28 -0
  72. phoenix/server/api/auth.py +44 -0
  73. phoenix/server/api/context.py +152 -9
  74. phoenix/server/api/dataloaders/__init__.py +91 -0
  75. phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
  76. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  77. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  78. phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
  79. phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
  80. phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
  81. phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
  82. phoenix/server/api/dataloaders/document_evaluations.py +31 -0
  83. phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
  84. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
  85. phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
  86. phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
  87. phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
  88. phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
  89. phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
  90. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
  91. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  92. phoenix/server/api/dataloaders/record_counts.py +116 -0
  93. phoenix/server/api/dataloaders/session_io.py +79 -0
  94. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  95. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  96. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  97. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  98. phoenix/server/api/dataloaders/span_annotations.py +26 -0
  99. phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
  100. phoenix/server/api/dataloaders/span_descendants.py +57 -0
  101. phoenix/server/api/dataloaders/span_projects.py +33 -0
  102. phoenix/server/api/dataloaders/token_counts.py +124 -0
  103. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  104. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  105. phoenix/server/api/dataloaders/user_roles.py +30 -0
  106. phoenix/server/api/dataloaders/users.py +33 -0
  107. phoenix/server/api/exceptions.py +48 -0
  108. phoenix/server/api/helpers/__init__.py +12 -0
  109. phoenix/server/api/helpers/dataset_helpers.py +217 -0
  110. phoenix/server/api/helpers/experiment_run_filters.py +763 -0
  111. phoenix/server/api/helpers/playground_clients.py +948 -0
  112. phoenix/server/api/helpers/playground_registry.py +70 -0
  113. phoenix/server/api/helpers/playground_spans.py +455 -0
  114. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  115. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  116. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  117. phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
  118. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  119. phoenix/server/api/input_types/ClusterInput.py +2 -2
  120. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  121. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
  122. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
  123. phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
  124. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  125. phoenix/server/api/input_types/DatasetSort.py +17 -0
  126. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  127. phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
  128. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  129. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  130. phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
  131. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  132. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  133. phoenix/server/api/input_types/Granularity.py +1 -1
  134. phoenix/server/api/input_types/InvocationParameters.py +162 -0
  135. phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
  136. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  137. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  138. phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
  139. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  140. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  141. phoenix/server/api/input_types/SpanSort.py +134 -69
  142. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  143. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  144. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  145. phoenix/server/api/mutations/__init__.py +28 -0
  146. phoenix/server/api/mutations/api_key_mutations.py +167 -0
  147. phoenix/server/api/mutations/chat_mutations.py +593 -0
  148. phoenix/server/api/mutations/dataset_mutations.py +591 -0
  149. phoenix/server/api/mutations/experiment_mutations.py +75 -0
  150. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
  151. phoenix/server/api/mutations/project_mutations.py +57 -0
  152. phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
  153. phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
  154. phoenix/server/api/mutations/user_mutations.py +329 -0
  155. phoenix/server/api/openapi/__init__.py +0 -0
  156. phoenix/server/api/openapi/main.py +17 -0
  157. phoenix/server/api/openapi/schema.py +16 -0
  158. phoenix/server/api/queries.py +738 -0
  159. phoenix/server/api/routers/__init__.py +11 -0
  160. phoenix/server/api/routers/auth.py +284 -0
  161. phoenix/server/api/routers/embeddings.py +26 -0
  162. phoenix/server/api/routers/oauth2.py +488 -0
  163. phoenix/server/api/routers/v1/__init__.py +64 -0
  164. phoenix/server/api/routers/v1/datasets.py +1017 -0
  165. phoenix/server/api/routers/v1/evaluations.py +362 -0
  166. phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
  167. phoenix/server/api/routers/v1/experiment_runs.py +167 -0
  168. phoenix/server/api/routers/v1/experiments.py +308 -0
  169. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  170. phoenix/server/api/routers/v1/spans.py +267 -0
  171. phoenix/server/api/routers/v1/traces.py +208 -0
  172. phoenix/server/api/routers/v1/utils.py +95 -0
  173. phoenix/server/api/schema.py +44 -247
  174. phoenix/server/api/subscriptions.py +597 -0
  175. phoenix/server/api/types/Annotation.py +21 -0
  176. phoenix/server/api/types/AnnotationSummary.py +55 -0
  177. phoenix/server/api/types/AnnotatorKind.py +16 -0
  178. phoenix/server/api/types/ApiKey.py +27 -0
  179. phoenix/server/api/types/AuthMethod.py +9 -0
  180. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  181. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  182. phoenix/server/api/types/Cluster.py +25 -24
  183. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  184. phoenix/server/api/types/DataQualityMetric.py +31 -13
  185. phoenix/server/api/types/Dataset.py +288 -63
  186. phoenix/server/api/types/DatasetExample.py +85 -0
  187. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  188. phoenix/server/api/types/DatasetVersion.py +14 -0
  189. phoenix/server/api/types/Dimension.py +32 -31
  190. phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
  191. phoenix/server/api/types/EmbeddingDimension.py +56 -49
  192. phoenix/server/api/types/Evaluation.py +25 -31
  193. phoenix/server/api/types/EvaluationSummary.py +30 -50
  194. phoenix/server/api/types/Event.py +20 -20
  195. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  196. phoenix/server/api/types/Experiment.py +152 -0
  197. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  198. phoenix/server/api/types/ExperimentComparison.py +17 -0
  199. phoenix/server/api/types/ExperimentRun.py +119 -0
  200. phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
  201. phoenix/server/api/types/GenerativeModel.py +9 -0
  202. phoenix/server/api/types/GenerativeProvider.py +85 -0
  203. phoenix/server/api/types/Inferences.py +80 -0
  204. phoenix/server/api/types/InferencesRole.py +23 -0
  205. phoenix/server/api/types/LabelFraction.py +7 -0
  206. phoenix/server/api/types/MimeType.py +2 -2
  207. phoenix/server/api/types/Model.py +54 -54
  208. phoenix/server/api/types/PerformanceMetric.py +8 -5
  209. phoenix/server/api/types/Project.py +407 -142
  210. phoenix/server/api/types/ProjectSession.py +139 -0
  211. phoenix/server/api/types/Segments.py +4 -4
  212. phoenix/server/api/types/Span.py +221 -176
  213. phoenix/server/api/types/SpanAnnotation.py +43 -0
  214. phoenix/server/api/types/SpanIOValue.py +15 -0
  215. phoenix/server/api/types/SystemApiKey.py +9 -0
  216. phoenix/server/api/types/TemplateLanguage.py +10 -0
  217. phoenix/server/api/types/TimeSeries.py +19 -15
  218. phoenix/server/api/types/TokenUsage.py +11 -0
  219. phoenix/server/api/types/Trace.py +154 -0
  220. phoenix/server/api/types/TraceAnnotation.py +45 -0
  221. phoenix/server/api/types/UMAPPoints.py +7 -7
  222. phoenix/server/api/types/User.py +60 -0
  223. phoenix/server/api/types/UserApiKey.py +45 -0
  224. phoenix/server/api/types/UserRole.py +15 -0
  225. phoenix/server/api/types/node.py +13 -107
  226. phoenix/server/api/types/pagination.py +156 -57
  227. phoenix/server/api/utils.py +34 -0
  228. phoenix/server/app.py +864 -115
  229. phoenix/server/bearer_auth.py +163 -0
  230. phoenix/server/dml_event.py +136 -0
  231. phoenix/server/dml_event_handler.py +256 -0
  232. phoenix/server/email/__init__.py +0 -0
  233. phoenix/server/email/sender.py +97 -0
  234. phoenix/server/email/templates/__init__.py +0 -0
  235. phoenix/server/email/templates/password_reset.html +19 -0
  236. phoenix/server/email/types.py +11 -0
  237. phoenix/server/grpc_server.py +102 -0
  238. phoenix/server/jwt_store.py +505 -0
  239. phoenix/server/main.py +305 -116
  240. phoenix/server/oauth2.py +52 -0
  241. phoenix/server/openapi/__init__.py +0 -0
  242. phoenix/server/prometheus.py +111 -0
  243. phoenix/server/rate_limiters.py +188 -0
  244. phoenix/server/static/.vite/manifest.json +87 -0
  245. phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
  246. phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
  247. phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
  248. phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
  249. phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
  250. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
  251. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
  252. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
  253. phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
  254. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
  255. phoenix/server/telemetry.py +68 -0
  256. phoenix/server/templates/index.html +82 -23
  257. phoenix/server/thread_server.py +3 -3
  258. phoenix/server/types.py +275 -0
  259. phoenix/services.py +27 -18
  260. phoenix/session/client.py +743 -68
  261. phoenix/session/data_extractor.py +31 -7
  262. phoenix/session/evaluation.py +3 -9
  263. phoenix/session/session.py +263 -219
  264. phoenix/settings.py +22 -0
  265. phoenix/trace/__init__.py +2 -22
  266. phoenix/trace/attributes.py +338 -0
  267. phoenix/trace/dsl/README.md +116 -0
  268. phoenix/trace/dsl/filter.py +663 -213
  269. phoenix/trace/dsl/helpers.py +73 -21
  270. phoenix/trace/dsl/query.py +574 -201
  271. phoenix/trace/exporter.py +24 -19
  272. phoenix/trace/fixtures.py +368 -32
  273. phoenix/trace/otel.py +71 -219
  274. phoenix/trace/projects.py +3 -2
  275. phoenix/trace/schemas.py +33 -11
  276. phoenix/trace/span_evaluations.py +21 -16
  277. phoenix/trace/span_json_decoder.py +6 -4
  278. phoenix/trace/span_json_encoder.py +2 -2
  279. phoenix/trace/trace_dataset.py +47 -32
  280. phoenix/trace/utils.py +21 -4
  281. phoenix/utilities/__init__.py +0 -26
  282. phoenix/utilities/client.py +132 -0
  283. phoenix/utilities/deprecation.py +31 -0
  284. phoenix/utilities/error_handling.py +3 -2
  285. phoenix/utilities/json.py +109 -0
  286. phoenix/utilities/logging.py +8 -0
  287. phoenix/utilities/project.py +2 -2
  288. phoenix/utilities/re.py +49 -0
  289. phoenix/utilities/span_store.py +0 -23
  290. phoenix/utilities/template_formatters.py +99 -0
  291. phoenix/version.py +1 -1
  292. arize_phoenix-3.16.0.dist-info/METADATA +0 -495
  293. arize_phoenix-3.16.0.dist-info/RECORD +0 -178
  294. phoenix/core/project.py +0 -617
  295. phoenix/core/traces.py +0 -100
  296. phoenix/experimental/evals/__init__.py +0 -73
  297. phoenix/experimental/evals/evaluators.py +0 -413
  298. phoenix/experimental/evals/functions/__init__.py +0 -4
  299. phoenix/experimental/evals/functions/classify.py +0 -453
  300. phoenix/experimental/evals/functions/executor.py +0 -353
  301. phoenix/experimental/evals/functions/generate.py +0 -138
  302. phoenix/experimental/evals/functions/processing.py +0 -76
  303. phoenix/experimental/evals/models/__init__.py +0 -14
  304. phoenix/experimental/evals/models/anthropic.py +0 -175
  305. phoenix/experimental/evals/models/base.py +0 -170
  306. phoenix/experimental/evals/models/bedrock.py +0 -221
  307. phoenix/experimental/evals/models/litellm.py +0 -134
  308. phoenix/experimental/evals/models/openai.py +0 -448
  309. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  310. phoenix/experimental/evals/models/vertex.py +0 -173
  311. phoenix/experimental/evals/models/vertexai.py +0 -186
  312. phoenix/experimental/evals/retrievals.py +0 -96
  313. phoenix/experimental/evals/templates/__init__.py +0 -50
  314. phoenix/experimental/evals/templates/default_templates.py +0 -472
  315. phoenix/experimental/evals/templates/template.py +0 -195
  316. phoenix/experimental/evals/utils/__init__.py +0 -172
  317. phoenix/experimental/evals/utils/threads.py +0 -27
  318. phoenix/server/api/helpers.py +0 -11
  319. phoenix/server/api/routers/evaluation_handler.py +0 -109
  320. phoenix/server/api/routers/span_handler.py +0 -70
  321. phoenix/server/api/routers/trace_handler.py +0 -60
  322. phoenix/server/api/types/DatasetRole.py +0 -23
  323. phoenix/server/static/index.css +0 -6
  324. phoenix/server/static/index.js +0 -7447
  325. phoenix/storage/span_store/__init__.py +0 -23
  326. phoenix/storage/span_store/text_file.py +0 -85
  327. phoenix/trace/dsl/missing.py +0 -60
  328. phoenix/trace/langchain/__init__.py +0 -3
  329. phoenix/trace/langchain/instrumentor.py +0 -35
  330. phoenix/trace/llama_index/__init__.py +0 -3
  331. phoenix/trace/llama_index/callback.py +0 -102
  332. phoenix/trace/openai/__init__.py +0 -3
  333. phoenix/trace/openai/instrumentor.py +0 -30
  334. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  335. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/LICENSE +0 -0
  336. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  337. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  338. /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
@@ -0,0 +1,1017 @@
1
+ import csv
2
+ import gzip
3
+ import io
4
+ import json
5
+ import logging
6
+ import zlib
7
+ from asyncio import QueueFull
8
+ from collections import Counter
9
+ from collections.abc import Awaitable, Callable, Coroutine, Iterator, Mapping, Sequence
10
+ from datetime import datetime
11
+ from enum import Enum
12
+ from functools import partial
13
+ from typing import Any, Optional, Union, cast
14
+
15
+ import pandas as pd
16
+ import pyarrow as pa
17
+ from fastapi import APIRouter, BackgroundTasks, HTTPException, Path, Query
18
+ from fastapi.responses import PlainTextResponse, StreamingResponse
19
+ from sqlalchemy import and_, delete, func, select
20
+ from sqlalchemy.ext.asyncio import AsyncSession
21
+ from starlette.concurrency import run_in_threadpool
22
+ from starlette.datastructures import FormData, UploadFile
23
+ from starlette.requests import Request
24
+ from starlette.responses import Response
25
+ from starlette.status import (
26
+ HTTP_200_OK,
27
+ HTTP_204_NO_CONTENT,
28
+ HTTP_404_NOT_FOUND,
29
+ HTTP_409_CONFLICT,
30
+ HTTP_422_UNPROCESSABLE_ENTITY,
31
+ HTTP_429_TOO_MANY_REQUESTS,
32
+ )
33
+ from strawberry.relay import GlobalID
34
+ from typing_extensions import TypeAlias, assert_never
35
+
36
+ from phoenix.db import models
37
+ from phoenix.db.helpers import get_eval_trace_ids_for_datasets, get_project_names_for_datasets
38
+ from phoenix.db.insertion.dataset import (
39
+ DatasetAction,
40
+ DatasetExampleAdditionEvent,
41
+ ExampleContent,
42
+ add_dataset_examples,
43
+ )
44
+ from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
45
+ from phoenix.server.api.types.DatasetExample import DatasetExample as DatasetExampleNodeType
46
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
47
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
48
+ from phoenix.server.api.utils import delete_projects, delete_traces
49
+ from phoenix.server.dml_event import DatasetInsertEvent
50
+
51
+ from .pydantic_compat import V1RoutesBaseModel
52
+ from .utils import (
53
+ PaginatedResponseBody,
54
+ ResponseBody,
55
+ add_errors_to_responses,
56
+ add_text_csv_content_to_responses,
57
+ )
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ DATASET_NODE_NAME = DatasetNodeType.__name__
62
+ DATASET_VERSION_NODE_NAME = DatasetVersionNodeType.__name__
63
+
64
+
65
+ router = APIRouter(tags=["datasets"])
66
+
67
+
68
+ class Dataset(V1RoutesBaseModel):
69
+ id: str
70
+ name: str
71
+ description: Optional[str]
72
+ metadata: dict[str, Any]
73
+ created_at: datetime
74
+ updated_at: datetime
75
+
76
+
77
+ class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
78
+ pass
79
+
80
+
81
+ @router.get(
82
+ "/datasets",
83
+ operation_id="listDatasets",
84
+ summary="List datasets",
85
+ responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
86
+ )
87
+ async def list_datasets(
88
+ request: Request,
89
+ cursor: Optional[str] = Query(
90
+ default=None,
91
+ description="Cursor for pagination",
92
+ ),
93
+ name: Optional[str] = Query(default=None, description="An optional dataset name to filter by"),
94
+ limit: int = Query(
95
+ default=10, description="The max number of datasets to return at a time.", gt=0
96
+ ),
97
+ ) -> ListDatasetsResponseBody:
98
+ async with request.app.state.db() as session:
99
+ query = select(models.Dataset).order_by(models.Dataset.id.desc())
100
+
101
+ if cursor:
102
+ try:
103
+ cursor_id = GlobalID.from_id(cursor).node_id
104
+ query = query.filter(models.Dataset.id <= int(cursor_id))
105
+ except ValueError:
106
+ raise HTTPException(
107
+ detail=f"Invalid cursor format: {cursor}",
108
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
109
+ )
110
+ if name:
111
+ query = query.filter(models.Dataset.name == name)
112
+
113
+ query = query.limit(limit + 1)
114
+ result = await session.execute(query)
115
+ datasets = result.scalars().all()
116
+
117
+ if not datasets:
118
+ return ListDatasetsResponseBody(next_cursor=None, data=[])
119
+
120
+ next_cursor = None
121
+ if len(datasets) == limit + 1:
122
+ next_cursor = str(GlobalID(DATASET_NODE_NAME, str(datasets[-1].id)))
123
+ datasets = datasets[:-1]
124
+
125
+ data = []
126
+ for dataset in datasets:
127
+ data.append(
128
+ Dataset(
129
+ id=str(GlobalID(DATASET_NODE_NAME, str(dataset.id))),
130
+ name=dataset.name,
131
+ description=dataset.description,
132
+ metadata=dataset.metadata_,
133
+ created_at=dataset.created_at,
134
+ updated_at=dataset.updated_at,
135
+ )
136
+ )
137
+
138
+ return ListDatasetsResponseBody(next_cursor=next_cursor, data=data)
139
+
140
+
141
+ @router.delete(
142
+ "/datasets/{id}",
143
+ operation_id="deleteDatasetById",
144
+ summary="Delete dataset by ID",
145
+ status_code=HTTP_204_NO_CONTENT,
146
+ responses=add_errors_to_responses(
147
+ [
148
+ {"status_code": HTTP_404_NOT_FOUND, "description": "Dataset not found"},
149
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid dataset ID"},
150
+ ]
151
+ ),
152
+ )
153
+ async def delete_dataset(
154
+ request: Request, id: str = Path(description="The ID of the dataset to delete.")
155
+ ) -> None:
156
+ if id:
157
+ try:
158
+ dataset_id = from_global_id_with_expected_type(
159
+ GlobalID.from_id(id),
160
+ DATASET_NODE_NAME,
161
+ )
162
+ except ValueError:
163
+ raise HTTPException(
164
+ detail=f"Invalid Dataset ID: {id}", status_code=HTTP_422_UNPROCESSABLE_ENTITY
165
+ )
166
+ else:
167
+ raise HTTPException(detail="Missing Dataset ID", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
168
+ project_names_stmt = get_project_names_for_datasets(dataset_id)
169
+ eval_trace_ids_stmt = get_eval_trace_ids_for_datasets(dataset_id)
170
+ stmt = (
171
+ delete(models.Dataset).where(models.Dataset.id == dataset_id).returning(models.Dataset.id)
172
+ )
173
+ async with request.app.state.db() as session:
174
+ project_names = await session.scalars(project_names_stmt)
175
+ eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
176
+ if (await session.scalar(stmt)) is None:
177
+ raise HTTPException(detail="Dataset does not exist", status_code=HTTP_404_NOT_FOUND)
178
+ tasks = BackgroundTasks()
179
+ tasks.add_task(delete_projects, request.app.state.db, *project_names)
180
+ tasks.add_task(delete_traces, request.app.state.db, *eval_trace_ids)
181
+
182
+
183
+ class DatasetWithExampleCount(Dataset):
184
+ example_count: int
185
+
186
+
187
+ class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
188
+ pass
189
+
190
+
191
+ @router.get(
192
+ "/datasets/{id}",
193
+ operation_id="getDataset",
194
+ summary="Get dataset by ID",
195
+ responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
196
+ )
197
+ async def get_dataset(
198
+ request: Request, id: str = Path(description="The ID of the dataset")
199
+ ) -> GetDatasetResponseBody:
200
+ dataset_id = GlobalID.from_id(id)
201
+
202
+ if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
203
+ raise HTTPException(
204
+ detail=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
205
+ )
206
+ async with request.app.state.db() as session:
207
+ result = await session.execute(
208
+ select(models.Dataset, models.Dataset.example_count).filter(
209
+ models.Dataset.id == int(dataset_id.node_id)
210
+ )
211
+ )
212
+ dataset_query = result.first()
213
+ dataset = dataset_query[0] if dataset_query else None
214
+ example_count = dataset_query[1] if dataset_query else 0
215
+ if dataset is None:
216
+ raise HTTPException(
217
+ detail=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
218
+ )
219
+
220
+ dataset = DatasetWithExampleCount(
221
+ id=str(dataset_id),
222
+ name=dataset.name,
223
+ description=dataset.description,
224
+ metadata=dataset.metadata_,
225
+ created_at=dataset.created_at,
226
+ updated_at=dataset.updated_at,
227
+ example_count=example_count,
228
+ )
229
+ return GetDatasetResponseBody(data=dataset)
230
+
231
+
232
+ class DatasetVersion(V1RoutesBaseModel):
233
+ version_id: str
234
+ description: Optional[str]
235
+ metadata: dict[str, Any]
236
+ created_at: datetime
237
+
238
+
239
+ class ListDatasetVersionsResponseBody(PaginatedResponseBody[DatasetVersion]):
240
+ pass
241
+
242
+
243
+ @router.get(
244
+ "/datasets/{id}/versions",
245
+ operation_id="listDatasetVersionsByDatasetId",
246
+ summary="List dataset versions",
247
+ responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
248
+ )
249
+ async def list_dataset_versions(
250
+ request: Request,
251
+ id: str = Path(description="The ID of the dataset"),
252
+ cursor: Optional[str] = Query(
253
+ default=None,
254
+ description="Cursor for pagination",
255
+ ),
256
+ limit: int = Query(
257
+ default=10, description="The max number of dataset versions to return at a time", gt=0
258
+ ),
259
+ ) -> ListDatasetVersionsResponseBody:
260
+ if id:
261
+ try:
262
+ dataset_id = from_global_id_with_expected_type(
263
+ GlobalID.from_id(id),
264
+ DATASET_NODE_NAME,
265
+ )
266
+ except ValueError:
267
+ raise HTTPException(
268
+ detail=f"Invalid Dataset ID: {id}",
269
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
270
+ )
271
+ else:
272
+ raise HTTPException(
273
+ detail="Missing Dataset ID",
274
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
275
+ )
276
+ stmt = (
277
+ select(models.DatasetVersion)
278
+ .where(models.DatasetVersion.dataset_id == dataset_id)
279
+ .order_by(models.DatasetVersion.id.desc())
280
+ .limit(limit + 1)
281
+ )
282
+ if cursor:
283
+ try:
284
+ dataset_version_id = from_global_id_with_expected_type(
285
+ GlobalID.from_id(cursor), DATASET_VERSION_NODE_NAME
286
+ )
287
+ except ValueError:
288
+ raise HTTPException(
289
+ detail=f"Invalid cursor: {cursor}",
290
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
291
+ )
292
+ max_dataset_version_id = (
293
+ select(models.DatasetVersion.id)
294
+ .where(models.DatasetVersion.id == dataset_version_id)
295
+ .where(models.DatasetVersion.dataset_id == dataset_id)
296
+ ).scalar_subquery()
297
+ stmt = stmt.filter(models.DatasetVersion.id <= max_dataset_version_id)
298
+ async with request.app.state.db() as session:
299
+ data = [
300
+ DatasetVersion(
301
+ version_id=str(GlobalID(DATASET_VERSION_NODE_NAME, str(version.id))),
302
+ description=version.description,
303
+ metadata=version.metadata_,
304
+ created_at=version.created_at,
305
+ )
306
+ async for version in await session.stream_scalars(stmt)
307
+ ]
308
+ next_cursor = data.pop().version_id if len(data) == limit + 1 else None
309
+ return ListDatasetVersionsResponseBody(data=data, next_cursor=next_cursor)
310
+
311
+
312
+ class UploadDatasetData(V1RoutesBaseModel):
313
+ dataset_id: str
314
+
315
+
316
+ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
317
+ pass
318
+
319
+
320
+ @router.post(
321
+ "/datasets/upload",
322
+ operation_id="uploadDataset",
323
+ summary="Upload dataset from JSON, CSV, or PyArrow",
324
+ responses=add_errors_to_responses(
325
+ [
326
+ {
327
+ "status_code": HTTP_409_CONFLICT,
328
+ "description": "Dataset of the same name already exists",
329
+ },
330
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid request body"},
331
+ ]
332
+ ),
333
+ # FastAPI cannot generate the request body portion of the OpenAPI schema for
334
+ # routes that accept multiple request content types, so we have to provide
335
+ # this part of the schema manually. For context, see
336
+ # https://github.com/tiangolo/fastapi/discussions/7786 and
337
+ # https://github.com/tiangolo/fastapi/issues/990
338
+ openapi_extra={
339
+ "requestBody": {
340
+ "content": {
341
+ "application/json": {
342
+ "schema": {
343
+ "type": "object",
344
+ "required": ["name", "inputs"],
345
+ "properties": {
346
+ "action": {"type": "string", "enum": ["create", "append"]},
347
+ "name": {"type": "string"},
348
+ "description": {"type": "string"},
349
+ "inputs": {"type": "array", "items": {"type": "object"}},
350
+ "outputs": {"type": "array", "items": {"type": "object"}},
351
+ "metadata": {"type": "array", "items": {"type": "object"}},
352
+ },
353
+ }
354
+ },
355
+ "multipart/form-data": {
356
+ "schema": {
357
+ "type": "object",
358
+ "required": ["name", "input_keys[]", "output_keys[]", "file"],
359
+ "properties": {
360
+ "action": {"type": "string", "enum": ["create", "append"]},
361
+ "name": {"type": "string"},
362
+ "description": {"type": "string"},
363
+ "input_keys[]": {
364
+ "type": "array",
365
+ "items": {"type": "string"},
366
+ "uniqueItems": True,
367
+ },
368
+ "output_keys[]": {
369
+ "type": "array",
370
+ "items": {"type": "string"},
371
+ "uniqueItems": True,
372
+ },
373
+ "metadata_keys[]": {
374
+ "type": "array",
375
+ "items": {"type": "string"},
376
+ "uniqueItems": True,
377
+ },
378
+ "file": {"type": "string", "format": "binary"},
379
+ },
380
+ }
381
+ },
382
+ }
383
+ },
384
+ },
385
+ )
386
+ async def upload_dataset(
387
+ request: Request,
388
+ sync: bool = Query(
389
+ default=False,
390
+ description="If true, fulfill request synchronously and return JSON containing dataset_id.",
391
+ ),
392
+ ) -> Optional[UploadDatasetResponseBody]:
393
+ request_content_type = request.headers["content-type"]
394
+ examples: Union[Examples, Awaitable[Examples]]
395
+ if request_content_type.startswith("application/json"):
396
+ try:
397
+ examples, action, name, description = await run_in_threadpool(
398
+ _process_json, await request.json()
399
+ )
400
+ except ValueError as e:
401
+ raise HTTPException(
402
+ detail=str(e),
403
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
404
+ )
405
+ if action is DatasetAction.CREATE:
406
+ async with request.app.state.db() as session:
407
+ if await _check_table_exists(session, name):
408
+ raise HTTPException(
409
+ detail=f"Dataset with the same name already exists: {name=}",
410
+ status_code=HTTP_409_CONFLICT,
411
+ )
412
+ elif request_content_type.startswith("multipart/form-data"):
413
+ async with request.form() as form:
414
+ try:
415
+ (
416
+ action,
417
+ name,
418
+ description,
419
+ input_keys,
420
+ output_keys,
421
+ metadata_keys,
422
+ file,
423
+ ) = await _parse_form_data(form)
424
+ except ValueError as e:
425
+ raise HTTPException(
426
+ detail=str(e),
427
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
428
+ )
429
+ if action is DatasetAction.CREATE:
430
+ async with request.app.state.db() as session:
431
+ if await _check_table_exists(session, name):
432
+ raise HTTPException(
433
+ detail=f"Dataset with the same name already exists: {name=}",
434
+ status_code=HTTP_409_CONFLICT,
435
+ )
436
+ content = await file.read()
437
+ try:
438
+ file_content_type = FileContentType(file.content_type)
439
+ if file_content_type is FileContentType.CSV:
440
+ encoding = FileContentEncoding(file.headers.get("content-encoding"))
441
+ examples = await _process_csv(
442
+ content, encoding, input_keys, output_keys, metadata_keys
443
+ )
444
+ elif file_content_type is FileContentType.PYARROW:
445
+ examples = await _process_pyarrow(content, input_keys, output_keys, metadata_keys)
446
+ else:
447
+ assert_never(file_content_type)
448
+ except ValueError as e:
449
+ raise HTTPException(
450
+ detail=str(e),
451
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
452
+ )
453
+ else:
454
+ raise HTTPException(
455
+ detail="Invalid request Content-Type",
456
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
457
+ )
458
+ operation = cast(
459
+ Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
460
+ partial(
461
+ add_dataset_examples,
462
+ examples=examples,
463
+ action=action,
464
+ name=name,
465
+ description=description,
466
+ ),
467
+ )
468
+ if sync:
469
+ async with request.app.state.db() as session:
470
+ dataset_id = (await operation(session)).dataset_id
471
+ request.state.event_queue.put(DatasetInsertEvent((dataset_id,)))
472
+ return UploadDatasetResponseBody(
473
+ data=UploadDatasetData(dataset_id=str(GlobalID(Dataset.__name__, str(dataset_id))))
474
+ )
475
+ try:
476
+ request.state.enqueue_operation(operation)
477
+ except QueueFull:
478
+ if isinstance(examples, Coroutine):
479
+ examples.close()
480
+ raise HTTPException(detail="Too many requests.", status_code=HTTP_429_TOO_MANY_REQUESTS)
481
+ return None
482
+
483
+
484
+ class FileContentType(Enum):
485
+ CSV = "text/csv"
486
+ PYARROW = "application/x-pandas-pyarrow"
487
+
488
+ @classmethod
489
+ def _missing_(cls, v: Any) -> "FileContentType":
490
+ if isinstance(v, str) and v and v.isascii() and not v.islower():
491
+ return cls(v.lower())
492
+ raise ValueError(f"Invalid file content type: {v}")
493
+
494
+
495
+ class FileContentEncoding(Enum):
496
+ NONE = "none"
497
+ GZIP = "gzip"
498
+ DEFLATE = "deflate"
499
+
500
+ @classmethod
501
+ def _missing_(cls, v: Any) -> "FileContentEncoding":
502
+ if v is None:
503
+ return cls("none")
504
+ if isinstance(v, str) and v and v.isascii() and not v.islower():
505
+ return cls(v.lower())
506
+ raise ValueError(f"Invalid file content encoding: {v}")
507
+
508
+
509
+ Name: TypeAlias = str
510
+ Description: TypeAlias = Optional[str]
511
+ InputKeys: TypeAlias = frozenset[str]
512
+ OutputKeys: TypeAlias = frozenset[str]
513
+ MetadataKeys: TypeAlias = frozenset[str]
514
+ DatasetId: TypeAlias = int
515
+ Examples: TypeAlias = Iterator[ExampleContent]
516
+
517
+
518
+ def _process_json(
519
+ data: Mapping[str, Any],
520
+ ) -> tuple[Examples, DatasetAction, Name, Description]:
521
+ name = data.get("name")
522
+ if not name:
523
+ raise ValueError("Dataset name is required")
524
+ description = data.get("description") or ""
525
+ inputs = data.get("inputs")
526
+ if not inputs:
527
+ raise ValueError("input is required")
528
+ if not isinstance(inputs, list) or not _is_all_dict(inputs):
529
+ raise ValueError("Input should be a list containing only dictionary objects")
530
+ outputs, metadata = data.get("outputs"), data.get("metadata")
531
+ for k, v in {"outputs": outputs, "metadata": metadata}.items():
532
+ if v and not (isinstance(v, list) and len(v) == len(inputs) and _is_all_dict(v)):
533
+ raise ValueError(
534
+ f"{k} should be a list of same length as input containing only dictionary objects"
535
+ )
536
+ examples: list[ExampleContent] = []
537
+ for i, obj in enumerate(inputs):
538
+ example = ExampleContent(
539
+ input=obj,
540
+ output=outputs[i] if outputs else {},
541
+ metadata=metadata[i] if metadata else {},
542
+ )
543
+ examples.append(example)
544
+ action = DatasetAction(cast(Optional[str], data.get("action")) or "create")
545
+ return iter(examples), action, name, description
546
+
547
+
548
+ async def _process_csv(
549
+ content: bytes,
550
+ content_encoding: FileContentEncoding,
551
+ input_keys: InputKeys,
552
+ output_keys: OutputKeys,
553
+ metadata_keys: MetadataKeys,
554
+ ) -> Examples:
555
+ if content_encoding is FileContentEncoding.GZIP:
556
+ content = await run_in_threadpool(gzip.decompress, content)
557
+ elif content_encoding is FileContentEncoding.DEFLATE:
558
+ content = await run_in_threadpool(zlib.decompress, content)
559
+ elif content_encoding is not FileContentEncoding.NONE:
560
+ assert_never(content_encoding)
561
+ reader = await run_in_threadpool(lambda c: csv.DictReader(io.StringIO(c.decode())), content)
562
+ if reader.fieldnames is None:
563
+ raise ValueError("Missing CSV column header")
564
+ (header, freq), *_ = Counter(reader.fieldnames).most_common(1)
565
+ if freq > 1:
566
+ raise ValueError(f"Duplicated column header in CSV file: {header}")
567
+ column_headers = frozenset(reader.fieldnames)
568
+ _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
569
+ return (
570
+ ExampleContent(
571
+ input={k: row.get(k) for k in input_keys},
572
+ output={k: row.get(k) for k in output_keys},
573
+ metadata={k: row.get(k) for k in metadata_keys},
574
+ )
575
+ for row in iter(reader)
576
+ )
577
+
578
+
579
+ async def _process_pyarrow(
580
+ content: bytes,
581
+ input_keys: InputKeys,
582
+ output_keys: OutputKeys,
583
+ metadata_keys: MetadataKeys,
584
+ ) -> Awaitable[Examples]:
585
+ try:
586
+ reader = pa.ipc.open_stream(content)
587
+ except pa.ArrowInvalid as e:
588
+ raise ValueError("File is not valid pyarrow") from e
589
+ column_headers = frozenset(reader.schema.names)
590
+ _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
591
+
592
+ def get_examples() -> Iterator[ExampleContent]:
593
+ for row in reader.read_pandas().to_dict(orient="records"):
594
+ yield ExampleContent(
595
+ input={k: row.get(k) for k in input_keys},
596
+ output={k: row.get(k) for k in output_keys},
597
+ metadata={k: row.get(k) for k in metadata_keys},
598
+ )
599
+
600
+ return run_in_threadpool(get_examples)
601
+
602
+
603
+ async def _check_table_exists(session: AsyncSession, name: str) -> bool:
604
+ return bool(
605
+ await session.scalar(
606
+ select(1).select_from(models.Dataset).where(models.Dataset.name == name)
607
+ )
608
+ )
609
+
610
+
611
+ def _check_keys_exist(
612
+ column_headers: frozenset[str],
613
+ input_keys: InputKeys,
614
+ output_keys: OutputKeys,
615
+ metadata_keys: MetadataKeys,
616
+ ) -> None:
617
+ for desc, keys in (
618
+ ("input", input_keys),
619
+ ("output", output_keys),
620
+ ("metadata", metadata_keys),
621
+ ):
622
+ if keys and (diff := keys.difference(column_headers)):
623
+ raise ValueError(f"{desc} keys not found in column headers: {diff}")
624
+
625
+
626
+ async def _parse_form_data(
627
+ form: FormData,
628
+ ) -> tuple[
629
+ DatasetAction,
630
+ Name,
631
+ Description,
632
+ InputKeys,
633
+ OutputKeys,
634
+ MetadataKeys,
635
+ UploadFile,
636
+ ]:
637
+ name = cast(Optional[str], form.get("name"))
638
+ if not name:
639
+ raise ValueError("Dataset name must not be empty")
640
+ action = DatasetAction(cast(Optional[str], form.get("action")) or "create")
641
+ file = form["file"]
642
+ if not isinstance(file, UploadFile):
643
+ raise ValueError("Malformed file in form data.")
644
+ description = cast(Optional[str], form.get("description")) or file.filename
645
+ input_keys = frozenset(filter(bool, cast(list[str], form.getlist("input_keys[]"))))
646
+ output_keys = frozenset(filter(bool, cast(list[str], form.getlist("output_keys[]"))))
647
+ metadata_keys = frozenset(filter(bool, cast(list[str], form.getlist("metadata_keys[]"))))
648
+ return (
649
+ action,
650
+ name,
651
+ description,
652
+ input_keys,
653
+ output_keys,
654
+ metadata_keys,
655
+ file,
656
+ )
657
+
658
+
659
+ class DatasetExample(V1RoutesBaseModel):
660
+ id: str
661
+ input: dict[str, Any]
662
+ output: dict[str, Any]
663
+ metadata: dict[str, Any]
664
+ updated_at: datetime
665
+
666
+
667
+ class ListDatasetExamplesData(V1RoutesBaseModel):
668
+ dataset_id: str
669
+ version_id: str
670
+ examples: list[DatasetExample]
671
+
672
+
673
+ class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]):
674
+ pass
675
+
676
+
677
+ @router.get(
678
+ "/datasets/{id}/examples",
679
+ operation_id="getDatasetExamples",
680
+ summary="Get examples from a dataset",
681
+ responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
682
+ )
683
+ async def get_dataset_examples(
684
+ request: Request,
685
+ id: str = Path(description="The ID of the dataset"),
686
+ version_id: Optional[str] = Query(
687
+ default=None,
688
+ description=(
689
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
690
+ ),
691
+ ),
692
+ ) -> ListDatasetExamplesResponseBody:
693
+ dataset_gid = GlobalID.from_id(id)
694
+ version_gid = GlobalID.from_id(version_id) if version_id else None
695
+
696
+ if (dataset_type := dataset_gid.type_name) != "Dataset":
697
+ raise HTTPException(
698
+ detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
699
+ )
700
+
701
+ if version_gid and (version_type := version_gid.type_name) != "DatasetVersion":
702
+ raise HTTPException(
703
+ detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
704
+ )
705
+
706
+ async with request.app.state.db() as session:
707
+ if (
708
+ resolved_dataset_id := await session.scalar(
709
+ select(models.Dataset.id).where(models.Dataset.id == int(dataset_gid.node_id))
710
+ )
711
+ ) is None:
712
+ raise HTTPException(
713
+ detail=f"No dataset with id {dataset_gid} can be found.",
714
+ status_code=HTTP_404_NOT_FOUND,
715
+ )
716
+
717
+ # Subquery to find the maximum created_at for each dataset_example_id
718
+ # timestamp tiebreaks are resolved by the largest id
719
+ partial_subquery = select(
720
+ func.max(models.DatasetExampleRevision.id).label("max_id"),
721
+ ).group_by(models.DatasetExampleRevision.dataset_example_id)
722
+
723
+ if version_gid:
724
+ if (
725
+ resolved_version_id := await session.scalar(
726
+ select(models.DatasetVersion.id).where(
727
+ and_(
728
+ models.DatasetVersion.dataset_id == resolved_dataset_id,
729
+ models.DatasetVersion.id == int(version_gid.node_id),
730
+ )
731
+ )
732
+ )
733
+ ) is None:
734
+ raise HTTPException(
735
+ detail=f"No dataset version with id {version_id} can be found.",
736
+ status_code=HTTP_404_NOT_FOUND,
737
+ )
738
+ # if a version_id is provided, filter the subquery to only include revisions from that
739
+ partial_subquery = partial_subquery.filter(
740
+ models.DatasetExampleRevision.dataset_version_id <= resolved_version_id
741
+ )
742
+ else:
743
+ if (
744
+ resolved_version_id := await session.scalar(
745
+ select(func.max(models.DatasetVersion.id)).where(
746
+ models.DatasetVersion.dataset_id == resolved_dataset_id
747
+ )
748
+ )
749
+ ) is None:
750
+ raise HTTPException(
751
+ detail="Dataset has no versions.",
752
+ status_code=HTTP_404_NOT_FOUND,
753
+ )
754
+
755
+ subquery = partial_subquery.subquery()
756
+ # Query for the most recent example revisions that are not deleted
757
+ query = (
758
+ select(models.DatasetExample, models.DatasetExampleRevision)
759
+ .join(
760
+ models.DatasetExampleRevision,
761
+ models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
762
+ )
763
+ .join(
764
+ subquery,
765
+ (subquery.c.max_id == models.DatasetExampleRevision.id),
766
+ )
767
+ .filter(models.DatasetExample.dataset_id == resolved_dataset_id)
768
+ .filter(models.DatasetExampleRevision.revision_kind != "DELETE")
769
+ .order_by(models.DatasetExample.id.asc())
770
+ )
771
+ examples = [
772
+ DatasetExample(
773
+ id=str(GlobalID("DatasetExample", str(example.id))),
774
+ input=revision.input,
775
+ output=revision.output,
776
+ metadata=revision.metadata_,
777
+ updated_at=revision.created_at,
778
+ )
779
+ async for example, revision in await session.stream(query)
780
+ ]
781
+ return ListDatasetExamplesResponseBody(
782
+ data=ListDatasetExamplesData(
783
+ dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))),
784
+ version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))),
785
+ examples=examples,
786
+ )
787
+ )
788
+
789
+
790
+ @router.get(
791
+ "/datasets/{id}/csv",
792
+ operation_id="getDatasetCsv",
793
+ summary="Download dataset examples as CSV file",
794
+ response_class=StreamingResponse,
795
+ status_code=HTTP_200_OK,
796
+ responses={
797
+ **add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
798
+ **add_text_csv_content_to_responses(HTTP_200_OK),
799
+ },
800
+ )
801
+ async def get_dataset_csv(
802
+ request: Request,
803
+ response: Response,
804
+ id: str = Path(description="The ID of the dataset"),
805
+ version_id: Optional[str] = Query(
806
+ default=None,
807
+ description=(
808
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
809
+ ),
810
+ ),
811
+ ) -> Response:
812
+ try:
813
+ async with request.app.state.db() as session:
814
+ dataset_name, examples = await _get_db_examples(
815
+ session=session, id=id, version_id=version_id
816
+ )
817
+ except ValueError as e:
818
+ raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
819
+ content = await run_in_threadpool(_get_content_csv, examples)
820
+ return Response(
821
+ content=content,
822
+ headers={
823
+ "content-disposition": f'attachment; filename="{dataset_name}.csv"',
824
+ "content-type": "text/csv",
825
+ },
826
+ )
827
+
828
+
829
+ @router.get(
830
+ "/datasets/{id}/jsonl/openai_ft",
831
+ operation_id="getDatasetJSONLOpenAIFineTuning",
832
+ summary="Download dataset examples as OpenAI fine-tuning JSONL file",
833
+ response_class=PlainTextResponse,
834
+ responses=add_errors_to_responses(
835
+ [
836
+ {
837
+ "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
838
+ "description": "Invalid dataset or version ID",
839
+ }
840
+ ]
841
+ ),
842
+ )
843
+ async def get_dataset_jsonl_openai_ft(
844
+ request: Request,
845
+ response: Response,
846
+ id: str = Path(description="The ID of the dataset"),
847
+ version_id: Optional[str] = Query(
848
+ default=None,
849
+ description=(
850
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
851
+ ),
852
+ ),
853
+ ) -> bytes:
854
+ try:
855
+ async with request.app.state.db() as session:
856
+ dataset_name, examples = await _get_db_examples(
857
+ session=session, id=id, version_id=version_id
858
+ )
859
+ except ValueError as e:
860
+ raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
861
+ content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
862
+ response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
863
+ return content
864
+
865
+
866
+ @router.get(
867
+ "/datasets/{id}/jsonl/openai_evals",
868
+ operation_id="getDatasetJSONLOpenAIEvals",
869
+ summary="Download dataset examples as OpenAI evals JSONL file",
870
+ response_class=PlainTextResponse,
871
+ responses=add_errors_to_responses(
872
+ [
873
+ {
874
+ "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
875
+ "description": "Invalid dataset or version ID",
876
+ }
877
+ ]
878
+ ),
879
+ )
880
+ async def get_dataset_jsonl_openai_evals(
881
+ request: Request,
882
+ response: Response,
883
+ id: str = Path(description="The ID of the dataset"),
884
+ version_id: Optional[str] = Query(
885
+ default=None,
886
+ description=(
887
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
888
+ ),
889
+ ),
890
+ ) -> bytes:
891
+ try:
892
+ async with request.app.state.db() as session:
893
+ dataset_name, examples = await _get_db_examples(
894
+ session=session, id=id, version_id=version_id
895
+ )
896
+ except ValueError as e:
897
+ raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
898
+ content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
899
+ response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
900
+ return content
901
+
902
+
903
+ def _get_content_csv(examples: list[models.DatasetExampleRevision]) -> bytes:
904
+ records = [
905
+ {
906
+ "example_id": GlobalID(
907
+ type_name=DatasetExampleNodeType.__name__,
908
+ node_id=str(ex.dataset_example_id),
909
+ ),
910
+ **{f"input_{k}": v for k, v in ex.input.items()},
911
+ **{f"output_{k}": v for k, v in ex.output.items()},
912
+ **{f"metadata_{k}": v for k, v in ex.metadata_.items()},
913
+ }
914
+ for ex in examples
915
+ ]
916
+ return str(pd.DataFrame.from_records(records).to_csv(index=False)).encode()
917
+
918
+
919
+ def _get_content_jsonl_openai_ft(examples: list[models.DatasetExampleRevision]) -> bytes:
920
+ records = io.BytesIO()
921
+ for ex in examples:
922
+ records.write(
923
+ (
924
+ json.dumps(
925
+ {
926
+ "messages": (
927
+ ims if isinstance(ims := ex.input.get("messages"), list) else []
928
+ )
929
+ + (oms if isinstance(oms := ex.output.get("messages"), list) else [])
930
+ },
931
+ ensure_ascii=False,
932
+ )
933
+ + "\n"
934
+ ).encode()
935
+ )
936
+ records.seek(0)
937
+ return records.read()
938
+
939
+
940
+ def _get_content_jsonl_openai_evals(examples: list[models.DatasetExampleRevision]) -> bytes:
941
+ records = io.BytesIO()
942
+ for ex in examples:
943
+ records.write(
944
+ (
945
+ json.dumps(
946
+ {
947
+ "messages": ims
948
+ if isinstance(ims := ex.input.get("messages"), list)
949
+ else [],
950
+ "ideal": (
951
+ ideal if isinstance(ideal := last_message.get("content"), str) else ""
952
+ )
953
+ if isinstance(oms := ex.output.get("messages"), list)
954
+ and oms
955
+ and hasattr(last_message := oms[-1], "get")
956
+ else "",
957
+ },
958
+ ensure_ascii=False,
959
+ )
960
+ + "\n"
961
+ ).encode()
962
+ )
963
+ records.seek(0)
964
+ return records.read()
965
+
966
+
967
+ async def _get_db_examples(
968
+ *, session: Any, id: str, version_id: Optional[str]
969
+ ) -> tuple[str, list[models.DatasetExampleRevision]]:
970
+ dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
971
+ dataset_version_id: Optional[int] = None
972
+ if version_id:
973
+ dataset_version_id = from_global_id_with_expected_type(
974
+ GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
975
+ )
976
+ latest_version = (
977
+ select(
978
+ models.DatasetExampleRevision.dataset_example_id,
979
+ func.max(models.DatasetExampleRevision.dataset_version_id).label("dataset_version_id"),
980
+ )
981
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
982
+ .join(models.DatasetExample)
983
+ .where(models.DatasetExample.dataset_id == dataset_id)
984
+ )
985
+ if dataset_version_id is not None:
986
+ max_dataset_version_id = (
987
+ select(models.DatasetVersion.id)
988
+ .where(models.DatasetVersion.id == dataset_version_id)
989
+ .where(models.DatasetVersion.dataset_id == dataset_id)
990
+ ).scalar_subquery()
991
+ latest_version = latest_version.where(
992
+ models.DatasetExampleRevision.dataset_version_id <= max_dataset_version_id
993
+ )
994
+ subq = latest_version.subquery("latest_version")
995
+ stmt = (
996
+ select(models.DatasetExampleRevision)
997
+ .join(
998
+ subq,
999
+ onclause=and_(
1000
+ models.DatasetExampleRevision.dataset_example_id == subq.c.dataset_example_id,
1001
+ models.DatasetExampleRevision.dataset_version_id == subq.c.dataset_version_id,
1002
+ ),
1003
+ )
1004
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
1005
+ .order_by(models.DatasetExampleRevision.dataset_example_id)
1006
+ )
1007
+ dataset_name: Optional[str] = await session.scalar(
1008
+ select(models.Dataset.name).where(models.Dataset.id == dataset_id)
1009
+ )
1010
+ if not dataset_name:
1011
+ raise ValueError("Dataset does not exist.")
1012
+ examples = [r async for r in await session.stream_scalars(stmt)]
1013
+ return dataset_name, examples
1014
+
1015
+
1016
+ def _is_all_dict(seq: Sequence[Any]) -> bool:
1017
+ return all(map(lambda obj: isinstance(obj, dict), seq))